starbucksdolcelatte

Created DominantColors class

1 +import cv2
2 +import numpy as np
3 +from sklearn.cluster import KMeans
4 +import matplotlib.pyplot as plt
5 +from mpl_toolkits.mplot3d import Axes3D
6 +
7 +class DominantColors:
8 +
9 + CLUSTERS = None
10 + IMAGE = None
11 + COLORS = None
12 + LABELS = None
13 +
14 + def __init__(self, image, clusters=3):
15 + self.CLUSTERS = clusters
16 + self.IMAGE = image
17 +
18 +
19 + def dominantColors(self):
20 +
21 + #read image
22 + #img = cv2.imread(self.IMAGE)
23 + img = self.IMAGE
24 +
25 + #convert to rgb from bgr
26 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27 +
28 + #reshaping to a list of pixels
29 + img = img.reshape((img.shape[0] * img.shape[1], 3))
30 +
31 + #save image after operations
32 + self.IMAGE = img
33 +
34 + #using k-means to cluster pixels
35 + kmeans = KMeans(n_clusters = self.CLUSTERS)
36 + kmeans.fit(img)
37 +
38 + #the cluster centers are our dominant colors.
39 + self.COLORS = kmeans.cluster_centers_
40 +
41 + #save labels
42 + self.LABELS = kmeans.labels_
43 +
44 + #returning after converting to integer from float
45 + return self.COLORS.astype(int)
46 +
47 +
48 + def rgb_to_hex(self, rgb):
49 + return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2]))
50 +
51 +
52 + def plotClusters(self):
53 + #plotting
54 + fig = plt.figure()
55 + ax = Axes3D(fig)
56 + for label, pix in zip(self.LABELS, self.IMAGE):
57 + ax.scatter(pix[0], pix[1], pix[2], color = self.rgb_to_hex(self.COLORS[label]))
58 + plt.show()
59 +
60 +
61 + def plotHistogram(self):
62 +
63 + #labels form 0 to no. of clusters
64 + numLabels = np.arange(0, self.CLUSTERS+1)
65 +
66 + #create frequency count tables
67 + (hist, _) = np.histogram(self.LABELS, bins = numLabels)
68 + hist = hist.astype("float")
69 + hist /= hist.sum()
70 +
71 + #appending frequencies to cluster centers
72 + colors = self.COLORS
73 +
74 + #descending order sorting as per frequency count
75 + colors = colors[(-hist).argsort()]
76 + hist = hist[(-hist).argsort()]
77 +
78 + #creating empty chart
79 + chart = np.zeros((50, 500, 3), np.uint8)
80 + start = 0
81 +
82 + #creating color rectangles
83 + for i in range(self.CLUSTERS):
84 + end = start + hist[i] * 500
85 +
86 + #getting rgb values
87 + r = colors[i][0]
88 + g = colors[i][1]
89 + b = colors[i][2]
90 +
91 + #using cv2.rectangle to plot colors
92 + cv2.rectangle(chart, (int(start), 0), (int(end), 50), (r,g,b), -1)
93 + start = end
94 +
95 + #display chart
96 + plt.figure()
97 + plt.axis("off")
98 + plt.imshow(chart)
99 + plt.show()