Skip to content

Commit 27a2871

Browse files
committed
Fix conv_net.load_weights #956
1 parent f084e6f commit 27a2871

File tree

5 files changed

+13
-10
lines changed

5 files changed

+13
-10
lines changed

src/TensorFlowNET.Examples/GAN/MnistGAN.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ public class MnistGAN : SciSharpExample, IExample
1818
float LeakyReLU_alpha = 0.2f;
1919

2020
#if GPU
21-
int epochs = 2000; // Better effect, but longer time
21+
int epochs = 1000; // Better effect, but longer time
22+
int batch_size = 16;
2223
#else
2324
int epochs = 20;
24-
#endif
2525
int batch_size = 64;
26+
#endif
2627

2728
string imgpath = "dcgan\\imgs";
2829
string modelpath = "dcgan\\models";

src/TensorFlowNET.Examples/ImageProcessing/ImageClassificationKeras.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public class ImageClassificationKeras : SciSharpExample, IExample
1717
{
1818
int batch_size = 32;
1919
int epochs = 3;
20-
Shape img_dim = (180, 180);
20+
Shape img_dim = (64, 64);
2121
IDatasetV2 train_ds, val_ds;
2222
Model model;
2323

src/TensorFlowNET.Examples/ImageProcessing/MnistCnnKerasSubclass.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ public bool Run()
6262
PrepareData();
6363

6464
Train();
65-
Test();
65+
66+
// Test();
6667

6768
return accuracy_test > 0.85;
6869
}
@@ -102,7 +103,7 @@ public override void Train()
102103
print($"Test Accuracy: {accuracy_test}");
103104
}
104105

105-
conv_net.save_weights("model.weights");
106+
conv_net.save_weights("weights.h5");
106107
}
107108

108109
public override void Test()
@@ -112,12 +113,14 @@ public override void Test()
112113
NumClasses = num_classes
113114
});
114115

115-
conv_net.load_weights("model.weights");
116-
117-
// Test model on validation set.
116+
// Test model on testing set.
118117
{
119118
x_test = x_test["::100"];
120119
y_test = y_test["::100"];
120+
121+
conv_net.build(x_test.shape);
122+
conv_net.load_weights("weights.h5");
123+
121124
var pred = conv_net.Apply(x_test);
122125
accuracy_test = (float)accuracy(pred, y_test);
123126
print($"Test Accuracy: {accuracy_test}");

src/TensorFlowNET.Examples/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void Main(string[] args)
3737
.Where(x => x.GetInterfaces().Contains(typeof(IExample)))
3838
//.Where(x => x.Name == nameof(WeatherPrediction))
3939
//.Where(x => x.Name == nameof(SentimentClassification))
40-
//.Where(x => x.Name == nameof(MnistInYOLOv3))
40+
.Where(x => x.Name == nameof(MnistCnnKerasSubclass))
4141
.ToArray();
4242

4343
Console.WriteLine(Environment.OSVersion, Color.Yellow);

src/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747

4848
<ItemGroup>
4949
<PackageReference Include="Colorful.Console" Version="1.2.15" />
50-
<PackageReference Include="Newtonsoft.Json" Version="13.0.1" />
5150
<PackageReference Include="OpenCvSharp4.runtime.win" Version="4.4.0.20200915" />
5251
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.10.0" Condition="'$(Configuration)'!='GPU'" />
5352
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.10.0" Condition="'$(Configuration)'=='GPU'" />

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy