Skip to content

Commit 01e2dd5

Browse files
committed
Improvements thanks to #1 feedback
Updating this in preparation of YouTube tutorial. Changes thanks to feedback from @nsthorat. I think if we are to create generic `NeuralNetwork` or `Classifier` classes in ml5, a lot of the work will inevitably involve data helper classes so that the end user can work with vanilla arrays and tensors are all created and managed internally by ml5.
1 parent 79696a3 commit 01e2dd5

File tree

2 files changed

+58
-22
lines changed

2 files changed

+58
-22
lines changed

01_XOR/neuralnetwork.js

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
1+
// Data helper classes
2+
3+
// class ml5.Data
4+
5+
// ml5.Data?
6+
class Data {
7+
// Need to deal with shape
8+
constructor(data) {
9+
this.xs = tf.tensor2d(data.inputs);
10+
this.ys = tf.tensor2d(data.targets);
11+
}
12+
}
13+
14+
15+
// Helper class for a Batch of Data: ml5.Batch
116
class Batch {
217
constructor() {
18+
// Need to deal with shape
319
// this.shape = ??;
420
this.data = [];
521
}
@@ -12,7 +28,7 @@ class Batch {
1228

1329
class NeuralNetwork {
1430

15-
constructor(inputs, hidden, outputs) {
31+
constructor(inputs, hidden, outputs, lr) {
1632
this.model = tf.sequential();
1733
const hiddenLayer = tf.layers.dense({
1834
units: hidden,
@@ -21,13 +37,14 @@ class NeuralNetwork {
2137
});
2238
const outputLayer = tf.layers.dense({
2339
units: outputs,
24-
inputShape: [hidden],
40+
// inferred
41+
// inputShape: [hidden],
2542
activation: 'sigmoid'
2643
});
2744
this.model.add(hiddenLayer);
2845
this.model.add(outputLayer);
2946

30-
const LEARNING_RATE = 0.5;
47+
const LEARNING_RATE = lr || 0.5;
3148
const optimizer = tf.train.sgd(LEARNING_RATE);
3249

3350
this.model.compile({
@@ -38,28 +55,46 @@ class NeuralNetwork {
3855
}
3956

4057
predict(inputs) {
41-
if (inputs instanceof Batch) {
42-
return tf.tidy(() => {
43-
const xs = tf.tensor2d(inputs.data);
44-
return this.model.predict(xs).dataSync();
45-
});
58+
return tf.tidy(() => {
59+
let data;
60+
if (inputs instanceof Batch) {
61+
data = inputs.data;
62+
} else {
63+
data = [inputs];
64+
}
65+
const xs = tf.tensor2d(data);
66+
return this.model.predict(xs).dataSync();
67+
});
68+
}
69+
70+
setTrainingData(data) {
71+
if (data instanceof Data) {
72+
this.trainingData = data;
4673
} else {
47-
return tf.tidy(() => {
48-
const xs = tf.tensor2d([inputs]);
49-
return this.model.predict(xs).dataSync();
50-
});
74+
this.trainingData = new Data(data);
5175
}
5276
}
5377

54-
async train(data, epochs, callback) {
55-
const xs = tf.tensor2d(data.inputs);
56-
const ys = tf.tensor2d(data.targets);
78+
async train(callback, epochs, data) {
79+
let xs, ys;
80+
if (data) {
81+
xs = tf.tensor2d(data.inputs);
82+
ys = tf.tensor2d(data.targets);
83+
} else if (this.trainingData) {
84+
xs = this.trainingData.xs;
85+
ys = this.trainingData.ys;
86+
} else {
87+
console.log("I have no data!");
88+
return;
89+
}
5790
await this.model.fit(xs, ys, {
58-
epochs: epochs,
91+
epochs: epochs || 1,
5992
shuffle: true
6093
});
61-
xs.dispose();
62-
ys.dispose();
94+
if (data) {
95+
xs.dispose();
96+
ys.dispose();
97+
}
6398
callback();
6499
}
65-
}
100+
}

01_XOR/sketch.js

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,21 @@ let counter = 0;
1818
let training = true;
1919

2020
function train() {
21-
nn.train(data, 10, finished);
21+
nn.train(finished);
2222
}
2323

2424
function finished() {
2525
counter++;
2626
statusP.html('training pass: ' + counter + '<br>framerate: ' + floor(frameRate()));
27-
setTimeout(train, 10);
27+
train();
2828
}
2929

3030
let statusP;
3131

3232
function setup() {
3333
createCanvas(400, 400);
3434
nn = new NeuralNetwork(2, 2, 1);
35+
nn.setTrainingData(data);
3536
train();
3637
statusP = createP('0');
3738
}
@@ -66,4 +67,4 @@ function draw() {
6667
}
6768
}
6869
// }
69-
}
70+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy