(์ด๋ฒ ํ๋ก์ ํธ ์ฝ๋๋ ํจ์บ ๋ฅ๋ฌ๋ ๊ฐ์๋ฅผ ์ฐธ๊ณ ํ ์ฝ๋์ด๋ค)
<์ด์ ํฌ์คํ
>
https://silvercoding.tistory.com/4
[MNIST ํ๋ก์ ํธ] 2. MNIST ๋ฐ์ดํฐ์ ์ ์ฒ๋ฆฌ, ์๊ฐํ
(์ด๋ฒ ํ๋ก์ ํธ ์ฝ๋๋ ํจ์บ ๋ฅ๋ฌ๋ ๊ฐ์๋ฅผ ์ฐธ๊ณ ํ ์ฝ๋์ด๋ค) <์ด์ ํฌ์คํ > https://silvercoding.tistory.com/3 [MNIST ํ๋ก์ ํธ] 1. MNIST ๋ฐ์ดํฐ ์์๋ณด๊ธฐ (์ด๋ฒ ํ๋ก์ ํธ ์ฝ๋๋ ํจ์บ ๋ฅ๋ฌ๋ ๊ฐ์๋ฅผ ์ฐธ๊ณ ํ
silvercoding.tistory.com
Noise ์ถ๊ฐํ๊ธฐ
https://www.tensorflow.org/tutorials/images/data_augmentation
๋ฐ์ดํฐ ์ฆ๊ฐ | TensorFlow Core
๊ฐ์ ์ด ํํ ๋ฆฌ์ผ์์๋ ์ด๋ฏธ์ง ํ์ ๊ณผ ๊ฐ์ ๋ฌด์์(๊ทธ๋ฌ๋ ์ฌ์ค์ ์ธ) ๋ณํ์ ์ ์ฉํ์ฌ ํ๋ จ ์ธํธ์ ๋ค์์ฑ์ ์ฆ๊ฐ์ํค๋ ๊ธฐ์ ์ธ ๋ฐ์ดํฐ ์ฆ๊ฐ์ ์๋ฅผ ๋ณด์ฌ์ค๋๋ค. ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ผ๋ก ๋ฐ์ดํฐ ์ฆ
www.tensorflow.org
์ฐ์ Data augmentation ์ ๋ฌด์์ ๋ณํ์ ์ ์ฉํ์ฌ ํ๋ จ ์ธํธ์ ๋ค์์ฑ์ ์ฆ๊ฐ์ํค๋ ๊ธฐ์ ์ด๋ค.
์ด ์ฌ์ง๊ณผ ๊ฐ์ด ์ฌ๋ ๋์๋ ํ์ ์ ํ๋ ํ๋๋ฅผ ํ๋ ๊ฐ์ ๊ฝ์ด๋ผ๋ ๊ฑธ ํ๋ณํ ์ ์์ง๋ง, ์ปดํจํฐ ์ ์ฅ์์๋ ์๋ก ๋ค๋ฅธ ์ฌ์ง์ผ๋ก ์ ๋ ฅ๋๋ค๋ ๊ฒ์ด๋ค. ๋ฐ๋ผ์ ์ด๋ฌํ ๋ฌด์์ ๋ณํ์ ์์ผ ํ๋ จ์ธํธ์ ๋ค์ํ๋ฅผ ํ๊ณ ์ ํ๋ค.
์ด ๊ธ์์๋ MNIST์ ์ด๋ฌํ Noise๋ฅผ ์
ํ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ ๊ฒ์ด๋ค.
(1) (28, 28) ํฌ๊ธฐ์ ๋๋ค ๋
ธ์ด์ฆ ์์ฑํ๊ธฐ
- np.random.random
print(np.random.random((2, 2)))
np.random.random() ํจ์๋ฅผ ์ฌ์ฉํ๋ฉด 0-1์ฌ์ด์ ์ค์๊ฐ ๋์ค๊ฒ ๋๋ค. ๊ดํธ ์์ ์ฌ์ด์ฆ๋ฅผ ์ ๋ ฅํด์ฃผ๋ฉด
์ด๋ ๊ฒ (2, 2) ํํ๋ก ๋๋ค๊ฐ์ด ๋์ค๋ ๊ฒ์ ์ ์ ์๋ค.
np.random.random((28,28)).shape
๋ฐ๋ผ์ ์ด๋ ๊ฒ ํด์ฃผ๋ฉด (28, 28) ์ฌ์ด์ฆ์ ๋๋ค ๋
ธ์ด์ฆ๊ฐ ์์ฑ๋๋ค.
์ด๋ฅผ plt.imshow()์ ๋ฃ์ด ํ์ธํด๋ณด๋ฉด ์์์ ๋ณด์๋ ๋
ธ์ด์ฆ ๊ทธ๋ฆผ์ ๋ณผ ์ ์์ ๊ฒ์ด๋ค.
๊ทธ๋ฐ๋ฐ ์์์ ๋ณด์๋ ๊ทธ๋ฆผ๋ณด๋ค๋ ์งํ๋ค. ๋
ธ์ด์ฆ๋ฅผ ์ฃผ๊ธฐ์ ๋๋ฌด ์ธ๋ค.
- np.random.normal
print(np.random.normal(0.0, 0.1, (28, 28)))
๊ทธ๋์ np.random.normal๋ก ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ ์ง์ ํด์ค๋ค. ํ๊ท 0, ํ์คํธ์ฐจ 0.1 ๋ก ์ง์ ํด์ค๋ค.
์ด๋ฅผ ๊ทธ๋ํ๋ก ๊ทธ๋ ค์ฃผ๋ฉด
์ ๋นํ ๋
ธ์ด์ฆ๊ฐ ์์ฑ๋์๋ค!
(2) ์ด๋ฏธ์ง ํ์ฅ์ ์ ์ฉํด๋ณด๊ธฐ
777๋ฒ์งธ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์์๋ณด์.
noisy_image = train_images[777] + np.random.normal(0.5, 0.1, (28, 28))
์ฐจ์ด๋ฅผ ๋ ์ ๋ช ํ ๋ณด๊ธฐ ์ํด ํ๊ท ์ 0.5๋ก ์ค๋ค.
๊ทธ๋ํ๋ฅผ ๊ทธ๋ ค๋ณด๋ ๋ ธ์ด์ฆ๊ฐ ์๊ฒผ์ง๋ง 1์ด ๋๋ ๊ฐ์ด ์๊ฒจ๋ฒ๋ฆฐ๋ค.
noisy_image[noisy_image > 1.0] = 1.0
๊ทธ๋์ 1.0์ด ๋๋ ๊ฐ์ 1.0์ผ๋ก ๋์ฒดํ๋ค๋ ์ฝ๋๋ฅผ ์์ฑํด์ฃผ๋ฉด
0๊ณผ 1์ฌ์ด์ ๊ฐ์ผ๋ก ์ด๋ฃจ์ด์ง ๋
ธ์ด์ฆ ์ด๋ฏธ์ง๊ฐ ์์ฑ๋๋ค.
(3) ๋ชจ๋ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ ์ ์ฉํ๊ธฐ
train_noisy_images = train_images + np.random.normal(0.5, 0.1, train_images.shape)
train_noisy_images[train_noisy_images > 1.0] = 1.0
test_noisy_images = test_images + np.random.normal(0.5, 0.1, test_images.shape)
test_noisy_images[test_noisy_images > 1.0] = 1.0
์ต์ข
์ ์ผ๋ก train์ด๋ฏธ์ง์ test์ด๋ฏธ์ง ๋ชจ๋ ๋
ธ์ด์ฆ๋ฅผ ์ ์ฉ์ํค๋ ์ฝ๋์ด๋ค.
์ ๋ฒ์๊ฐ์ ์ฌ๋ฌ์ฅ์ ์ด๋ฏธ์ง๋ฅผ ํ๋ฒ์ ์๊ฐํํ๋ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ ์ฒซ 5๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ์ถ๋ ฅํด๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ์ด ์ ์์ ์ผ๋ก ๋์ค๋ ๊ฒ์ ๋ณผ ์ ์๋ค.
๋๋์ด
๋ชจ๋ธ๋ง ํ๊ธฐ
(1) ๋ชจ๋ธ๋ง ์ค๋น - ๋ผ๋ฒจ ์ํซ์ธ์ฝ๋ฉ ์์
(๋ฐฐ์น์ฌ์ด์ฆ,) -> (๋ฐฐ์น์ฌ์ด์ฆ, ํด๋์ค ๊ฐ์)
(60000,) (10000,) ์ ํํ์๋ ๋ผ๋ฒจ์ (60000, 10) (10000, 10) ์ ํํ๋ก one-hot encoding ํด์ค ๊ฒ์ด๋ค.
from keras.utils import to_categorical
train_labels = to_categorical( train_labels, 10)
test_labels = to_categorical( test_labels, 10)
keras.utils์ to_categorical์ import ํ์ฌ ์ฌ์ฉํ๋ค. to_categorical(์ํซ์ธ์ฝ๋ฉํ ๋ผ๋ฒจ, ํด๋์ค ๊ฐ์) ์ด๋ ๊ฒ ์ฌ์ฉํ๋ฉด ๋๋ค.
(2) simpleRNN classification ๋ชจ๋ธ ์์ฑ
from keras.layers import simpleRNN
from keras.layers import Dense, Input
from keras.models import Model
inputs = Input(shape=(28, 28))
x1 = simpleRNN(64, activation="tanh")(inputs)
x2 = Dense(10, activation="softmax")(x1)
model = Model(inputs, x2)
keras.layers์ simpleRNN์ผ๋ก ๋ชจ๋ธ ์์ฑ์ ํ๋ค. activation ํจ์๋ ๊ฐ๊ฐ tanh, softmax๋ก ๊ตฌ์ฑ์ด ๋์ด์๋ค.
model.summary()
summaryํจ์๋ฅผ ์ด์ฉํ์ฌ ์์ฝ์ ๋ณด๋ฅผ ์ป์ด์ฌ ์ ์๋ค. ํ๋ผ๋ฏธํฐ์ ๊ฐ์์ ์์ํ shape์ ์ ์ ์๋ค.
(3) loss, optimizer, metrics ์ค์
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics = ["accuracy"])
compile ํจ์๋ฅผ ์ด์ฉํ์ฌ ์์คํจ์๋ categorical crossentropy, optimizer ๋ adam, ์งํ๋ ์ ํ๋๋ก ์ค์ ํด ์ค๋ค.
(4) ํ์ต์ํค๊ธฐ
hist = model.fit(train_noisy_images, train_labels, validation_data=(test_noisy_images, test_labels), epochs=5, verbose=2)
๋ค๋ฅธ ๊ฑด ๋ค ์์ ๊ฐ๋ฅํ์ง๋ง verbose๋ ๋ฌด์์ธ์ง ์ ๋ชจ๋ฅด๊ฒ ์ด์ ์ฐพ์๋ณด์๋ค.
verbose: 'auto', 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. 'auto' defaults to 1 for most cases, but 2 when used with ParameterServerStrategy. Note that the progress bar is not particularly useful when logged to a file, so verbose=2 is recommended
<์ถ์ฒ>
https://keras.io/api/models/model_training_apis/
Keras documentation: Model training APIs
Model training APIs compile method Model.compile( optimizer="rmsprop", loss=None, metrics=None, loss_weights=None, weighted_metrics=None, run_eagerly=None, steps_per_execution=None, **kwargs ) Configures the model for training. Arguments optimizer: String
keras.io
** ๋น๊ตํด๋ณด๊ธฐ
- verbose = 1
- verbose = 2
(5) ํ์ต ๊ฒฐ๊ณผ ํ์ธ
plt.plot(hist.history['accuracy'], label='accuracy') plt.plot(hist.history['loss'], label='loss') plt.plot(hist.history['val_accuracy'], label='val_accuracy') plt.plot(hist.history['val_loss'], label='val_loss') plt.legend(loc='upper left') plt.show()
ํ์ตํ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋ํ๋ก ๊ทธ๋ ค๋ณด์์ ๋ ์ ํ๋๋ ๋งค์ฐ ๋๊ณ ์ค๋ฅ๋ ๋งค์ฐ ๋ฎ์ ๊ฑธ ๋ณผ ์ ์๋ค. ๊ฐ๋จํ RNN๋ชจ๋ธ๋ก ๊ตฌํ์ ํ์ฌ๋ ์ฑ๋ฅ์ด ๊ด์ฐฎ๋ค!
--- ์์ฑ๋ ๋ชจ๋ธ์ test ์ด๋ฏธ์ง ํ์ฅ์ผ๋ก ๊ฒฐ๊ณผ ํ์ธํด๋ณด๊ธฐ
res = model.predict( test_noisy_images[777:778] )
777๋ฒ์งธ ์ด๋ฏธ์ง๋ฅผ ํ์ธํด๋ณด์.
plt.bar(range(10), res[0], color='red') plt.bar(np.array(range(10)) + 0.35, test_labels[777]) plt.show()
red๊ฐ ์์ธกํ ํ๋ฅ , blue๊ฐ ์ ๋ต์ด๋ค. ๋ณด๋ฉด 1๋ก ์ ์์ธกํ์ง๋ง, 7๊ณผ 8๋ก ์์ธกํ ๊ฒ์ด ๋ฏธ์ธํ๊ฒ ๋ณด์ธ๋ค. ์ฑ๋ฅ์ ๋์์ง ์์๋ณด์ธ๋ค.
(6) ํ
์คํธ ๋ฐ์ดํฐ์
์ผ๋ก ํ๊ฐํ๊ธฐ
loss, acc = model.evaluate(test_noisy_images, test_labels, verbose=2) print(loss, acc)
evaluate์ ํ ์คํธ ๋ฐ์ดํฐ์ ์ ๋ฃ์ด์ฃผ๋ฉด ๋๋ค.
์ ํ๋ 95%๋ก ๋ชจ๋ธ ํ๊ฐ๊น์ง ๋ง์ณค๋ค.
(7) ๋ชจ๋ธ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ
# ๋ชจ๋ธ ์ ์ฅ
model.save("./mnist_rnn.h5")
# ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
new_model = tf.keras.models.load_model('./mnist_rnn.h5')
h5๋ก ์ ์ฅํด์ฃผ๋ฉด ๋๋ค.
** ํน์ ์ฝ๋ฉ์ผ๋ก ํ๋ค๋ฉด, ์ฝ๋ฉ์ ์ ์ฅ๋ ๋ชจ๋ธ์ ์ปดํจํฐ์ ์ ์ฅํ๋ ์ฝ๋
from google.colab import files
files.download('./mnist_rnn.h5')
'๋ฐ์ดํฐ ๋ถ์ ์ด๋ก > ๋ฅ๋ฌ๋' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[celeba ํ๋ก์ ํธ] 2. celeba ๋ฐ์ดํฐ์ ์ ์ฒ๋ฆฌ, ์๊ฐํ (0) | 2021.06.07 |
---|---|
[celeba ํ๋ก์ ํธ] 1. celeba ๋ฐ์ดํฐ ์ดํด๋ณด๊ธฐ (0) | 2021.06.07 |
[MNIST ํ๋ก์ ํธ] 2. MNIST ๋ฐ์ดํฐ์ ์ ์ฒ๋ฆฌ, ์๊ฐํ (0) | 2021.06.07 |
[MNIST ํ๋ก์ ํธ] 1. MNIST ๋ฐ์ดํฐ ์์๋ณด๊ธฐ (0) | 2021.06.07 |
๋ฅ๋ฌ๋์ ๋ํ์ฌ (0) | 2021.05.24 |