Browse Source

Refine the resnet example.

tags/v0.100.5-BERT-load
Yaohui Liu Haiping 2 years ago
parent
commit
4f794e8afa
2 changed files with 16 additions and 12 deletions
  1. +8
    -6
      README.md
  2. +8
    -6
      docs/README-CN.md

+ 8
- 6
README.md View File

@@ -131,7 +131,7 @@ using static Tensorflow.KerasApi;
using Tensorflow;
using Tensorflow.NumPy;

var layers = new LayersApi();
var layers = keras.layers;
// input layer
var inputs = keras.Input(shape: (32, 32, 3), name: "img");
// convolutional layer
@@ -155,17 +155,19 @@ var model = keras.Model(inputs, outputs, name: "toy_resnet");
model.summary();
// compile keras model in tensorflow static graph
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
metrics: new[] { "acc" });
// prepare dataset
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
// normalize the input
x_train = x_train / 255.0f;
y_train = np_utils.to_categorical(y_train, 10);
// training
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)],
batch_size: 64,
epochs: 10,
validation_split: 0.2f);
batch_size: 64,
epochs: 10,
validation_split: 0.2f);
// save the model
model.save("./toy_resnet_model");
```

The F# example for linear regression is available [here](docs/Example-fsharp.md).


+ 8
- 6
docs/README-CN.md View File

@@ -130,7 +130,7 @@ using static Tensorflow.KerasApi;
using Tensorflow;
using Tensorflow.NumPy;

var layers = new LayersApi();
var layers = keras.layers;
// input layer
var inputs = keras.Input(shape: (32, 32, 3), name: "img");
// convolutional layer
@@ -154,17 +154,19 @@ var model = keras.Model(inputs, outputs, name: "toy_resnet");
model.summary();
// compile keras model in tensorflow static graph
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
metrics: new[] { "acc" });
// prepare dataset
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
// normalize the input
x_train = x_train / 255.0f;
y_train = np_utils.to_categorical(y_train, 10);
// training
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)],
batch_size: 64,
epochs: 10,
validation_split: 0.2f);
batch_size: 64,
epochs: 10,
validation_split: 0.2f);
// save the model
model.save("./toy_resnet_model");
```

此外,Tensorflow.NET也支持用F#搭建上述模型进行训练和推理。


Loading…
Cancel
Save