Home Reference Source

src/supervised/trees/random-forest.js

// Internal dependencies
import { Classifier } from '../base';
import DecisionTree from './decision-tree';
import * as Arrays from '../../arrays';
import * as Random from '../../random';

/**
 * Random forest learner. Builds multiple decision trees with a random subsample of the samples,
 * and averages their predictions for the final prediction model.
 */
export default class RandomForest extends Classifier {
  /**
   * Constructor. Initialize class members and store user-defined options.
   *
   * @param {Object} [optionsUser] - User-defined options for random forest
   * @param {number} [optionsUser.numTrees = 10] - Number of decision trees to build
   * @param {string} [optionsUser.criterion = 'gini'] - Splitting criterion. Either 'gini', for the
   *   Gini coefficient, or 'entropy' for the Shannon entropy
   * @param {number|string} [optionsUser.numFeatures = 1.0] - Number of features to subsample at
   *   each node. Either a number (float), in which case the input fraction of features is used
   *   (e.g., 1.0 for all features), or a string. If string, 'sqrt' and 'log2' are supported,
   *   causing the algorithm to use sqrt(n) and log2(n) features, respectively (where n is the
   *   total number of features)
   * @param {number} [optionsUser.maxDepth = -1] - Maximum depth of each decision tree. The depth of
   *   a decision tree is the number of nodes in the longest path from the decision tree root to a
   *   leaf. It is an indicator of the complexity of the tree. Use -1 for no maximum depth
   * @param {boolean} [bootstrap = true] - Whether to select samples for each tree by bootstrapping.
   *   If false, all samples are used for each tree. If true, n samples are drawn with replacement
   *   from the full set of samples for each tree (where n is the total number of samples)
   * @param {number} [optionsUser.numTrees = 10] - Number of trees to construct
   */
  constructor(optionsUser = {}) {
    super();

    // Parse options
    const optionsDefault = {
      criterion: 'gini',
      numFeatures: 1.0,
      maxDepth: -1,
      numTrees: 10,
      bootstrap: true,
    };

    const options = {
      ...optionsDefault,
      ...optionsUser,
    };

    // Set options
    this.criterion = options.criterion;
    this.numFeatures = options.numFeatures;
    this.maxDepth = options.maxDepth;
    this.numTrees = options.numTrees;
    this.bootstrap = options.bootstrap;
  }

  /**
   * @see {@link Classifier#train}
   */
  train(X, y) {
    if (X.length !== y.length) {
      throw new Error('Number of data points should match number of labels.');
    }

    // Construct and train decision trees
    this.trees = [];

    // All sample indices
    const sampleIndices = [...Array(X.length)].map((x, i) => i);

    for (let i = 0; i < this.numTrees; i += 1) {
      // Construct decision tree
      const tree = new DecisionTree({
        criterion: this.criterion,
        numFeatures: this.numFeatures,
        maxDepth: this.maxDepth,
      });

      // Select the input samples. If bootstrapping is disabled, use all samples. If it is enabled,
      // use a bootstrapped sample of all samples
      let treeX;
      let treeY;

      if (this.bootstrap) {
        const treeSamples = Random.sample(sampleIndices, X.length, true);
        treeX = treeSamples.map(sampleIndex => X[sampleIndex]);
        treeY = treeSamples.map(sampleIndex => y[sampleIndex]);
      } else {
        treeX = X;
        treeY = y;
      }

      // Train the tree
      tree.train(treeX, treeY);

      // Add the trained tree to the list of trees
      this.trees.push(tree);
    }
  }

  /**
   * @see {@link Classifier#predict}
   */
  predict(X) {
    if (typeof this.trees === 'undefined') {
      throw new Error('Model has to be trained in order to make predictions.');
    }

    // Make prediction for each data point
    const predictions = X.map(x => this.predictSample(x));

    return predictions;
  }

  /**
   * Make a prediction for a single sample.
   *
   * @param {Array.<number>} sampleFeatures - Data point features
   * @return {mixed} Prediction. Label of class with highest prevalence among k nearest neighbours
   */
  predictSample(sampleFeatures) {
    // Gather predictions from all trees
    const predictions = this.trees.map(x => x.predictSample(sampleFeatures));

    // Count the number of votes for each class
    const predictionCounts = Arrays.valueCounts(predictions);

    // Predict the class with the most predictions
    return predictionCounts.reduce((r, x) => (x[1] > r[1] ? x : r), [-1, -1])[0];
  }
}