Home Reference Source


// Internal dependencies
import Neighbors from './base';
import * as Arrays from '../../arrays';

 * k-nearest neighbours learner. Classifies points based on the (possibly weighted) vote
 * of its k nearest neighbours (euclidian distance).
export default class KNN extends Neighbors {
   * Constructor. Initialize class members and store user-defined options.
   * @param {Object} [optionsUser] - User-defined options for KNN
   * @param {number} [optionsUser.numNeighbours = 3] - Number of nearest neighbours to consider for
   *   the majority vote
  constructor(optionsUser = {}) {

    // Parse options
    const optionsDefault = {
      numNeighbours: 3,

    const options = {

    // Set options
    this.numNeighbours = options.numNeighbours;

   * @see {@link Classifier#train}
  train(X, y) {
    if (X.length !== y.length) {
      throw new Error('Number of data points should match number of labels.');

    // Store data points
    this.training = { X, y };

   * @see {@link Classifier#predict}
  predict(X) {
    if (typeof this.training === 'undefined') {
      throw new Error('Model has to be trained in order to make predictions.');

    if (X[0].length !== this.training.X[0].length) {
      throw new Error('Number of features of test data should match number of features of training data.');

    // Make prediction for each data point
    const predictions = X.map(x => this.predictSample(x));

    return predictions;

   * Make a prediction for a single sample.
   * @param {Array.<number>} sampleFeatures - Data point features
   * @return {mixed} Prediction. Label of class with highest prevalence among k nearest neighbours
  predictSample(sampleFeatures) {
    // Calculate distances to all other data points
    const distances = Arrays.zipWithIndex(
        x => Arrays.norm(Arrays.sum(sampleFeatures, Arrays.scale(x, -1)))

    // Sort training data points based on distance
    distances.sort((a, b) => {
      if (a[0] > b[0]) return 1;
      if (a[0] < b[0]) return -1;
      return 0;

    // Number of nearest neighbours to consider
    const k = Math.min(this.numNeighbours, distances.length);

    // Take top k distances
    const distancesTopKClasses = distances.slice(0, k).map(x => this.training.y[x[1]]);

    // Count the number of neighbours per class
    const votes = Arrays.valueCounts(distancesTopKClasses);

    // Get class index with highest number of votes
    let highest = -1;
    let highestLabel = -1;

    votes.forEach((vote) => {
      if (vote[1] > highest) {
        highest = vote[1];
        highestLabel = vote[0];

    return highestLabel;