Skip to content

adding initial code for StratifiedShuffleSplit #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"lodash": "4.17.13",
"numeric": "1.2.6",
"random-js": "1.0.8",
"seedrandom": "^3.0.1",
"stopword": "0.1.10"
},
"devDependencies": {
Expand Down
20 changes: 8 additions & 12 deletions src/lib/linear_model/stochastic_gradient.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import * as tf from '@tensorflow/tfjs';
import { cloneDeep, range } from 'lodash';
import * as Random from 'random-js';
// import * as Random from 'random-js';
import { IMlModel, Type1DMatrix, Type2DMatrix } from '../types';
import RandomState, { RandomStateObj } from '../utils/random';
import { validateFitInputs, validateMatrix2D } from '../utils/validation';

export enum TypeLoss {
Expand Down Expand Up @@ -29,7 +30,7 @@ export class BaseSGD implements IMlModel<number> {
protected regFactor: TypeRegFactor;
private clone: boolean = true;
private weights: tf.Tensor<tf.Rank.R1> = null;
private randomEngine: Random.MT19937; // Random engine used to
private randomEngine: RandomStateObj; // Random engine used to
private randomState: number;
/**
* @param preprocess - preprocess methodology can be either minmax or null. Default is minmax.
Expand All @@ -42,7 +43,7 @@ export class BaseSGD implements IMlModel<number> {
learning_rate = 0.0001,
epochs = 10000,
clone = true,
random_state = null,
random_state,
loss = TypeLoss.L2,
reg_factor = null,
}: {
Expand All @@ -56,7 +57,7 @@ export class BaseSGD implements IMlModel<number> {
learning_rate: 0.0001,
epochs: 10000,
clone: true,
random_state: null,
random_state: undefined,
loss: TypeLoss.L2,
reg_factor: null,
},
Expand Down Expand Up @@ -87,11 +88,7 @@ export class BaseSGD implements IMlModel<number> {
}

// Random Engine
if (Number.isInteger(this.randomState)) {
this.randomEngine = Random.engines.mt19937().seed(this.randomState);
} else {
this.randomEngine = Random.engines.mt19937().autoSeed();
}
this.randomEngine = new RandomState(this.randomState);
}

/**
Expand Down Expand Up @@ -193,9 +190,8 @@ export class BaseSGD implements IMlModel<number> {
*/
private initializeWeights(nFeatures: number): void {
const limit = 1 / Math.sqrt(nFeatures);
const distribution = Random.real(-limit, limit);
const getRand = () => distribution(this.randomEngine);
this.weights = tf.tensor1d(range(0, nFeatures).map(() => getRand()));
const distribution = this.randomEngine.real(-limit, limit);
this.weights = tf.tensor1d(range(0, nFeatures).map(distribution));
}

/**
Expand Down
181 changes: 166 additions & 15 deletions src/lib/model_selection/_split.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import * as tf from '@tensorflow/tfjs';
import * as _ from 'lodash';
import * as Random from 'random-js';
import { Type1DMatrix, Type2DMatrix } from '../types';
import { ValidationError } from '../utils/Errors';
import { inferShape } from '../utils/tensors';
import { validateFitInputs } from '../utils/validation';
import RandomState, { RandomStateObj } from '../utils/random';
import { approximateMode, arraySplit, countBin, inferShape, invidualize } from '../utils/tensors';
import { numSamples, validateFitInputs } from '../utils/validation';

const testShapes = (X: Type1DMatrix<any> | Type2DMatrix<any>, y: Type1DMatrix<any>) => {
const xShape: Type1DMatrix<number> = inferShape(X);
const yShape: Type1DMatrix<number> = inferShape(y);
if (xShape.length > 0 && yShape.length > 0 && xShape[0] !== yShape[0]) {
throw new ValidationError('X and y must have an identical size');
}
};
/**
* K-Folds cross-validator
*
Expand All @@ -17,7 +25,6 @@ import { validateFitInputs } from '../utils/validation';
*
* const kFold = new KFold({ k: 5 });
* const X1 = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
* console.log(kFold.split(X1, X1));
*
* /* [ { trainIndex: [ 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ],
* * testIndex: [ 0, 1, 2, 3 ] },
Expand Down Expand Up @@ -47,20 +54,14 @@ export class KFold {
this.k = k;
this.shuffle = shuffle;
}

/**
*
* @param X - Training data, where n_samples is the number of samples and n_features is the number of features.
* @param y - The target variable for supervised learning problems.
* @returns {any[]}
*/
public split(X: Type1DMatrix<any> = null, y: Type1DMatrix<any> = null): any[] {
const xShape = inferShape(X);
const yShape = inferShape(y);
if (xShape.length > 0 && yShape.length > 0 && xShape[0] !== yShape[0]) {
throw new ValidationError('X and y must have an identical size');
}

testShapes(X, y);
if (this.k > X.length || this.k > y.length) {
throw new ValidationError(
`Cannot have number of splits k=${this.k} greater than the number of samples: ${_.size(X)}`,
Expand Down Expand Up @@ -166,8 +167,7 @@ export function train_test_split(
throw new ValidationError('Sum of test_size and train_size does not equal 1');
}
// Initiate Random engine
const randomEngine = Random.engines.mt19937();
randomEngine.seed(random_state);
const randomEngine: RandomStateObj = new RandomState(random_state);

// split
const xTrain = [];
Expand All @@ -177,7 +177,7 @@ export function train_test_split(

// Getting X_train and y_train
while (xTrain.length < trainSizeLength && yTrain.length < trainSizeLength) {
const index = Random.integer(0, X.length - 1)(randomEngine);
const index = randomEngine.rangedInt(0, X.length - 1);

// X_train
xTrain.push(_X[index]);
Expand All @@ -189,7 +189,7 @@ export function train_test_split(
}

while (xTest.length < testSizeLength) {
const index = Random.integer(0, _X.length - 1)(randomEngine);
const index = randomEngine.rangedInt(0, _X.length - 1);
// X test
xTest.push(_X[index]);
_X.splice(index, 1);
Expand All @@ -208,3 +208,154 @@ export function train_test_split(
yTrain: clean(yTrain),
};
}

const rangeValidationError = (type, size, n_samples) => `${type}=${size} should be either
positive and smaller than number of samples ${n_samples} or a float in (0, 1) range`;

const testRangeValidationError = (test_size, n_samples) => rangeValidationError('test_size', test_size, n_samples);

const trainRangeValidationError = (test_size, n_samples) => rangeValidationError('test_size', test_size, n_samples);

/**
* StratifiedShuffleSplit
*/
export class StratifiedShuffleSplit {
private n_splits: number;
private testSize: number;
private trainSize: number;
private rng: RandomStateObj;
private defaultTestSize: number = 0.1;
constructor(n_splits: number = 10, testSize: number = null, trainSize?: number, seed?: number) {
this.n_splits = n_splits;
this.testSize = testSize;
this.trainSize = trainSize;
this.rng = new RandomState(seed);
}

split = (X: Type1DMatrix<any> | Type2DMatrix<any> = null, y: Type1DMatrix<any> = null): Type2DMatrix<any> => {
const XTensor = tf.tensor(X);
const nSamples = numSamples(XTensor);

const [nTest, nTrain] = validateShuffleSplit(nSamples, this.testSize, this.trainSize, this.defaultTestSize);

const [classes, yIndices] = invidualize(y);
const nClasses: number = classes.length;
const classCounts: Type1DMatrix<number> = countBin(yIndices);
if (_.min(classCounts) < 2) {
throw new ValidationError(
`The least populated class in y=${y} has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.`,
);
}

if (nTrain < nClasses) {
throw new ValidationError(
`The train_size = ${nTrain} should be greater or equal to the number of classes = ${nClasses}`,
);
}

if (nTest < nClasses) {
throw new ValidationError(
`The test_size = ${nTest} should be greater or equal to the number of classes = ${nClasses}`,
);
}

const cumsumClassCounts: tf.Tensor1D = tf.cumsum(classCounts);
const classIndices = arraySplit(
yIndices.sort(),
cumsumClassCounts.slice(0, cumsumClassCounts.shape[0] - 1).arraySync(),
);

const test = [];
const train = [];
for (let i = 0; i < this.n_splits; i++) {
const n_i: Type1DMatrix<number> = approximateMode(classCounts, nTrain, this.rng);
const classCountsRemaining: Type1DMatrix<number> = classCounts.map((item, index) => n_i[index] - item);
const t_i: Type1DMatrix<number> = approximateMode(classCountsRemaining, nTest, this.rng);

const tempTest = [];
const tempTrain = [];

for (let j = 0; j < nClasses; j++) {
const permutation: Type1DMatrix<any> = this.rng.permutation(classCounts[j]);
const permIndicesClassI = permutation.map((val) => classIndices[j][val]);
tempTrain.push.apply(tempTrain, permIndicesClassI.slice(0, n_i[j]));
tempTest.push.apply(tempTest, permIndicesClassI.slice(n_i[j], n_i[j] + t_i[j]));
}
test.push(this.rng.shuffle(tempTest));
train.push(this.rng.shuffle(tempTrain));
}

return [train, test];
};
}

function validateShuffleSplit(
n_samples: number,
test_size: number,
train_size: number,
default_test_size: number,
): Type1DMatrix<number> {
let n_train: number;
let n_test: number;

if (!test_size && !train_size) {
test_size = default_test_size;
}

if (test_size) {
if (Number.isInteger(test_size)) {
if (test_size >= n_samples || test_size <= 0) {
throw new ValidationError(testRangeValidationError(test_size, n_samples));
}

n_test = test_size;
} else {
if (test_size <= 0 || test_size >= 1) {
throw new ValidationError(testRangeValidationError(test_size, n_samples));
}

n_test = Math.ceil(test_size * n_samples);
}
}

if (train_size) {
if (Number.isInteger(train_size)) {
if (train_size >= n_samples || train_size <= 0) {
throw new ValidationError(trainRangeValidationError(train_size, n_samples));
}

n_train = train_size;
} else {
if (train_size <= 0 || train_size >= 1) {
throw new ValidationError(trainRangeValidationError(train_size, n_samples));
}

n_train = Math.floor(train_size * n_samples);
}
}

if (!train_size) {
n_train = n_samples - n_test;
} else if (!test_size) {
n_test = n_samples - n_train;
}

const total = n_train + n_test;
if (total > n_samples) {
throw new ValidationError(
`The sum of train_size and test_size = ${total}, ` +
'should be smaller than the number of ' +
`samples ${n_samples}. Reduce test_size and/or ` +
'train_size.',
);
}

if (n_train === 0) {
throw new ValidationError(
`With n_samples=${n_samples}, test_size=${test_size} and train_size=${train_size}, the ` +
'resulting train set will be empty. Adjust any of the ' +
'aforementioned parameters.',
);
}
return [Math.floor(n_test), Math.floor(n_train)];
}
4 changes: 2 additions & 2 deletions src/lib/model_selection/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import { KFold, train_test_split } from './_split';
import { KFold, StratifiedShuffleSplit, train_test_split } from './_split';

export { KFold, train_test_split };
export { KFold, train_test_split, StratifiedShuffleSplit };
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy