大きなデータセットがあり、オンデマンドでクラスタを抽出する必要がある場合は、numpy.where
を使用していくつかのスピードアップが表示されます。
from sklearn.cluster import KMeans
from sklearn import datasets
import numpy as np
centers = [[1, 1], [-1, -1], [1, -1]]
iris = datasets.load_iris()
X = iris.data
y = iris.target
km = KMeans(n_clusters=3)
km.fit(X)
ここで提供するcluster_idのインデックスを抽出する関数を定義します。 (ここでは、2つの機能がベンチマークのために、彼らの両方が同じ値を返す、です):
def ClusterIndicesNumpy(clustNum, labels_array): #numpy
return np.where(labels_array == clustNum)[0]
def ClusterIndicesComp(clustNum, labels_array): #list comprehension
return np.array([i for i, x in enumerate(labels_array) if x == clustNum])
はあなたがクラスタ2
内にあるすべてのサンプルをしたいとしましょう:
ClusterIndicesNumpy(2, km.labels_)
array([ 52, 77, 100, 102, 103, 104, 105, 107, 108, 109, 110, 111, 112,
115, 116, 117, 118, 120, 122, 124, 125, 128, 129, 130, 131, 132,
134, 135, 136, 137, 139, 140, 141, 143, 144, 145, 147, 148])
numpyのは、ベンチマークを受賞:
%timeit ClusterIndicesNumpy(2,km.labels_)
100000 loops, best of 3: 4 µs per loop
%timeit ClusterIndicesComp(2,km.labels_)
1000 loops, best of 3: 479 µs per loop
今、あなたはそのようなあなたのクラスタ2のデータ点のすべてを抽出することができます。
X[ClusterIndicesNumpy(2,km.labels_)]
array([[ 6.9, 3.1, 4.9, 1.5],
[ 6.7, 3. , 5. , 1.7],
[ 6.3, 3.3, 6. , 2.5],
... #truncated
上記切り捨て配列から最初の3つの指標をダブルチェック:
print X[52], km.labels_[52]
print X[77], km.labels_[77]
print X[100], km.labels_[100]
[ 6.9 3.1 4.9 1.5] 2
[ 6.7 3. 5. 1.7] 2
[ 6.3 3.3 6. 2.5] 2
内のすべてのデータ点をフィルタリングするために、例えば、 フィルタリングするのは非常に容易です。しかし、データポイントの多くが繰り返されるときに、ラベルを取得するためにそれらのすべてを反復することは効率的ではありません。私はちょうど与えられたクラスタのデータポイントのリストでした。これを行う別の方法はありませんか? – user77005