Home Reference Source

src/supervised/neighbors/knn.js

// Internal dependencies
import Neighbors from './base';
import * as Arrays from '../../arrays';

/**
 * k-nearest neighbours learner. Classifies points based on the (possibly weighted) vote
 * of its k nearest neighbours (euclidian distance).
 */
export default class KNN extends Neighbors {
  /**
   * Constructor. Initialize class members and store user-defined options.
   *
   * @param {Object} [optionsUser] - User-defined options for KNN
   * @param {number} [optionsUser.numNeighbours = 3] - Number of nearest neighbours to consider for
   *   the majority vote
   */
  constructor(optionsUser = {}) {
    super();

    // Parse options
    const optionsDefault = {
      numNeighbours: 3,
    };

    const options = {
      ...optionsDefault,
      ...optionsUser,
    };

    // Set options
    this.numNeighbours = options.numNeighbours;
  }

  /**
   * @see {@link Classifier#train}
   */
  train(X, y) {
    if (X.length !== y.length) {
      throw new Error('Number of data points should match number of labels.');
    }

    // Store data points
    this.training = { X, y };
  }

  /**
   * @see {@link Classifier#predict}
   */
  predict(X) {
    if (typeof this.training === 'undefined') {
      throw new Error('Model has to be trained in order to make predictions.');
    }

    if (X[0].length !== this.training.X[0].length) {
      throw new Error('Number of features of test data should match number of features of training data.');
    }

    // Make prediction for each data point
    const predictions = X.map(x => this.predictSample(x));

    return predictions;
  }

  /**
   * Make a prediction for a single sample.
   *
   * @param {Array.<number>} sampleFeatures - Data point features
   * @return {mixed} Prediction. Label of class with highest prevalence among k nearest neighbours
   */
  predictSample(sampleFeatures) {
    // Calculate distances to all other data points
    const distances = Arrays.zipWithIndex(
      this.training.X.map(
        x => Arrays.norm(Arrays.sum(sampleFeatures, Arrays.scale(x, -1)))
      )
    );

    // Sort training data points based on distance
    distances.sort((a, b) => {
      if (a[0] > b[0]) return 1;
      if (a[0] < b[0]) return -1;
      return 0;
    });

    // Number of nearest neighbours to consider
    const k = Math.min(this.numNeighbours, distances.length);

    // Take top k distances
    const distancesTopKClasses = distances.slice(0, k).map(x => this.training.y[x[1]]);

    // Count the number of neighbours per class
    const votes = Arrays.valueCounts(distancesTopKClasses);

    // Get class index with highest number of votes
    let highest = -1;
    let highestLabel = -1;

    votes.forEach((vote) => {
      if (vote[1] > highest) {
        highest = vote[1];
        highestLabel = vote[0];
      }
    });

    return highestLabel;
  }
}