diff --git a/package.json b/package.json index 88519c4a..b9f7920a 100644 --- a/package.json +++ b/package.json @@ -72,6 +72,7 @@ "lodash": "4.17.13", "numeric": "1.2.6", "random-js": "1.0.8", + "seedrandom": "^3.0.1", "stopword": "0.1.10" }, "devDependencies": { diff --git a/src/lib/linear_model/stochastic_gradient.ts b/src/lib/linear_model/stochastic_gradient.ts index 8e60a4e6..3015298d 100644 --- a/src/lib/linear_model/stochastic_gradient.ts +++ b/src/lib/linear_model/stochastic_gradient.ts @@ -1,7 +1,8 @@ import * as tf from '@tensorflow/tfjs'; import { cloneDeep, range } from 'lodash'; -import * as Random from 'random-js'; +// import * as Random from 'random-js'; import { IMlModel, Type1DMatrix, Type2DMatrix } from '../types'; +import RandomState, { RandomStateObj } from '../utils/random'; import { validateFitInputs, validateMatrix2D } from '../utils/validation'; export enum TypeLoss { @@ -29,7 +30,7 @@ export class BaseSGD implements IMlModel { protected regFactor: TypeRegFactor; private clone: boolean = true; private weights: tf.Tensor = null; - private randomEngine: Random.MT19937; // Random engine used to + private randomEngine: RandomStateObj; // Random engine used to private randomState: number; /** * @param preprocess - preprocess methodology can be either minmax or null. Default is minmax. @@ -42,7 +43,7 @@ export class BaseSGD implements IMlModel { learning_rate = 0.0001, epochs = 10000, clone = true, - random_state = null, + random_state, loss = TypeLoss.L2, reg_factor = null, }: { @@ -56,7 +57,7 @@ export class BaseSGD implements IMlModel { learning_rate: 0.0001, epochs: 10000, clone: true, - random_state: null, + random_state: undefined, loss: TypeLoss.L2, reg_factor: null, }, @@ -87,11 +88,7 @@ export class BaseSGD implements IMlModel { } // Random Engine - if (Number.isInteger(this.randomState)) { - this.randomEngine = Random.engines.mt19937().seed(this.randomState); - } else { - this.randomEngine = Random.engines.mt19937().autoSeed(); - } + this.randomEngine = new RandomState(this.randomState); } /** @@ -193,9 +190,8 @@ export class BaseSGD implements IMlModel { */ private initializeWeights(nFeatures: number): void { const limit = 1 / Math.sqrt(nFeatures); - const distribution = Random.real(-limit, limit); - const getRand = () => distribution(this.randomEngine); - this.weights = tf.tensor1d(range(0, nFeatures).map(() => getRand())); + const distribution = this.randomEngine.real(-limit, limit); + this.weights = tf.tensor1d(range(0, nFeatures).map(distribution)); } /** diff --git a/src/lib/model_selection/_split.ts b/src/lib/model_selection/_split.ts index 1671c3cc..23f89565 100644 --- a/src/lib/model_selection/_split.ts +++ b/src/lib/model_selection/_split.ts @@ -1,10 +1,18 @@ +import * as tf from '@tensorflow/tfjs'; import * as _ from 'lodash'; -import * as Random from 'random-js'; import { Type1DMatrix, Type2DMatrix } from '../types'; import { ValidationError } from '../utils/Errors'; -import { inferShape } from '../utils/tensors'; -import { validateFitInputs } from '../utils/validation'; +import RandomState, { RandomStateObj } from '../utils/random'; +import { approximateMode, arraySplit, countBin, inferShape, invidualize } from '../utils/tensors'; +import { numSamples, validateFitInputs } from '../utils/validation'; +const testShapes = (X: Type1DMatrix | Type2DMatrix, y: Type1DMatrix) => { + const xShape: Type1DMatrix = inferShape(X); + const yShape: Type1DMatrix = inferShape(y); + if (xShape.length > 0 && yShape.length > 0 && xShape[0] !== yShape[0]) { + throw new ValidationError('X and y must have an identical size'); + } +}; /** * K-Folds cross-validator * @@ -17,7 +25,6 @@ import { validateFitInputs } from '../utils/validation'; * * const kFold = new KFold({ k: 5 }); * const X1 = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; - * console.log(kFold.split(X1, X1)); * * /* [ { trainIndex: [ 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ], * * testIndex: [ 0, 1, 2, 3 ] }, @@ -47,7 +54,6 @@ export class KFold { this.k = k; this.shuffle = shuffle; } - /** * * @param X - Training data, where n_samples is the number of samples and n_features is the number of features. @@ -55,12 +61,7 @@ export class KFold { * @returns {any[]} */ public split(X: Type1DMatrix = null, y: Type1DMatrix = null): any[] { - const xShape = inferShape(X); - const yShape = inferShape(y); - if (xShape.length > 0 && yShape.length > 0 && xShape[0] !== yShape[0]) { - throw new ValidationError('X and y must have an identical size'); - } - + testShapes(X, y); if (this.k > X.length || this.k > y.length) { throw new ValidationError( `Cannot have number of splits k=${this.k} greater than the number of samples: ${_.size(X)}`, @@ -166,8 +167,7 @@ export function train_test_split( throw new ValidationError('Sum of test_size and train_size does not equal 1'); } // Initiate Random engine - const randomEngine = Random.engines.mt19937(); - randomEngine.seed(random_state); + const randomEngine: RandomStateObj = new RandomState(random_state); // split const xTrain = []; @@ -177,7 +177,7 @@ export function train_test_split( // Getting X_train and y_train while (xTrain.length < trainSizeLength && yTrain.length < trainSizeLength) { - const index = Random.integer(0, X.length - 1)(randomEngine); + const index = randomEngine.rangedInt(0, X.length - 1); // X_train xTrain.push(_X[index]); @@ -189,7 +189,7 @@ export function train_test_split( } while (xTest.length < testSizeLength) { - const index = Random.integer(0, _X.length - 1)(randomEngine); + const index = randomEngine.rangedInt(0, _X.length - 1); // X test xTest.push(_X[index]); _X.splice(index, 1); @@ -208,3 +208,154 @@ export function train_test_split( yTrain: clean(yTrain), }; } + +const rangeValidationError = (type, size, n_samples) => `${type}=${size} should be either +positive and smaller than number of samples ${n_samples} or a float in (0, 1) range`; + +const testRangeValidationError = (test_size, n_samples) => rangeValidationError('test_size', test_size, n_samples); + +const trainRangeValidationError = (test_size, n_samples) => rangeValidationError('test_size', test_size, n_samples); + +/** + * StratifiedShuffleSplit + */ +export class StratifiedShuffleSplit { + private n_splits: number; + private testSize: number; + private trainSize: number; + private rng: RandomStateObj; + private defaultTestSize: number = 0.1; + constructor(n_splits: number = 10, testSize: number = null, trainSize?: number, seed?: number) { + this.n_splits = n_splits; + this.testSize = testSize; + this.trainSize = trainSize; + this.rng = new RandomState(seed); + } + + split = (X: Type1DMatrix | Type2DMatrix = null, y: Type1DMatrix = null): Type2DMatrix => { + const XTensor = tf.tensor(X); + const nSamples = numSamples(XTensor); + + const [nTest, nTrain] = validateShuffleSplit(nSamples, this.testSize, this.trainSize, this.defaultTestSize); + + const [classes, yIndices] = invidualize(y); + const nClasses: number = classes.length; + const classCounts: Type1DMatrix = countBin(yIndices); + if (_.min(classCounts) < 2) { + throw new ValidationError( + `The least populated class in y=${y} has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.`, + ); + } + + if (nTrain < nClasses) { + throw new ValidationError( + `The train_size = ${nTrain} should be greater or equal to the number of classes = ${nClasses}`, + ); + } + + if (nTest < nClasses) { + throw new ValidationError( + `The test_size = ${nTest} should be greater or equal to the number of classes = ${nClasses}`, + ); + } + + const cumsumClassCounts: tf.Tensor1D = tf.cumsum(classCounts); + const classIndices = arraySplit( + yIndices.sort(), + cumsumClassCounts.slice(0, cumsumClassCounts.shape[0] - 1).arraySync(), + ); + + const test = []; + const train = []; + for (let i = 0; i < this.n_splits; i++) { + const n_i: Type1DMatrix = approximateMode(classCounts, nTrain, this.rng); + const classCountsRemaining: Type1DMatrix = classCounts.map((item, index) => n_i[index] - item); + const t_i: Type1DMatrix = approximateMode(classCountsRemaining, nTest, this.rng); + + const tempTest = []; + const tempTrain = []; + + for (let j = 0; j < nClasses; j++) { + const permutation: Type1DMatrix = this.rng.permutation(classCounts[j]); + const permIndicesClassI = permutation.map((val) => classIndices[j][val]); + tempTrain.push.apply(tempTrain, permIndicesClassI.slice(0, n_i[j])); + tempTest.push.apply(tempTest, permIndicesClassI.slice(n_i[j], n_i[j] + t_i[j])); + } + test.push(this.rng.shuffle(tempTest)); + train.push(this.rng.shuffle(tempTrain)); + } + + return [train, test]; + }; +} + +function validateShuffleSplit( + n_samples: number, + test_size: number, + train_size: number, + default_test_size: number, +): Type1DMatrix { + let n_train: number; + let n_test: number; + + if (!test_size && !train_size) { + test_size = default_test_size; + } + + if (test_size) { + if (Number.isInteger(test_size)) { + if (test_size >= n_samples || test_size <= 0) { + throw new ValidationError(testRangeValidationError(test_size, n_samples)); + } + + n_test = test_size; + } else { + if (test_size <= 0 || test_size >= 1) { + throw new ValidationError(testRangeValidationError(test_size, n_samples)); + } + + n_test = Math.ceil(test_size * n_samples); + } + } + + if (train_size) { + if (Number.isInteger(train_size)) { + if (train_size >= n_samples || train_size <= 0) { + throw new ValidationError(trainRangeValidationError(train_size, n_samples)); + } + + n_train = train_size; + } else { + if (train_size <= 0 || train_size >= 1) { + throw new ValidationError(trainRangeValidationError(train_size, n_samples)); + } + + n_train = Math.floor(train_size * n_samples); + } + } + + if (!train_size) { + n_train = n_samples - n_test; + } else if (!test_size) { + n_test = n_samples - n_train; + } + + const total = n_train + n_test; + if (total > n_samples) { + throw new ValidationError( + `The sum of train_size and test_size = ${total}, ` + + 'should be smaller than the number of ' + + `samples ${n_samples}. Reduce test_size and/or ` + + 'train_size.', + ); + } + + if (n_train === 0) { + throw new ValidationError( + `With n_samples=${n_samples}, test_size=${test_size} and train_size=${train_size}, the ` + + 'resulting train set will be empty. Adjust any of the ' + + 'aforementioned parameters.', + ); + } + return [Math.floor(n_test), Math.floor(n_train)]; +} diff --git a/src/lib/model_selection/index.ts b/src/lib/model_selection/index.ts index 8d387b7d..20d125e3 100644 --- a/src/lib/model_selection/index.ts +++ b/src/lib/model_selection/index.ts @@ -1,3 +1,3 @@ -import { KFold, train_test_split } from './_split'; +import { KFold, StratifiedShuffleSplit, train_test_split } from './_split'; -export { KFold, train_test_split }; +export { KFold, train_test_split, StratifiedShuffleSplit }; diff --git a/src/lib/utils/random.ts b/src/lib/utils/random.ts new file mode 100644 index 00000000..ed89a886 --- /dev/null +++ b/src/lib/utils/random.ts @@ -0,0 +1,123 @@ +import * as _ from 'lodash'; +import seedrandom from 'seedrandom'; +import { isNumber } from 'util'; +import { Type1DMatrix } from '../types'; + +/** + * instance of RandomState + * @ignore + */ +export interface RandomStateObj { + next(): number; + shuffle(array: Type1DMatrix): Type1DMatrix; + rangedInt(min: number, max: number): number; + choice( + choiceArray: number | Type1DMatrix, + outputSize: number, + probability?: Type1DMatrix, + ): Type1DMatrix; + permutation(num: number): Type1DMatrix; + rangedReal(min: number, max: number): number; + real(min: number, max: number): (() => number); +} + +/** + * All of Random works lie here + * @ignore + */ +export default class RandomState implements RandomStateObj { + private random; + constructor(seed: string | number = Math.random()) { + this.random = seedrandom(seed.toString()); + } + + next(): number { + return this.random(); + } + + rangedInt(min: number, max: number): number { + return min + Math.floor((max - min) * this.next()); + } + + rangedReal(min: number, max: number): number { + return min + (max - min) * this.next(); + } + + real(min: number, max: number): (() => number) { + const diff = max - min; + return () => min + diff * this.next(); + } + /** + * shuffles 1D array in place + * taken from https://github.com/TimothyGu/knuth-shuffle-seeded/blob/gh-pages/index.js + * var random = new RandomState(4); + * random.shuffle([1, 2, 3, 4, 5]) + * random.shuffle([1, 2, 3, 4, 5]) + * output-1: [5, 3, 4, 1, 2] + * output-2: [3, 4, 2, 5, 1] + * @param array type: Type1DMatrix + * @returns shuffled array + */ + shuffle(array: Type1DMatrix): Type1DMatrix { + let currentIndex = array.length; + + // While there remain elements to shuffle... + while (0 !== currentIndex) { + // Pick a remaining element... + const randomIndex = Math.floor(this.next() * currentIndex--); + + // And swap it with the current element. + const temporaryValue = array[currentIndex]; + array[currentIndex] = array[randomIndex]; + array[randomIndex] = temporaryValue; + } + + return array; + } + + pickRandomIndex(length: number, probability: Type1DMatrix): number { + const theFate: number = this.next(); + const indexToPick = Math.floor(theFate * length); + if (probability && probability[indexToPick] > theFate) { + return this.pickRandomIndex(length, probability); + } + return indexToPick; + } + + choice( + choiceArray: number | Type1DMatrix, + outputSize: number, + probability?: Type1DMatrix, + ): Type1DMatrix { + if (isNumber(choiceArray)) { + choiceArray = _.range(choiceArray); + } + + const lenChoiceArray: number = choiceArray.length; + + const outPutArray: Type1DMatrix = new Array(outputSize); + for (let i = 0; i < outputSize; i++) { + const index = this.pickRandomIndex(lenChoiceArray, probability); + outPutArray.push(choiceArray[index]); + } + return outPutArray; + } + + /** + * generates an array with number and permutates it. + * const random = new RandomState(4); + * random.shuffle(5) + * random.shuffle(5) + * output-1: [4, 2, 3, 0, 1] + * output-2: [2, 3, 1, 4, 0] + * @param num type: number + * @returns shuffled array + */ + permutation(num: number): Type1DMatrix { + return this.shuffle( + Array(num) + .fill(0) + .map(Number.call, Number), + ); + } +} diff --git a/src/lib/utils/tensors.ts b/src/lib/utils/tensors.ts index 83eed70e..97d54420 100644 --- a/src/lib/utils/tensors.ts +++ b/src/lib/utils/tensors.ts @@ -2,6 +2,7 @@ import * as tf from '@tensorflow/tfjs'; import * as _ from 'lodash'; import { Type1DMatrix, Type2DMatrix, TypeMatrix } from '../types'; import { ValidationError, ValidationInconsistentShape } from './Errors'; +import { RandomStateObj } from './random'; import { validateMatrix1D, validateMatrix2D } from './validation'; /** @@ -97,3 +98,173 @@ export const ensure2DMatrix = (X: Type2DMatrix | Type1DMatrix): const matrix1D = validateMatrix1D(X); return _.map(matrix1D, (o) => [o]); }; + +/** + * + * @param array - target matrix + * @ignore + */ +export function invidualize(array: Type1DMatrix = null): Type2DMatrix { + const uniqArray = _.uniq(_.flatten(array)).sort(); + let min = Number.MAX_VALUE; + let max = Number.MIN_VALUE; + + const valueCount = {}; + const uniqIndexMap = uniqArray.reduce((acc, ele, i) => { + if (min > ele) { + min = ele; + } + + if (max < ele) { + max = ele; + } + + return { + ...acc, + [ele]: i, + }; + }, {}); + + const indexMap = array.map((ele) => { + if (valueCount[ele]) { + valueCount[ele] += 1; + } else { + valueCount[ele] = 1; + } + return uniqIndexMap[ele]; + }); + + return [uniqArray, indexMap]; +} + +/** + * + * Count number of occurrences of each value in array of non-negative ints. + * countBin([0, 1, 1, 3, 2, 1, 7]) = [1, 3, 1, 1, 0, 0, 0, 1] + * countBin([0, 1, 1, 2, 2, 2], [0.3, 0.5, 0.2, 0.7, 1., -0.6]) = [ 0.3, 0.7, 1.1] + * countBin([7]) = [0, 0, 0, 0, 0, 0, 0, 1] + * @param array + */ +export function countBin(array: Type1DMatrix, weights?: Type1DMatrix): Type1DMatrix { + if (weights && array.length !== weights.length) { + throw Error(`weights=${weights} and targetArray=${array} should be of same length.`); + } + const min: number = _.min(array); + const max: number = _.max(array); + + const retArray = Array(max - min + 1).fill(0); + if (!weights) { + weights = Array(array.length).fill(1); + } + + const arrToObj = array.reduce((acc, ele, i) => { + if (Math.floor(ele) !== ele) { + throw Error(`Only integer values are acceptable in the values of ${array}`); + } + + acc[ele] = (acc[ele] || 0) + weights[i]; + return acc; + }, {}); + + for (let i = 0; i < retArray.length; i++) { + if (arrToObj[i + min]) { + retArray[i] = arrToObj[i + min]; + } + } + + return [...Array(min).fill(0), ...retArray]; +} + +/** + * Split an array into multiple sub-arrays. + * @param array + * @param indices_or_sections + */ + +export function arraySplit( + array: Type1DMatrix, + indices_or_sections: number | Type1DMatrix, +): Type2DMatrix { + const nTotal: number = array.length; + let nSections: number = null; + let divPoints: tf.Tensor1D = null; + if (indices_or_sections instanceof Array) { + nSections = indices_or_sections.length + 1; + divPoints = tf.tensor([0, ...indices_or_sections, nTotal]); + } else { + if (indices_or_sections <= 0) { + throw Error('The number of sections can not be less than one'); + } + nSections = Math.floor(indices_or_sections); + const nEachSection = Math.floor(nTotal / nSections); + const extras = nTotal % nSections; + divPoints = tf.cumsum([ + 0, + ...Array(extras).fill(nEachSection + 1), + ...Array(nSections - extras).fill(nEachSection), + ]); + } + + const subArrays: Type2DMatrix = []; + for (let i = 0; i < nSections; i++) { + const st = divPoints.get(i); + const end = divPoints.get(i + 1); + subArrays.push(array.slice(st, end)); + } + + return subArrays; +} + +export function approximateMode( + classCounts: Type1DMatrix, + nDraws: number, + rng: RandomStateObj, +): Type1DMatrix { + // this computes a bad approximation to the mode of the + // multivariate hypergeometric given by class_counts and n_draws + const countSum = _.sum(classCounts); + let flooredSum = 0; + // floored means we don't overshoot n_samples, but probably undershoot + const { floored, remainder } = classCounts.reduce( + (acc, val) => { + const value = nDraws * val / countSum; + const flooredVal = Math.floor(value); + const diff = value - flooredVal; + acc.continuous.push(value); + acc.floored.push(flooredVal); + acc.remainder.push(diff); + flooredSum += flooredVal; + return acc; + }, + { floored: [], continuous: [], remainder: [] }, + ); + + let needToAdd = Math.floor(nDraws - flooredSum); + // we add samples according to how much "left over" probability + // they had, until we arrive at n_samples + // need_to_add = int(n_draws - floored.sum()) + if (needToAdd > 0) { + const values = _.sortedUniq(remainder); + for (let i = 0; i < values.length; i++) { + const val = values[i]; + let inds = remainder.reduce((acc, rval, j) => { + if (rval === val) { + acc.push(j); + } + return acc; + }, []); + const addNow = Math.min(inds.length, needToAdd); + inds = rng.choice(inds, addNow); + inds.forEach((k) => { + floored[k] += 1; + }); + + needToAdd -= addNow; + if (needToAdd === 0) { + break; + } + } + } + + return floored; +} diff --git a/src/lib/utils/validation.ts b/src/lib/utils/validation.ts index e2992df7..b56a7851 100644 --- a/src/lib/utils/validation.ts +++ b/src/lib/utils/validation.ts @@ -177,7 +177,7 @@ export const validateFeaturesConsistency = ( export function validateShapesEqual( y_true: Type1DMatrix | Type2DMatrix = null, y_pred: Type1DMatrix | Type2DMatrix = null, -): tf.Tensor[] { +): Type1DMatrix> { const yTrueTensor = tf.tensor(y_true); const yPredTensor = tf.tensor(y_pred); const yTrueShape = yTrueTensor.shape; @@ -197,3 +197,17 @@ export function validateShapesEqual( return [yTrueTensor, yPredTensor]; } + +/** + * get number of samples from an array + * @param array - type matrix or tensor + */ +export function numSamples(array: TypeMatrix | tf.Tensor = null): number { + if (!array) { + throw new ValidationError(`array cant be null`); + } + if (array instanceof tf.Tensor) { + return array.shape[0]; + } + return array.length; +} diff --git a/test/linear_model/__snapshots__/manual_sgd_regressor.snap.ts b/test/linear_model/__snapshots__/manual_sgd_regressor.snap.ts index fb87d235..d3dd38ca 100644 --- a/test/linear_model/__snapshots__/manual_sgd_regressor.snap.ts +++ b/test/linear_model/__snapshots__/manual_sgd_regressor.snap.ts @@ -1,158 +1,158 @@ export const reg_l1_snap = [ - -0.0721491202712059, - 1.8499072790145874, - 2.000120162963867, - 1.763584017753601, - -0.06548641622066498, - 1.2569680213928223, - 1.408263087272644, - -0.09963376820087433, - 0.9498693346977234, - 0.9368916153907776, - 0.12602682411670685, - -0.016304979100823402, - 0.8815138339996338, - 1.0217688083648682, - 0.042637307196855545, - 1.2872182130813599, - 2.0576579570770264, - -0.001897446229122579, - 1.928894281387329, - 1.4728261232376099, - 0.838404655456543, - 2.0542337894439697, - -0.05303708836436272, - 2.005605936050415, - 0.024565843865275383, - 1.1904466152191162, - 1.5692096948623657, - 1.0676918029785156, - 1.246755838394165, - 1.692721962928772, - -0.13608022034168243, - -0.008093821816146374, - 1.7138903141021729, - 1.8486846685409546, - -0.0868988186120987, - 1.3307785987854004, - -0.032723069190979004, - -0.052975043654441833, - 1.5985339879989624, - 1.6508700847625732, - 1.2555382251739502, - -0.05963774397969246, - 1.898187279701233, - 1.5247799158096313, - -0.05670524761080742, - 1.8354581594467163, - 1.2046414613723755, - 1.8430787324905396, - -0.08569683134555817, - 1.1684170961380005, + -0.10187516361474991, + 1.9362775087356567, + -0.030817851424217224, + -0.15670014917850494, + -0.06435448676347733, + 1.3769986629486084, + -0.054627493023872375, + 1.4001959562301636, + 1.768455982208252, + 1.1212642192840576, + 1.5794532299041748, + 0.9440510272979736, + -0.09058281779289246, + 0.13453781604766846, + 2.2322444915771484, + 0.8844969868659973, + 1.173557162284851, + 1.5922285318374634, + -0.23827676475048065, + 1.999155879020691, + 1.9670710563659668, + 1.1842509508132935, + 1.8097987174987793, + 0.8924177885055542, + 1.0067024230957031, + 1.9101625680923462, + -0.04956323280930519, + 0.01447377260774374, + 1.1836856603622437, + 0.012781056575477123, + 1.0901933908462524, + 1.5491446256637573, + 1.2899250984191895, + 1.2064342498779297, + 1.4554414749145508, + 2.0235116481781006, + -0.00013770590885542333, + 0.03823625668883324, + -0.16355983912944794, + 1.4931367635726929, + -0.06580637395381927, + 1.3363806009292603, + 1.0733377933502197, + 1.5785175561904907, + 1.7249342203140259, + 1.2076889276504517, + 1.3013920783996582, + 2.0409014225006104, + 1.4154359102249146, + 1.1597988605499268, ]; export const reg_l12_snap = [ - 0.10738064348697662, - 1.7174302339553833, - 1.8375533819198608, - 1.67069411277771, - 0.06189228966832161, - 1.1819779872894287, - 1.3368662595748901, - 0.0677068829536438, - 0.9299587607383728, - 0.9379181265830994, - 0.35877126455307007, - 0.16697554290294647, - 0.8833820223808289, - 1.0770105123519897, - 0.20129679143428802, - 1.3402129411697388, - 2.03820538520813, - 0.11076389998197556, - 1.7500780820846558, - 1.469829797744751, - 0.8211827278137207, - 1.9208850860595703, - 0.09558970481157303, - 1.892223596572876, - 0.17387846112251282, - 1.138923168182373, - 1.5752410888671875, - 1.0453166961669922, - 1.308498501777649, - 1.45609712600708, - -0.03998078405857086, - 0.16131111979484558, - 1.56316339969635, - 1.8202002048492432, - 0.06510242074728012, - 1.346653699874878, - 0.10171855241060257, - 0.048874109983444214, - 1.4369295835494995, - 1.6515308618545532, - 1.2302302122116089, - 0.09436246007680893, - 1.8291091918945312, - 1.407045602798462, - 0.08854561299085617, - 1.6996674537658691, - 1.2097012996673584, - 1.7823750972747803, - 0.08302222937345505, - 1.217348575592041, + -0.09578026086091995, + 1.9175169467926025, + -0.044050805270671844, + -0.1154128909111023, + -0.07411445677280426, + 1.3834723234176636, + -0.04848529398441315, + 1.3901816606521606, + 1.7331737279891968, + 1.1252330541610718, + 1.6151249408721924, + 0.8965901732444763, + -0.09128819406032562, + 0.15496404469013214, + 2.22098445892334, + 0.8438668251037598, + 1.1376556158065796, + 1.5821704864501953, + -0.20684932172298431, + 1.9915281534194946, + 1.9458564519882202, + 1.1916320323944092, + 1.8428572416305542, + 0.8544819355010986, + 0.9816596508026123, + 1.894813895225525, + -0.034065138548612595, + 0.013514656573534012, + 1.1664959192276, + -0.007028540596365929, + 1.0804977416992188, + 1.542830228805542, + 1.2962431907653809, + 1.1974050998687744, + 1.462317943572998, + 1.9718060493469238, + -0.010182066820561886, + 0.04536003991961479, + -0.15250259637832642, + 1.4982532262802124, + -0.05850651115179062, + 1.3294603824615479, + 1.0583312511444092, + 1.564939260482788, + 1.7103776931762695, + 1.2344924211502075, + 1.3150047063827515, + 2.040947198867798, + 1.4060312509536743, + 1.1478352546691895, ]; export const reg_l2_snap = [ - -0.07334784418344498, - 1.8702806234359741, - 2.027863025665283, - 1.757415533065796, - -0.058840565383434296, - 1.2708457708358765, - 1.4022281169891357, - -0.12441510707139969, - 0.9141921401023865, - 0.9153953790664673, - 0.09964509308338165, - -0.012327668257057667, - 0.8524127006530762, - 0.9993201494216919, - 0.05692737177014351, - 1.260292649269104, - 2.080444574356079, - -0.009513204917311668, - 1.9556868076324463, - 1.4598970413208008, - 0.8319887518882751, - 2.089475631713867, - -0.05327172577381134, - 2.0235347747802734, - -0.0016155339544638991, - 1.2159310579299927, - 1.526334285736084, - 1.057511806488037, - 1.2104287147521973, - 1.7403795719146729, - -0.13033176958560944, - -0.025982998311519623, - 1.7425973415374756, - 1.8811771869659424, - -0.08381890505552292, - 1.3183915615081787, - -0.027015937492251396, - -0.058081306517124176, - 1.6068274974822998, - 1.6342045068740845, - 1.2668280601501465, - -0.07258858531713486, - 1.9060038328170776, - 1.5319108963012695, - -0.06700124591588974, - 1.8551995754241943, - 1.1812701225280762, - 1.8531256914138794, - -0.07763924449682236, - 1.150630235671997, + -0.1128714382648468, + 1.927524209022522, + -0.019957974553108215, + -0.15421128273010254, + -0.06358052790164948, + 1.3845850229263306, + -0.05573194846510887, + 1.4279766082763672, + 1.7814819812774658, + 1.1407538652420044, + 1.5841346979141235, + 0.9630727171897888, + -0.08736681193113327, + 0.13895414769649506, + 2.242431879043579, + 0.8999905586242676, + 1.1831176280975342, + 1.597848653793335, + -0.2701653838157654, + 2.004774570465088, + 1.967844009399414, + 1.1722862720489502, + 1.8184894323349, + 0.9122092127799988, + 1.0183953046798706, + 1.9147017002105713, + -0.057779841125011444, + -0.007610837463289499, + 1.1790218353271484, + 0.007723023183643818, + 1.0842615365982056, + 1.569187045097351, + 1.3011060953140259, + 1.2352399826049805, + 1.470099925994873, + 2.021085023880005, + 0.009398099966347218, + 0.029045211151242256, + -0.18687404692173004, + 1.4855189323425293, + -0.07027280330657959, + 1.3430536985397339, + 1.0793861150741577, + 1.584331750869751, + 1.7265660762786865, + 1.1936516761779785, + 1.2927387952804565, + 2.037707567214966, + 1.420951008796692, + 1.1704827547073364, ]; diff --git a/test/linear_model/stochastic_gradient.test.ts b/test/linear_model/stochastic_gradient.test.ts index cba70cf1..bc14a373 100644 --- a/test/linear_model/stochastic_gradient.test.ts +++ b/test/linear_model/stochastic_gradient.test.ts @@ -193,6 +193,7 @@ describe('linear_model:SGDRegressor', () => { const reg = new SGDRegressor(); reg.fit(xTrain, yTrain); const result = reg.predict(xTest); + const similarity = assertArrayAlmostEqual(reg_l2_snap, result, 1); expect(similarity).toBeGreaterThanOrEqual(accuracyExpected1); diff --git a/test/model_selection/_split.test.ts b/test/model_selection/_split.test.ts index 60fcd006..33523522 100644 --- a/test/model_selection/_split.test.ts +++ b/test/model_selection/_split.test.ts @@ -1,6 +1,7 @@ import * as _ from 'lodash'; -import { KFold, train_test_split } from '../../src/lib/model_selection/_split'; +import { KFold, StratifiedShuffleSplit, train_test_split } from '../../src/lib/model_selection/_split'; import { ValidationError } from '../../src/lib/utils/Errors'; +// import { x_1, y_1 as ySnap } from './__snapshots__/_split.test'; describe('_split:KFold', () => { const X1 = [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]]; @@ -119,10 +120,10 @@ describe('_split:train_test_split', () => { train_size: 0.67, }); - expect(_.isEqual(xTrain, [[4, 5], [6, 7], [2, 3]])).toBe(true); - expect(_.isEqual(yTrain, [2, 3, 1])).toBe(true); - expect(_.isEqual(xTest, [[0, 1], [8, 9]])).toBe(true); - expect(_.isEqual(yTest, [0, 4])).toBe(true); + expect(_.isEqual(xTrain, [[0, 1], [2, 3]])).toBe(true); + expect(_.isEqual(yTrain, [0, 1])).toBe(true); + expect(_.isEqual(xTest, [[4, 5], [6, 7]])).toBe(true); + expect(_.isEqual(yTest, [2, 3])).toBe(true); }); it('Should split X1, y1 with random_state 100 test_size: .50 train_size: .50', () => { @@ -132,19 +133,19 @@ describe('_split:train_test_split', () => { train_size: 0.5, }); - expect(_.isEqual(xTrain, [[0, 1], [6, 7], [2, 3]])).toBe(true); - expect(_.isEqual(yTrain, [0, 3, 1])).toBe(true); - expect(_.isEqual(xTest, [[8, 9], [4, 5]])).toBe(true); - expect(_.isEqual(yTest, [4, 2])).toBe(true); + expect(_.isEqual(xTrain, [[4, 5], [2, 3], [8, 9]])).toBe(true); + expect(_.isEqual(yTrain, [2, 1, 4])).toBe(true); + expect(_.isEqual(xTest, [[0, 1], [6, 7]])).toBe(true); + expect(_.isEqual(yTest, [0, 3])).toBe(true); }); it('Should use default test and train sizes', () => { const { xTrain, yTrain, xTest, yTest } = train_test_split(X1, y1); - expect(_.isEqual(xTrain, [[8, 9], [6, 7], [0, 1]])).toBe(true); - expect(_.isEqual(yTrain, [4, 3, 0])).toBe(true); - expect(_.isEqual(xTest, [[4, 5]])).toBe(true); - expect(_.isEqual(yTest, [2])).toBe(true); + expect(_.isEqual(xTrain, [[6, 7], [0, 1], [8, 9], [4, 5]])).toBe(true); + expect(_.isEqual(yTrain, [3, 0, 4, 2])).toBe(true); + expect(_.isEqual(xTest, [[2, 3]])).toBe(true); + expect(_.isEqual(yTest, [1])).toBe(true); }); it('Should sum of test_size and train_size attempting to match the input size throw an error', () => { @@ -161,9 +162,76 @@ describe('_split:train_test_split', () => { it('Should split X2 y2 with random_state: 42 test_size: .33 and train_size: .67', () => { const { xTrain, yTrain, xTest, yTest } = train_test_split(X2, y2); - expect(_.isEqual(xTrain, [['five'], ['four'], ['one']])).toBe(true); - expect(_.isEqual(yTrain, ['e', 'd', 'a'])).toBe(true); - expect(_.isEqual(xTest, [['three']])).toBe(true); - expect(_.isEqual(yTest, ['c'])).toBe(true); + + expect(_.isEqual(xTrain, [['four'], ['one'], ['five'], ['three']])).toBe(true); + expect(_.isEqual(yTrain, ['d', 'a', 'e', 'c'])).toBe(true); + expect(_.isEqual(xTest, [['two']])).toBe(true); + expect(_.isEqual(yTest, ['b'])).toBe(true); + }); +}); + +describe('_split:StratifiedShuffleSplit', () => { + it('Check that error is raised if there is a class with only one sample', () => { + const X = [0, 1, 2, 3, 4, 5, 6]; + const y = [0, 1, 1, 1, 2, 2, 2]; + const initAndCall = (...values) => { + const sss = new StratifiedShuffleSplit(...values); + const [train, test] = sss.split(X, y); + }; + expect(() => initAndCall(3, 0.2)).toThrow(); + expect(() => initAndCall(3, 0.2)).toThrowError( + `The least populated class in y=${y} has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.`, + ); + + // expect(() => initAndCall(3, 2)).toThrow(); + // expect(() => initAndCall(3, 2)).toThrowError( + // `The least populated class in y=${y} has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.`, + // ); + }); + + it('Check that error is raised if the test/train set size is smaller than n_classes', () => { + const X = [0, 1, 2, 3, 4, 5, 6, 7, 8]; + const y = [0, 0, 0, 1, 1, 1, 2, 2, 2]; + const initAndCall = (...values) => { + const sss = new StratifiedShuffleSplit(...values); + const [train, test] = sss.split(X, y); + }; + expect(() => initAndCall(3, 2)).toThrow(); + expect(() => initAndCall(3, 2)).toThrowError( + 'The test_size = 2 should be greater or equal to the number of classes = 3', + ); + + expect(() => initAndCall(3, 3, 2)).toThrow(); + expect(() => initAndCall(3, 3, 2)).toThrowError( + 'The train_size = 2 should be greater or equal to the number of classes = 3', + ); }); + + it('Test stratified shuffle split respects test size.', () => { + const y = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]; + const testSize = 5; + const trainSize = 10; + const [trainSet, testSet] = new StratifiedShuffleSplit(6, testSize, trainSize, 0).split( + new Array(y.length).fill(1), + y, + ); + for (let i = 0; i < trainSet.length; i++) { + const train = trainSet[i]; + const test = testSet[i]; + expect(train.length).toBe(trainSize); + expect(test.length).toBe(testSize); + expect(train.length + test.length).toBe(y.length); + } + }); + + // it('Test stratified shuffle split multilabel many labels.', () => { + // const y = ySnap; + // const X = x_1; + // const [trainSet, testSet] = new StratifiedShuffleSplit(6, 0.5, undefined, 0).split(X, y); + // for (let i = 0; i < trainSet.length; i++) { + // const train = trainSet[i]; + // const test = testSet[i]; + // expect(train.length + test.length).toBe(y.length); + // } + // }); }); diff --git a/yarn.lock b/yarn.lock index 4f8202ef..013b718d 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6626,6 +6626,13 @@ kleur@^3.0.2: resolved "https://registry.yarnpkg.com/kleur/-/kleur-3.0.2.tgz#83c7ec858a41098b613d5998a7b653962b504f68" integrity sha512-3h7B2WRT5LNXOtQiAaWonilegHcPSf9nLVXlSTci8lu1dZUuui61+EsPEZqSVxY7rXYmB2DVKMQILxaO5WL61Q== +knuth-shuffle-seeded@^1.0.6: + version "1.0.6" + resolved "https://registry.yarnpkg.com/knuth-shuffle-seeded/-/knuth-shuffle-seeded-1.0.6.tgz#01f1b65733aa7540ee08d8b0174164d22081e4e1" + integrity sha1-AfG2VzOqdUDuCNiwF0Fk0iCB5OE= + dependencies: + seed-random "~2.2.0" + koa-compose@^3.0.0, koa-compose@^3.2.1: version "3.2.1" resolved "https://registry.yarnpkg.com/koa-compose/-/koa-compose-3.2.1.tgz#a85ccb40b7d986d8e5a345b3a1ace8eabcf54de7" @@ -9232,10 +9239,20 @@ section-matter@^1.0.0: extend-shallow "^2.0.1" kind-of "^6.0.0" +seed-random@~2.2.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/seed-random/-/seed-random-2.2.0.tgz#2a9b19e250a817099231a5b99a4daf80b7fbed54" + integrity sha1-KpsZ4lCoFwmSMaW5mk2vgLf77VQ= + seedrandom@2.4.3: version "2.4.3" resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-2.4.3.tgz#2438504dad33917314bff18ac4d794f16d6aaecc" +seedrandom@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-3.0.1.tgz#eb3dde015bcf55df05a233514e5df44ef9dce083" + integrity sha512-1/02Y/rUeU1CJBAGLebiC5Lbo5FnB22gQbIFFYTLkwvp1xdABZJH1sn4ZT1MzXmPpzv+Rf/Lu2NcsLJiK4rcDg== + seedrandom@~2.4.3: version "2.4.4" resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-2.4.4.tgz#b25ea98632c73e45f58b77cfaa931678df01f9ba" pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy