test_birch.py
5.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
Tests for the birch clustering algorithm.
"""
from scipy import sparse
import numpy as np
import pytest
from sklearn.cluster.tests.common import generate_clustered_data
from sklearn.cluster import Birch
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import make_blobs
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import ElasticNet
from sklearn.metrics import pairwise_distances_argmin, v_measure_score
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_warns
def test_n_samples_leaves_roots():
# Sanity check for the number of samples in leaves and roots
X, y = make_blobs(n_samples=10)
brc = Birch()
brc.fit(X)
n_samples_root = sum([sc.n_samples_ for sc in brc.root_.subclusters_])
n_samples_leaves = sum([sc.n_samples_ for leaf in brc._get_leaves()
for sc in leaf.subclusters_])
assert n_samples_leaves == X.shape[0]
assert n_samples_root == X.shape[0]
def test_partial_fit():
# Test that fit is equivalent to calling partial_fit multiple times
X, y = make_blobs(n_samples=100)
brc = Birch(n_clusters=3)
brc.fit(X)
brc_partial = Birch(n_clusters=None)
brc_partial.partial_fit(X[:50])
brc_partial.partial_fit(X[50:])
assert_array_almost_equal(brc_partial.subcluster_centers_,
brc.subcluster_centers_)
# Test that same global labels are obtained after calling partial_fit
# with None
brc_partial.set_params(n_clusters=3)
brc_partial.partial_fit(None)
assert_array_equal(brc_partial.subcluster_labels_, brc.subcluster_labels_)
def test_birch_predict():
# Test the predict method predicts the nearest centroid.
rng = np.random.RandomState(0)
X = generate_clustered_data(n_clusters=3, n_features=3,
n_samples_per_cluster=10)
# n_samples * n_samples_per_cluster
shuffle_indices = np.arange(30)
rng.shuffle(shuffle_indices)
X_shuffle = X[shuffle_indices, :]
brc = Birch(n_clusters=4, threshold=1.)
brc.fit(X_shuffle)
centroids = brc.subcluster_centers_
assert_array_equal(brc.labels_, brc.predict(X_shuffle))
nearest_centroid = pairwise_distances_argmin(X_shuffle, centroids)
assert_almost_equal(v_measure_score(nearest_centroid, brc.labels_), 1.0)
def test_n_clusters():
# Test that n_clusters param works properly
X, y = make_blobs(n_samples=100, centers=10)
brc1 = Birch(n_clusters=10)
brc1.fit(X)
assert len(brc1.subcluster_centers_) > 10
assert len(np.unique(brc1.labels_)) == 10
# Test that n_clusters = Agglomerative Clustering gives
# the same results.
gc = AgglomerativeClustering(n_clusters=10)
brc2 = Birch(n_clusters=gc)
brc2.fit(X)
assert_array_equal(brc1.subcluster_labels_, brc2.subcluster_labels_)
assert_array_equal(brc1.labels_, brc2.labels_)
# Test that the wrong global clustering step raises an Error.
clf = ElasticNet()
brc3 = Birch(n_clusters=clf)
with pytest.raises(ValueError):
brc3.fit(X)
# Test that a small number of clusters raises a warning.
brc4 = Birch(threshold=10000.)
assert_warns(ConvergenceWarning, brc4.fit, X)
def test_sparse_X():
# Test that sparse and dense data give same results
X, y = make_blobs(n_samples=100, centers=10)
brc = Birch(n_clusters=10)
brc.fit(X)
csr = sparse.csr_matrix(X)
brc_sparse = Birch(n_clusters=10)
brc_sparse.fit(csr)
assert_array_equal(brc.labels_, brc_sparse.labels_)
assert_array_almost_equal(brc.subcluster_centers_,
brc_sparse.subcluster_centers_)
def check_branching_factor(node, branching_factor):
subclusters = node.subclusters_
assert branching_factor >= len(subclusters)
for cluster in subclusters:
if cluster.child_:
check_branching_factor(cluster.child_, branching_factor)
def test_branching_factor():
# Test that nodes have at max branching_factor number of subclusters
X, y = make_blobs()
branching_factor = 9
# Purposefully set a low threshold to maximize the subclusters.
brc = Birch(n_clusters=None, branching_factor=branching_factor,
threshold=0.01)
brc.fit(X)
check_branching_factor(brc.root_, branching_factor)
brc = Birch(n_clusters=3, branching_factor=branching_factor,
threshold=0.01)
brc.fit(X)
check_branching_factor(brc.root_, branching_factor)
# Raises error when branching_factor is set to one.
brc = Birch(n_clusters=None, branching_factor=1, threshold=0.01)
with pytest.raises(ValueError):
brc.fit(X)
def check_threshold(birch_instance, threshold):
"""Use the leaf linked list for traversal"""
current_leaf = birch_instance.dummy_leaf_.next_leaf_
while current_leaf:
subclusters = current_leaf.subclusters_
for sc in subclusters:
assert threshold >= sc.radius
current_leaf = current_leaf.next_leaf_
def test_threshold():
# Test that the leaf subclusters have a threshold lesser than radius
X, y = make_blobs(n_samples=80, centers=4)
brc = Birch(threshold=0.5, n_clusters=None)
brc.fit(X)
check_threshold(brc, 0.5)
brc = Birch(threshold=5.0, n_clusters=None)
brc.fit(X)
check_threshold(brc, 5.)
def test_birch_n_clusters_long_int():
# Check that birch supports n_clusters with np.int64 dtype, for instance
# coming from np.arange. #16484
X, _ = make_blobs(random_state=0)
n_clusters = np.int64(5)
Birch(n_clusters=n_clusters).fit(X)