Home Reference Source

src/classifiers/classifier.js


import PorterStemmer from '../stemmers/porter_stemmer';

import util from 'util';
import events from 'events';
import os from 'os';

try {
    var Threads = require('webworker-threads');
} catch (e) {
    // Since webworker-threads are optional, only thow if the module is found
    if (e.code !== 'MODULE_NOT_FOUND') throw e;
}

function checkThreadSupport() {
    if (typeof Threads === 'undefined') {
        throw new Error('parallel classification requires the optional dependency webworker-threads');
    }
}

export const Classifier = function(classifier, stemmer) {
    this.classifier = classifier;
    this.docs = [];
    this.features = {};
    this.stemmer = stemmer || PorterStemmer;
    this.lastAdded = 0;
    this.events = new events.EventEmitter();
};

function addDocument(text, classification) {
    // Ignore further processing if classification is undefined
    if(typeof classification === 'undefined') return;

    // If classification is type of string then make sure it's dosen't have blank space at both end
    if(typeof classification === 'string'){
      classification = classification.trim();
    }

    if(typeof text === 'string')
	text = this.stemmer.tokenizeAndStem(text, this.keepStops);

    if(text.length === 0) {
        // ignore empty documents
        return;
    }

    this.docs.push({
	label: classification,
	text
    });

    for (const token of text) {
        this.features[token] = (this.features[token] || 0) + 1;
    }
}

function removeDocument(text, classification) {
    const docs = this.docs;
    let doc;
    let pos;

    if (typeof text === 'string') {
      text = this.stemmer.tokenizeAndStem(text, this.keepStops);
    }

    for (var i = 0, ii = docs.length; i < ii; i++) {
      doc = docs[i];
      if (doc.text.join(' ') == text.join(' ') &&
          doc.label == classification) {
        pos = i;
      }
    }

    // Remove if there's a match
    if (!isNaN(pos)) {
      this.docs.splice(pos, 1);

      for (var i = 0, ii = text.length; i < ii; i++) {
        delete this.features[text[i]];
      }
    }
}

function textToFeatures(observation) {
    const features = [];

    if(typeof observation === 'string')
	observation = this.stemmer.tokenizeAndStem(observation, this.keepStops);

    for(const feature in this.features) {
        if(observation.includes(feature))
            features.push(1);
        else
            features.push(0);
    }

    return features;
}

function docsToFeatures(docs) {
    const parsedDocs = [];

    for (let i = 0; i < docs.length; i++) {
        const features = [];

        for (const feature in FEATURES) {
            if (docs[i].observation.includes(feature))
                features.push(1);
            else
                features.push(0);
        }

        parsedDocs.push({
            index: docs[i].index,
            features
        });
    }

    return JSON.stringify(parsedDocs);
}

function train() {
    const totalDocs = this.docs.length;
    for(let i = this.lastAdded; i < totalDocs; i++) {
        const features = this.textToFeatures(this.docs[i].text);
        this.classifier.addExample(features, this.docs[i].label);
        this.events.emit('trainedWithDocument', {index: i, total: totalDocs, doc: this.docs[i]});
        this.lastAdded++;
    }
    this.events.emit('doneTraining', true);
    this.classifier.train();
}

function trainParallel(numThreads, callback) {
    checkThreadSupport();

    if (!callback) {
        callback = numThreads;
        numThreads = undefined;
    }

    if (isNaN(numThreads)) {
        numThreads = os.cpus().length;
    }

    const totalDocs = this.docs.length;
    const threadPool = Threads.createPool(numThreads);
    const docFeatures = {};
    let finished = 0;
    const self = this;

    // Init pool; send the features array and the parsing function
    threadPool.all.eval(`var FEATURES = ${JSON.stringify(this.features)}`);
    threadPool.all.eval(docsToFeatures);

    // Convert docs to observation objects
    const obsDocs = [];
    for (var i = this.lastAdded; i < totalDocs; i++) {
        let observation = this.docs[i].text;
        if (typeof observation === 'string')
            observation = this.stemmer.tokenizeAndStem(observation, this.keepStops);
        obsDocs.push({
            index: i,
            observation
        });
    }

    // Called when a batch completes processing
    const onFeaturesResult = docs => {
        setTimeout(() => {
            self.events.emit('processedBatch', {
                size: docs.length,
                docs: totalDocs,
                batches: numThreads,
                index: finished
            });
        });

        for (let j = 0; j < docs.length; j++) {
            docFeatures[docs[j].index] = docs[j].features;
        }
    };

    // Called when all batches finish processing
    const onFinished = err => {
        if (err) {
            threadPool.destroy();
            return callback(err);
        }

        for (let j = self.lastAdded; j < totalDocs; j++) {
            self.classifier.addExample(docFeatures[j], self.docs[j].label);
            self.events.emit('trainedWithDocument', {
                index: j,
                total: totalDocs,
                doc: self.docs[j]
            });
            self.lastAdded++;
        }

        self.events.emit('doneTraining', true);
        self.classifier.train();

        threadPool.destroy();
        callback(null);
    };

    // Split the docs and start processing
    const batchSize = Math.ceil(obsDocs.length / numThreads);
    let lastError;

    for (var i = 0; i < numThreads; i++) {
        const batchDocs = obsDocs.slice(i * batchSize, (i+1) * batchSize);
        const batchJson = JSON.stringify(batchDocs);

        threadPool.any.eval(`docsToFeatures(${batchJson})`, (err, docs) => {
            lastError = err || lastError;
            finished++;

            if (docs) {
                docs = JSON.parse(docs);
                onFeaturesResult(docs);
            }

            if (finished >= numThreads) {
                onFinished(lastError);
            }
        });
    }
}

