Showing
1 changed file
with
99 additions
and
0 deletions
dominant_colors.py
0 → 100644
1 | +import cv2 | ||
2 | +import numpy as np | ||
3 | +from sklearn.cluster import KMeans | ||
4 | +import matplotlib.pyplot as plt | ||
5 | +from mpl_toolkits.mplot3d import Axes3D | ||
6 | + | ||
7 | +class DominantColors: | ||
8 | + | ||
9 | + CLUSTERS = None | ||
10 | + IMAGE = None | ||
11 | + COLORS = None | ||
12 | + LABELS = None | ||
13 | + | ||
14 | + def __init__(self, image, clusters=3): | ||
15 | + self.CLUSTERS = clusters | ||
16 | + self.IMAGE = image | ||
17 | + | ||
18 | + | ||
19 | + def dominantColors(self): | ||
20 | + | ||
21 | + #read image | ||
22 | + #img = cv2.imread(self.IMAGE) | ||
23 | + img = self.IMAGE | ||
24 | + | ||
25 | + #convert to rgb from bgr | ||
26 | + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
27 | + | ||
28 | + #reshaping to a list of pixels | ||
29 | + img = img.reshape((img.shape[0] * img.shape[1], 3)) | ||
30 | + | ||
31 | + #save image after operations | ||
32 | + self.IMAGE = img | ||
33 | + | ||
34 | + #using k-means to cluster pixels | ||
35 | + kmeans = KMeans(n_clusters = self.CLUSTERS) | ||
36 | + kmeans.fit(img) | ||
37 | + | ||
38 | + #the cluster centers are our dominant colors. | ||
39 | + self.COLORS = kmeans.cluster_centers_ | ||
40 | + | ||
41 | + #save labels | ||
42 | + self.LABELS = kmeans.labels_ | ||
43 | + | ||
44 | + #returning after converting to integer from float | ||
45 | + return self.COLORS.astype(int) | ||
46 | + | ||
47 | + | ||
48 | + def rgb_to_hex(self, rgb): | ||
49 | + return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2])) | ||
50 | + | ||
51 | + | ||
52 | + def plotClusters(self): | ||
53 | + #plotting | ||
54 | + fig = plt.figure() | ||
55 | + ax = Axes3D(fig) | ||
56 | + for label, pix in zip(self.LABELS, self.IMAGE): | ||
57 | + ax.scatter(pix[0], pix[1], pix[2], color = self.rgb_to_hex(self.COLORS[label])) | ||
58 | + plt.show() | ||
59 | + | ||
60 | + | ||
61 | + def plotHistogram(self): | ||
62 | + | ||
63 | + #labels form 0 to no. of clusters | ||
64 | + numLabels = np.arange(0, self.CLUSTERS+1) | ||
65 | + | ||
66 | + #create frequency count tables | ||
67 | + (hist, _) = np.histogram(self.LABELS, bins = numLabels) | ||
68 | + hist = hist.astype("float") | ||
69 | + hist /= hist.sum() | ||
70 | + | ||
71 | + #appending frequencies to cluster centers | ||
72 | + colors = self.COLORS | ||
73 | + | ||
74 | + #descending order sorting as per frequency count | ||
75 | + colors = colors[(-hist).argsort()] | ||
76 | + hist = hist[(-hist).argsort()] | ||
77 | + | ||
78 | + #creating empty chart | ||
79 | + chart = np.zeros((50, 500, 3), np.uint8) | ||
80 | + start = 0 | ||
81 | + | ||
82 | + #creating color rectangles | ||
83 | + for i in range(self.CLUSTERS): | ||
84 | + end = start + hist[i] * 500 | ||
85 | + | ||
86 | + #getting rgb values | ||
87 | + r = colors[i][0] | ||
88 | + g = colors[i][1] | ||
89 | + b = colors[i][2] | ||
90 | + | ||
91 | + #using cv2.rectangle to plot colors | ||
92 | + cv2.rectangle(chart, (int(start), 0), (int(end), 50), (r,g,b), -1) | ||
93 | + start = end | ||
94 | + | ||
95 | + #display chart | ||
96 | + plt.figure() | ||
97 | + plt.axis("off") | ||
98 | + plt.imshow(chart) | ||
99 | + plt.show() |
-
Please register or login to post a comment