최대 1 분 소요

군집 (Clustering)

K-평균 활용

이미지 분할

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

image = plt.imread('./images/ladybug.png')
plt.imshow(image)
<matplotlib.image.AxesImage at 0x1a67a7ce550>

png

image.shape
(533, 800, 3)
X = image.reshape(-1, 3)
X.shape
(426400, 3)
kmeans = KMeans(n_clusters=8, random_state=42)
kmeans.fit(X)
KMeans(random_state=42)
kmeans.labels_ # 군집된 레이블블
array([1, 1, 1, ..., 4, 1, 1])
kmeans.labels_.shape
(426400,)
import numpy as np

np.unique(kmeans.labels_)
array([0, 1, 2, 3, 4, 5, 6, 7])
kmeans.cluster_centers_ # 8개 그룹의 센트로이드 : 각각의 그룹을 대표할 수 있는 RGB 조합 (색상정보보)
array([[0.98363745, 0.9359338 , 0.02574807],
       [0.02289337, 0.11064845, 0.00578197],
       [0.21914783, 0.38675755, 0.05800817],
       [0.75775605, 0.21225454, 0.0445884 ],
       [0.09990625, 0.2542204 , 0.01693457],
       [0.61266166, 0.63010883, 0.38751987],
       [0.37212682, 0.5235918 , 0.15730347],
       [0.8845907 , 0.7256049 , 0.03442054]], dtype=float32)
kmeans.labels_
array([1, 1, 1, ..., 4, 1, 1])
# 각그룹의 대표되는 RGB 값으로 426400개의 픽셀값을 대체
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
segmented_img.shape
(426400, 3)
segmented_img = segmented_img.reshape(image.shape)
plt.imshow(segmented_img)
<matplotlib.image.AxesImage at 0x1a67cbd19d0>

png

segmented_imgs = []
n_colors = [10, 8, 6, 4, 2]
for clusters in n_colors:
  kmeans = KMeans(n_clusters=clusters, random_state=42)
  kmeans.fit(X)
  segmented_img = kmeans.cluster_centers_[kmeans.labels_]
  segmented_img = segmented_img.reshape(image.shape)
  segmented_imgs.append(segmented_img)
plt.figure(figsize=(10, 5))
plt.subplot(231)
plt.imshow(image)
plt.title('original image')
plt.axis('off')

for idx, n_clusters in enumerate(n_colors):
  plt.subplot(232+idx)
  plt.imshow(segmented_imgs[idx])
  plt.title('{} colors'.format(n_clusters))
  plt.axis('off')

png

Reference

댓글남기기