function trainParallelBatches(options) {
    checkThreadSupport();

    let numThreads = options && options.numThreads;
    let batchSize = options && options.batchSize;

    if (isNaN(numThreads)) {
        numThreads = os.cpus().length;
    }

    if (isNaN(batchSize)) {
        batchSize = 2500;
    }

    const totalDocs = this.docs.length;
    const threadPool = Threads.createPool(numThreads);
    const docFeatures = {};
    let finished = 0;
    const self = this;

    let abort = false;
    const onError = err => {
        if (!err || abort) return;
        abort = true;
        threadPool.destroy(true);
        self.events.emit('doneTrainingError', err);
    };

    // Init pool; send the features array and the parsing function
    const str = JSON.stringify(this.features);
    threadPool.all.eval(`var FEATURES = ${str};`, onError);
    threadPool.all.eval(docsToFeatures, onError);

    // Convert docs to observation objects
    let obsDocs = [];
    for (var i = this.lastAdded; i < totalDocs; i++) {
        let observation = this.docs[i].text;
        if (typeof observation === 'string')
            observation = this.stemmer.tokenizeAndStem(observation, this.keepStops);
        obsDocs.push({
            index: i,
            observation
        });
    }

    // Split the docs in batches
    const obsBatches = [];
    var i = 0;
    while (true) {
        const batch = obsDocs.slice(i * batchSize, (i+1) * batchSize);
        if (!batch || !batch.length) break;
        obsBatches.push(batch);
        i++;
    }
    obsDocs = null;
    self.events.emit('startedTraining', {
        docs: totalDocs,
        batches: obsBatches.length
    });

    // Called when a batch completes processing
    const onFeaturesResult = docs => {
        self.events.emit('processedBatch', {
            size: docs.length,
            docs: totalDocs,
            batches: obsBatches.length,
            index: finished
        });

        for (let j = 0; j < docs.length; j++) {
            docFeatures[docs[j].index] = docs[j].features;
        }
    };

    // Called when all batches finish processing
    const onFinished = () => {
        threadPool.destroy(true);
        abort = true;

        for (let j = self.lastAdded; j < totalDocs; j++) {
            self.classifier.addExample(docFeatures[j], self.docs[j].label);
            self.events.emit('trainedWithDocument', {
                index: j,
                total: totalDocs,
                doc: self.docs[j]
            });
            self.lastAdded++;
        }

        self.events.emit('doneTraining', true);
        self.classifier.train();
    };

    // Called to send the next batch to be processed
    let batchIndex = 0;
    const sendNext = () => {
        if (abort) return;
        if (batchIndex >= obsBatches.length) {
            return;
        }

        sendBatch(JSON.stringify(obsBatches[batchIndex]));
        batchIndex++;
    };

    // Called to send a batch of docs to the threads
    var sendBatch = batchJson => {
        if (abort) return;
        threadPool.any.eval(`docsToFeatures(${batchJson});`, (err, docs) => {
            if (err) {
                return onError(err);
            }

            finished++;

            if (docs) {
                docs = JSON.parse(docs);
                setTimeout(onFeaturesResult.bind(null, docs));
            }

            if (finished >= obsBatches.length) {
                setTimeout(onFinished);
            }

            setTimeout(sendNext);
        });
    };

    // Start processing
    for (var i = 0; i < numThreads; i++) {
        sendNext();
    }
}

function retrain() {
  this.classifier = new (this.classifier.constructor)();
  this.lastAdded = 0;
  this.train();
}

function retrainParallel(numThreads, callback) {
  this.classifier = new (this.classifier.constructor)();
  this.lastAdded = 0;
  this.trainParallel(numThreads, callback);
}

function getClassifications(observation) {
    return this.classifier.getClassifications(this.textToFeatures(observation));
}

function classify(observation) {
    return this.classifier.classify(this.textToFeatures(observation));
}

function restore(classifier, stemmer) {
    classifier.stemmer = stemmer || PorterStemmer;
    classifier.events = new events.EventEmitter();
    return classifier;
}

function save(filename, callback) {
    const data = JSON.stringify(this);
    const fs = require('fs');
    const classifier = this;
    fs.writeFile(filename, data, 'utf8', err => {
        if(callback) {
            callback(err, err ? null : classifier);
        }
    });
}

function load(filename, callback) {
    const fs = require('fs');

    fs.readFile(filename, 'utf8', (err, data) => {
        let classifier;

        if(!err) {
            classifier = JSON.parse(data);
        }

        if(callback)
            callback(err, classifier);
    });
}

function setOptions(options){
    this.keepStops = (options.keepStops) ? true : false;
}

Classifier.prototype.addDocument = addDocument;
Classifier.prototype.removeDocument = removeDocument;
Classifier.prototype.train = train;
if (Threads) {
  Classifier.prototype.trainParallel = trainParallel;
  Classifier.prototype.trainParallelBatches = trainParallelBatches;
  Classifier.prototype.retrainParallel = retrainParallel;
}
Classifier.prototype.retrain = retrain;
Classifier.prototype.classify = classify;
Classifier.prototype.textToFeatures = textToFeatures;
Classifier.prototype.save = save;
Classifier.prototype.getClassifications = getClassifications;
Classifier.prototype.setOptions = setOptions;
Classifier.restore = restore;
Classifier.load = load;

export default Classifier;