(๋ณธ ํ๋ก์ ํธ ์ฝ๋๋ ํจ์บ ๋ฅ๋ฌ๋ ๊ฐ์๋ฅผ ์ฐธ๊ณ ํ ์ฝ๋์ด๋ค)
์ด๋ฒ์๋ ๋ ์ ๋ช ํ ๋ฐ์ดํฐ์ธ fashion MNIST๋ฅผ ์ด์ฉํ์ฌ ๋ฉํฐ๋ ์ด๋ธ ๋ถ๋ฅ๋ฅผ ํด๋ณผ ๊ฒ์ด๋ค.
์ฌ๊ธฐ์ ๋ฉํฐ๋ ์ด๋ธ์ด ๋ฌด์์ธ์ง ์์๋ณด๊ณ ๋์ด๊ฐ์.
Multiclass vs multi-label
Binary Classification ์ ํด๋์ค๊ฐ 2๊ฐ์ง์ธ ๊ฒฝ์ฐ์ด๋ค. ์ฌ์ง์ ๋์์๋ ๊ฒ ์ฒ๋ผ (์คํธ, ๋ซ์คํธ), ์ ๋ฒ ํ๋ก์ ํธ์์ ํ์๋ ์ฑ๋ณ (๋จ, ๋ ), ์์์ฌ๋ถ (์์, ์์์) ์ด๋ฐ์์ด๋ค.
MultiClass Classification ์ ์ฌ๋ฌ๊ฐ์ ํด๋์ค๋ฅผ ๊ฐ์ง๊ณ ์๋ ๊ฒฝ์ฐ์ด๋ค. ์ ์ฌ์ง์ฒ๋ผ ์ฌ์ง์ ๊ฐ์์ง ํ๋ง๋ฆฌ๊ฐ ์๊ณ ์ฌ๋ฌ ํด๋์ค๋ค ์ค ํ ์ข ๋ฅ๋ฅผ ์์ธกํด ์ฃผ๋ ๊ฒ์ด๋ค. ์ด๋ฒ์ ํ fashion Mnist๋ฅผ ๋ฉํฐ๋ ์ด๋ธ๋ก ํ์ง ์๊ณ ๊ทธ๋๋ก ๋ถ๋ฅ๋ชจ๋ธ์ ๋ง๋ ๋ค๋ฉด ๋ฉํฐํด๋์ค ๋ถ๋ฅ๋ชจ๋ธ์ด ๋ ๊ฒ์ด๋ค.
Multi-label Classification ์ ์ฌ๋ฌ๊ฐ์ ํด๋์ค๋ฅผ ๊ฐ์ง๊ณ ์๊ณ , ๋ผ๋ฒจ๋ง๋ ์ฌ๋ฌ๊ฐ๋ก ๋์ด์๋ ๊ฒฝ์ฐ์ด๋ค. ์ ์ฌ์ง์ ๋ณด๋ฉด ์ฌ์ง ์์ ๊ณ ์์ด์ ์๊ฐ ์์ผ๋ ์ฌ๋ฌ ํด๋์ค๋ค ์ค ๋๊ฐ์ง์ ๋ผ๋ฒจ๋ง์ด ๋์ด์๋ ๊ฒ์ด๋ค.
๋ณธ ํ๋ก์ ํธ์์๋ ๋ฉํฐ ๋ ์ด๋ธ ๋ถ๋ฅ ๋ชจ๋ธ์ ๋ง๋ค ๊ฒ์ด๊ธฐ ๋๋ฌธ์ ํ ์ฌ์ง์ ์๋ฅ๋ฅผ ๋ฌด์์๋ก ๋ถ์ฌ์ฃผ๋ ์์ ์ ํ์ฌ ํ ์ฌ์ง์ ์๋ฅ๊ฐ ์ต๋ 4๊ฐ์ง๊ฐ ๋ค์ด๊ฐ ์ ์๋ ๋ฐ์ดํฐ๋ก ๋ณํ์ ํ๋ค.
<multi-label ์ฌ์ง ์ถ์ฒ>
https://www.kaggle.com/c/lish-moa/discussion/180500
Mechanisms of Action (MoA) Prediction
Can you improve the algorithm that classifies drugs based on their biological activity?
www.kaggle.com
์ด์ fashion MNIST๋ฅผ ์์๋ณด๋๋ก ํ์!
์ด๋ฒ์๋ keras์์ ์ ๊ณตํด์ฃผ๋ datasets์์ ๋ถ๋ฌ์ ์ฌ์ฉํ๋ค. ์๋์ผ๋ก ์ค์นํ๋ ค๋ฉด ๋ฐ์ ๋งํฌ๋ฅผ ์ด์ฉํ๋ฉด ๋๋ค.
<fashion MNIST ์ถ์ฒ ๋ฐ ๋ค์ด>
https://www.kaggle.com/zalando-research/fashionmnist
Fashion MNIST
An MNIST-like dataset of 70,000 28x28 labeled fashion images
www.kaggle.com
MNIST ๋ฐ์ดํฐ์ ํฌ๊ธฐ๊ฐ ๋์ผํ๊ฒ 28x28 ์ด๋ค. train dataset์ด 60,000์ฅ, test dataset์ด 10,000 ์ฅ์ธ ๊ฒ๋ ๋์ผํ๋ค.
Labels
Each training and test example is assigned to one of the following labels:
- 0 T-shirt/top
- 1 Trouser
- 2 Pullover
- 3 Dress
- 4 Coat
- 5 Sandal
- 6 Shirt
- 7 Sneaker
- 8 Bag
- 9 Ankle boot
ํด๋์ค๋ ์ด 10๊ฐ๋ก, ํฐ์ ํธ, ๋๋ ์ค, ์ ์ธ , ์๋ค, ๊ฐ๋ฐฉ ๋ฑ๋ฑ ์ฌ๋ฌ ์๋ฅ ์ข ๋ฅ๊ฐ ํฌํจ๋์ด ์๋ค. ์ด๋ ์๋ฅ์ ์ข ๋ฅ์ธ์ง ๋ถ๋ฅํด๋ด๋ ๋ชจ๋ธ์ ์์ฑํ๋ ๊ฒ์ด ๋ชฉํ์ด๋ค.
fashion MNIST ๋ฐ์ดํฐ์ ์์๋ณด๊ธฐ
์ด์ ๋ฐ์ดํฐ์ ์ ์์๋ณด๋ ์ ์ฐจ๋ ์ต์ํด์ก์ ๊ฒ์ด๋ค. ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ์์ ๋ฐ์ดํฐ์ ํฌ๊ธฐ, ๋ฒ์, ํ์ ์ ํ์ธํ๊ณ , ์ด๋ป๊ฒ ์๊ฒผ๋์ง ์๊ฐํ๋ฅผ ํด๋ณธ๋ค.
(1) ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
fashion_mnist = keras.datasets.fashion_mnist
((train_images, train_labels), (test_images, test_labels)) = fashion_mnist.load_data()
keras์ datasets์์ fashion MNIST๋ฅผ ๋ถ๋ฌ์จ๋ค.
labels = ["T-shirt/top", # index 0
"Trouser", # index 1
"Pullover", # index 2
"Dress", # index 3
"Coat", # index 4
"Sandal", # index 5
"Shirt", # index 6
"Sneaker", # index 7
"Bag", # index 8
"Ankle boot"] # index 9
def idx2label(idx):
return labels[idx]
๋ ์ด๋ธ์ ํ ์คํธ๋ฅผ ๋ฆฌ์คํธ์ ์ ์ฅํด์ ์ธ๋ฑ์ค๋ฅผ ์ด์ฉํ์ฌ ํ ์คํธ๋ฅผ ๋ถ๋ฌ์ฌ ์ ์๋ค.
idx2label ํจ์๋ฅผ ๊ตฌํํ์ฌ ๋ ์ด๋ธ์ ํจ์์ ๋ฃ์ผ๋ฉด ๋ ์ด๋ธ์ ํ ์คํธ๋ฅผ ๋ถ๋ฌ์ค๋๋ก ํ๋ ์ฝ๋. ์๊ฐํ์์ ์ฌ์ฉํ ์์ ์ด๋ค.
(2) ๋ฐ์ดํฐ์ ํฌ๊ธฐ ํ์ธ
print(f"train_images: {train_images.shape}")
print(f"train_labels: {train_labels.shape}")
print(f"test_images: {test_images.shape}")
print(f"test_labels: {test_labels.shape}")
train_images: (60000, 28, 28)
train_labels: (60000,)
test_images: (10000, 28, 28)
test_labels: (10000,)
๊ธฐ์กด MNIST์ ๊ฐ์ ํํ๋ฅผ ๋๋ ๊ฒ์ ์ ์ ์๋ค.
(3) ๋ฐ์ดํฐ์ ๋ฒ์ ํ์ธ
- image ์์ 0์ด ์๋ ๊ฐ ์ถ๋ ฅํด๋ณด๊ธฐ
train_images[train_images!=0][:50]
test_images[train_images!=0][:50]
๋๋ฌด ๋ง์ผ๋ 50๊น์ง๋ง ์ถ๋ ฅํด๋ณธ๋ค. 0์ ์ ์ธํ๊ณ 255๊น์ง์ ์ ์๋ค๋ก ์ด๋ฃจ์ด์ ธ ์์ผ๋ฉด ์ ์!
- image์ ์ต์๊ฐ, ์ต๋๊ฐ ๊ตฌํด๋ณด๊ธฐ
print(train_images.min(), train_images.max())
print(test_images.min(), test_images.max())
๋๋ค 0 255 ๊ฐ ๋์ค๋ฉด ์ ์!
***์ด๋ฏธ์ง์ ๊ฐ์ ๋ํด์ ๊ฐ์ฅ ํฐ index, ๊ฐ์ฅ ์์ index๋ฅผ ๊ตฌํด๋ณด๊ณ ์๊ฐํ ํด๋ณด๊ธฐ
์ด๋ฏธ์ง์ ๊ฐ๋ค์ ๋ชจ๋ ๋ํด์ ์ซ์๊ฐ ํฌ๋ค๋ฉด ์ท์ ํฌ๊ธฐ๊ฐ ํฌ๊ณ ์์ด ๋ฐ์ ๊ฒ์ด๊ณ , ์ซ์๊ฐ ์๋ค๋ฉด ์ท์ ํฌ๊ธฐ๊ฐ ์์ผ๋ฉด์ ์์ด ์ด๋์ธ ๊ฒ์ผ๋ก ์์ํ ์ ์๋ค. ์ ๋ง ๊ทธ๋ฐ์ง ํ์ธํด ๋ณด์.
print(train_images.reshape((60000, -1)).sum(axis=1).argmax())
print(train_images.reshape((60000, -1)).sum(axis=1).argmin())
axis=1 ๋ฐฉํฅ์ผ๋ก ๋ค ๋ํด์ฃผ๋ฉด ๊ฐ ์ด๋ฏธ์ง์ ๋ํ ๊ฐ๋ค์ ํฉ์ด ๋์ฌ ๊ฒ์ด๋ค. ๊ทธ์ค์์ ์ต๋๊ฐ์ index์ ์ต์๊ฐ์ index๋ฅผ
์ธ๋ฑ์ค๋ 55023 9230๊ฐ ๋์๋ค. ์ฌ์ง์ ์ถ๋ ฅํด๋ณด๋ฉด,
์์ํ๋ ๋๋ก ํฉ์ด ํฐ ์ด๋ฏธ์ง๋ ๋ฐ์ ๋ถ๋ถ์ด ๋ง๊ณ , ํฉ์ด ์์ ์ด๋ฏธ์ง๋ ์ด๋์ด ๋ถ๋ถ์ด ๋๋ถ๋ถ์ด๋ค.
(4) ๋ฐ์ดํฐ ํ์ ํ์ธ
print(train_images.dtype)
print(train_labels.dtype)
print(test_images.dtype)
print(test_labels.dtype)
๋ชจ๋ uint8 ์ด ๋์ค๋ฉด ์ ์!
์ด๋ฅผ ํตํด ์ ์ ์๋ ๊ฒ์ ์ ์ฒ๋ฆฌ ํ ๋ 0-1 ์ฌ์ด์ float ํํ๋ก ๋ฐ๊ฟ ์ฃผ์ด์ผ ๋๋ค๋ผ๋ ๊ฒ.
(5) ๋ฐ์ดํฐ ํ์ฅ์ฉ ์๊ฐํ ํด๋ณด๊ธฐ
def show(idx):
plt.imshow(train_images[idx], cmap='gray')
plt.title(idx2label(train_labels[idx]))
plt.show()
์๊ฐํ ํ๋ ํจ์๋ฅผ ๊ตฌํํด์ฃผ์ด ํธ๋ฆฌํ๊ฒ ์ฌ์ง์ ํ์ธํ ์ ์๋ค.
show(777)
train image์ 777๋ฒ์งธ ์ฌ์ง์ sandal
show(77)
train image์ 77๋ฒ์งธ ์ฌ์ง์ shirt ์์ ์ ์ ์๋ค.
๋ค์ ํฌ์คํธ์์๋ fashion MNIST ์ ์ฒ๋ฆฌ์ ์ฌ๋ฌ์ฅ ์๊ฐํํ๋ ๋ฐฉ๋ฒ์ ๋ํ์ฌ ์์ฑํ ์์ ์ด๋ค.