src/unsupervised/neighbors/k-means.js
// Internal dependencies
import Clusterer from '../base';
import * as Arrays from '../../arrays';
import * as Random from '../../random';
/**
* k-means clusterer.
*/
export default class KMeans extends Clusterer {
/**
* Constructor. Initialize class members and store user-defined options.
*
* @param {Object} [optionsUser] - User-defined options for KNN
* @param {number} [optionsUser.numClusters = 8] - Number of clusters to assign in total
* @param {string} [optionsUser.initialization = 'random'] - Initialization procedure for cluster
* centers. Either 'random', for randomly selecting (without replacement) a datapoint for each
* cluster center, or 'kmeans++', for initializing cluster centroids with the
* [kmeans++ procedure](https://en.wikipedia.org/wiki/K-means%2B%2B)
*/
constructor(optionsUser = {}) {
super();
// Parse options
const optionsDefault = {
numClusters: 2,
initialization: 'kmeans++',
};
const options = {
...optionsDefault,
...optionsUser,
};
// Set options
this.numClusters = options.numClusters;
this.initialization = options.initialization;
}
/**
* Initialize the centroids of each of the clusters based on the user's settings
*
* @param {Array.<Array.<number>>} X - Features per data point
*/
initializeCentroids(X) {
if (this.initialization === 'kmeans++') {
// Clear list of centroids
this.centroids = [];
// Get indices [0, ..., n-1] for n datapoints
let indices = [...Array(this.numSamples)].map((x, i) => i);
for (let i = 0; i < this.numClusters; i += 1) {
let weights;
if (this.centroids.length) {
// Step 1. Compute the distance of each sample to the nearest cluster centroid
const minDistances = indices.map(x =>
// Minimize distance to nearest centroid by maximizing negative squared distance
Math.min(
// Calculate negative squared distance from sample to each centroid
...this.centroids.map(centroid =>
Arrays.norm(Arrays.sum(centroid, Arrays.scale(X[x], -1)))
)
)
);
if (minDistances.filter(x => x > 0).length > 0) {
// Step 2a. Calculate squared distances, which will be used as the weights for sampling
// a data point for the new cluster centroid
weights = Arrays.power(minDistances, 2);
} else {
// Step 2b. If all remaining samples have distance 0 to the nearest cluster centroid,
// there are (too many) samples with the exact same coordinates. This is a rare case.
// However, it can happen, for example when you have 3 clusters and 3 samples, and 2 of
// the samples have the same features
weights = 'uniform';
}
} else {
weights = 'uniform';
}
// Step 4. Choose a data point from the remaining data points at random, with the computed
// sample weights. Use it as the new cluster centroid, and remove it from the list of
// potential cluster centroids
const sampleIndex = Random.sample(indices, 1, false, weights)[0];
this.centroids.push(X[sampleIndex]);
indices = indices.filter(x => x !== sampleIndex);
}
} else {
// Random initialization. Each centroid is chosen randomly without replacement from the data
// points
// Get indices [0, ..., n-1] for n datapoints
const indices = [...Array(this.numSamples)].map((x, i) => i);
// Sample a random index (without replacement) for each cluster, and use its features as
// the initial centroid for that cluster
this.centroids = Random.sample(indices, this.numClusters).map(x => X[x]);
}
}
/**
* @see {@link Clusterer#train}
*/
train(X) {
// Number of features per sample
this.numSamples = Arrays.getShape(X)[0];
this.numFeatures = Arrays.getShape(X)[1];
// Check whether there aren't more clusters than samples
if (this.numSamples < this.numClusters) {
throw new Error(`Too many clusters (numClusters=${this.numClusters}) for the number for the
number of samples (numSamples=${this.numSamples}). The number of clusters should be equal to
or greater than the number of samples.`);
}
// Initialize cluster centroids
this.initializeCentroids(X);
// Keep track of current and last cluster assignments for all samples
let assignments = [];
let assignmentsPrevious;
let epoch = 0;
do {
// Recalculate clusters
if (assignments.length > 0) {
// For each cluster, calculate the new centroid as the mean of the features of all samples
// assigned to that cluster
this.centroids = this.centroids.map((centroid, clusterId) => {
const clusterNumSamples = assignments.filter(x => x === clusterId).length;
// If there are no samples assigned to this cluster, keep the centroid the same. This
// is to prevent unstable behaviour from happening
if (clusterNumSamples === 0) {
return centroid;
}
// The new cluster centroid is the mean of all samples assigned this cluster
return Arrays.scale(
// Sum of all assigned samples
Arrays.sum(
...(X.filter((x, i) => assignments[i] === clusterId))
),
// Divide by the number of assignments
1 / clusterNumSamples,
);
});
}
// Store previous assignments
assignmentsPrevious = assignments.slice();
// Assign clusters to samples
assignments = this.cluster(X);
epoch += 1;
} while (!Arrays.equal(assignments, assignmentsPrevious) && epoch < 100);
}
/**
* @see {@link Clusterer#cluster}
*/
cluster(X) {
return X.map(x =>
// Minimize distance to centroid by maximizing negative squared distance
Arrays.argMax(
// Calculate negative squared distance from sample to centroid
this.centroids.map(centroid => -Arrays.norm(Arrays.sum(centroid, Arrays.scale(x, -1))))
)
);
}
}