1
+ // Data helper classes
2
+
3
+ // class ml5.Data
4
+
5
+ // ml5.Data?
6
+ class Data {
7
+ // Need to deal with shape
8
+ constructor ( data ) {
9
+ this . xs = tf . tensor2d ( data . inputs ) ;
10
+ this . ys = tf . tensor2d ( data . targets ) ;
11
+ }
12
+ }
13
+
14
+
15
+ // Helper class for a Batch of Data: ml5.Batch
1
16
class Batch {
2
17
constructor ( ) {
18
+ // Need to deal with shape
3
19
// this.shape = ??;
4
20
this . data = [ ] ;
5
21
}
@@ -12,7 +28,7 @@ class Batch {
12
28
13
29
class NeuralNetwork {
14
30
15
- constructor ( inputs , hidden , outputs ) {
31
+ constructor ( inputs , hidden , outputs , lr ) {
16
32
this . model = tf . sequential ( ) ;
17
33
const hiddenLayer = tf . layers . dense ( {
18
34
units : hidden ,
@@ -21,13 +37,14 @@ class NeuralNetwork {
21
37
} ) ;
22
38
const outputLayer = tf . layers . dense ( {
23
39
units : outputs ,
24
- inputShape : [ hidden ] ,
40
+ // inferred
41
+ // inputShape: [hidden],
25
42
activation : 'sigmoid'
26
43
} ) ;
27
44
this . model . add ( hiddenLayer ) ;
28
45
this . model . add ( outputLayer ) ;
29
46
30
- const LEARNING_RATE = 0.5 ;
47
+ const LEARNING_RATE = lr || 0.5 ;
31
48
const optimizer = tf . train . sgd ( LEARNING_RATE ) ;
32
49
33
50
this . model . compile ( {
@@ -38,28 +55,46 @@ class NeuralNetwork {
38
55
}
39
56
40
57
predict ( inputs ) {
41
- if ( inputs instanceof Batch ) {
42
- return tf . tidy ( ( ) => {
43
- const xs = tf . tensor2d ( inputs . data ) ;
44
- return this . model . predict ( xs ) . dataSync ( ) ;
45
- } ) ;
58
+ return tf . tidy ( ( ) => {
59
+ let data ;
60
+ if ( inputs instanceof Batch ) {
61
+ data = inputs . data ;
62
+ } else {
63
+ data = [ inputs ] ;
64
+ }
65
+ const xs = tf . tensor2d ( data ) ;
66
+ return this . model . predict ( xs ) . dataSync ( ) ;
67
+ } ) ;
68
+ }
69
+
70
+ setTrainingData ( data ) {
71
+ if ( data instanceof Data ) {
72
+ this . trainingData = data ;
46
73
} else {
47
- return tf . tidy ( ( ) => {
48
- const xs = tf . tensor2d ( [ inputs ] ) ;
49
- return this . model . predict ( xs ) . dataSync ( ) ;
50
- } ) ;
74
+ this . trainingData = new Data ( data ) ;
51
75
}
52
76
}
53
77
54
- async train ( data , epochs , callback ) {
55
- const xs = tf . tensor2d ( data . inputs ) ;
56
- const ys = tf . tensor2d ( data . targets ) ;
78
+ async train ( callback , epochs , data ) {
79
+ let xs , ys ;
80
+ if ( data ) {
81
+ xs = tf . tensor2d ( data . inputs ) ;
82
+ ys = tf . tensor2d ( data . targets ) ;
83
+ } else if ( this . trainingData ) {
84
+ xs = this . trainingData . xs ;
85
+ ys = this . trainingData . ys ;
86
+ } else {
87
+ console . log ( "I have no data!" ) ;
88
+ return ;
89
+ }
57
90
await this . model . fit ( xs , ys , {
58
- epochs : epochs ,
91
+ epochs : epochs || 1 ,
59
92
shuffle : true
60
93
} ) ;
61
- xs . dispose ( ) ;
62
- ys . dispose ( ) ;
94
+ if ( data ) {
95
+ xs . dispose ( ) ;
96
+ ys . dispose ( ) ;
97
+ }
63
98
callback ( ) ;
64
99
}
65
- }
100
+ }
0 commit comments