src/supervised/base.js
// Standard imports
import * as Arrays from '../arrays';
/**
* Base class for supervised estimators (classifiers or regression models).
*/
export class Estimator {
/**
* Train the supervised learning algorithm on a dataset.
*
* @abstract
*
* @param {Array.<Array.<number>>} X - Features per data point
* @param {Array.<mixed>} y Class labels per data point
*/
train(X, y) { throw new Error('Method must be implemented child class.'); }
/**
* Make a prediction for a data set.
*
* @abstract
*
* @param {Array.<Array.<number>>} X - Features for each data point
* @return {Array.<mixed>} Predictions. Label of class with highest prevalence among k nearest
* neighbours for each sample
*/
predict(X) { throw new Error('Method must be implemented child class.'); }
}
/**
* Base class for classifiers.
*/
export class Classifier extends Estimator {
}
/**
* Base class for multiclass classifiers using the one-vs-all classification method. For a training
* set with k unique class labels, the one-vs-all classifier creates k binary classifiers. Each of
* these classifiers is trained on the entire data set, where the i-th classifier treats all samples
* that do not come from the i-th class as being from the same class. In the prediction phase, the
* one-vs-all classifier runs all k binary classifiers on the test data point, and predicts the
* class that has the highest normalized prediction value
*/
export class OneVsAllClassifier extends Classifier {
/**
* Create a binary classifier for one of the classes.
*
* @abstract
*
* @param {number} classIndex - Class index of the positive class for the binary classifier
* @return {BinaryClassifier} Binary classifier
*/
createClassifier(classIndex) { throw new Error('Method must be implemented child class.'); }
/**
* Create all binary classifiers. Creates one classifier per class.
*
* @param {Array.<number>} y - Class labels for the training data
*/
createClassifiers(y) {
// Get unique labels
const uniqueClassIndices = Arrays.unique(y);
// Initialize label set and classifier for all labels
this.classifiers = uniqueClassIndices.map((classIndex) => {
const classifier = this.createClassifier();
return {
classIndex,
classifier,
};
});
}
/**
* Get the class labels corresponding with each internal class label. Can be used to determine
* which predictino is for which class in predictProba.
*
* @return {Array.<number>} The n-th element in this array contains the class label of what is
* internally class n
*/
getClasses() {
return this.classifiers.map((x, i) => x);
}
/**
* Train all binary classifiers one-by-one
*
* @param {Array.<Array.<number>>} X - Features per data point
* @param {Array.<mixed>} y Class labels per data point
*/
trainBatch(X, y) {
this.classifiers.forEach((classifier) => {
const yOneVsAll = y.map(classIndex => ((classIndex === classifier.classIndex) ? 1 : 0));
classifier.classifier.train(X, yOneVsAll);
});
}
/**
* Train all binary classifiers iteration by iteration, i.e. start with the first training
* iteration for each binary classifier, then execute the second training iteration for each
* binary classifier, and so forth. Can be used when one needs to keep track of information per
* iteration, e.g. accuracy
*/
trainIterative() {
let remainingClassIndices = Arrays.unique(this.training.labels);
let epoch = 0;
while (epoch < 100 && remainingClassIndices.length > 0) {
const remainingClassIndicesNew = remainingClassIndices.slice();
// Loop over all 1-vs-all classifiers
for (const classIndex of remainingClassIndices) {
// Run a single iteration for the classifier
this.classifiers[classIndex].trainIteration();
if (this.classifiers[classIndex].checkConvergence()) {
remainingClassIndicesNew.splice(remainingClassIndicesNew.indexOf(classIndex), 1);
}
}
remainingClassIndices = remainingClassIndicesNew;
// Emit event the outside can hook into
this.emit('iterationCompleted');
epoch += 1;
}
// Emit event the outside can hook into
this.emit('converged');
}
/**
* @see {Classifier#predict}
*/
predict(X) {
// Get predictions from all classifiers for all data points by predicting all data points with
// each classifier (getting an array of predictions for each classifier) and transposing
const datapointsPredictions = Arrays.transpose(this.classifiers.map(classifier => classifier.classifier.predict(X, { output: 'normalized' })));
// Form final prediction by taking index of maximum normalized classifier output
return datapointsPredictions.map(x => Arrays.argMax(x));
}
/**
* Make a probabilistic prediction for a data set.
*
* @param {Array.Array.<number>} X - Features for each data point
* @return {Array.Array.<number>} Probability predictions. Each array element contains the
* probability of that particular class. The array elements are ordered in the order the classes
* appear in the training data (i.e., if class "A" occurs first in the labels list in the
* training, procedure, its probability is returned in the first array element of each
* sub-array)
*/
predictProba(X) {
if (typeof this.classifiers[0].classifier.predictProba !== 'function') {
throw new Error('Base classifier does not implement the predictProba method, which was attempted to be called from the one-vs-all classifier.');
}
// Get probability predictions from all classifiers for all data points by predicting all data
// points with each classifier (getting an array of predictions for each classifier) and
// transposing
const predictions = Arrays.transpose(
this.classifiers.map(classifier =>
classifier.classifier.predictProba(X).map(probs => probs[1])
)
);
// Scale all predictions to yield valid probabilities
return predictions.map(x => Arrays.scale(x, 1 / Arrays.internalSum(x)));
}
/**
* Retrieve the individual binary one-vs-all classifiers.
*
* @return {Array.<Classifier>} List of binary one-vs-all classifiers used as the base classifiers
* for this multiclass classifier
*/
getClassifiers() {
return this.classifiers;
}
}