From ca99be32f6c453d73c35b831b675b592b5a178d0 Mon Sep 17 00:00:00 2001 From: Haiping Date: Sat, 27 Jul 2019 19:34:38 -0500 Subject: [PATCH] dynamic_rnn implementation (#323) --- src/TensorFlowNET.Core/APIs/tf.nn.cs | 2 +- .../Operations/BasicRNNCell.cs | 3 +- .../Operations/NnOps/rnn.cs | 140 ++++++++++++++++-- .../Operations/NnOps/rnn_cell_impl.cs | 57 +++++++ src/TensorFlowNET.Core/Operations/RNNCell.cs | 17 ++- .../Operations/TensorArray.cs | 54 +++++++ .../Operations/_GraphTensorArray.cs | 102 +++++++++++++ .../Operations/array_ops.py.cs | 11 ++ ...ta_flow_ops.py.cs => gen_data_flow_ops.cs} | 18 +++ .../Operations/rnn_cell_impl.cs | 8 - src/TensorFlowNET.Core/Tensors/Tensor.cs | 5 + src/TensorFlowNET.Core/Tensors/TensorShape.cs | 25 ++++ src/TensorFlowNET.Core/Util/nest.py.cs | 42 +++--- src/TensorFlowNET.Core/ops.py.cs | 5 + tensorflowlib/README.md | 2 +- .../ImageProcessing/DigitRecognitionCNN.cs | 4 +- 16 files changed, 446 insertions(+), 49 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs create mode 100644 src/TensorFlowNET.Core/Operations/TensorArray.cs create mode 100644 src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs rename src/TensorFlowNET.Core/Operations/{gen_data_flow_ops.py.cs => gen_data_flow_ops.cs} (63%) delete mode 100644 src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 9ec1ef58..0bc9d0e5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -75,7 +75,7 @@ namespace Tensorflow /// /// A pair (outputs, state) public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, - int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, + Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, parallel_iterations: parallel_iterations, swap_memory: swap_memory, diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index 61da061f..554e9f1a 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -24,7 +24,8 @@ namespace Tensorflow int _num_units; Func _activation; - protected override int state_size => _num_units; + public override int state_size => _num_units; + public override int output_size => _num_units; public BasicRNNCell(int num_units, Func activation = null, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 79d1df89..3200e13f 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -24,15 +24,15 @@ namespace Tensorflow.Operations { internal class rnn { - public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, - int? sequence_length = null, Tensor initial_state = null, + public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor, + Tensor sequence_length = null, Tensor initial_state = null, TF_DataType dtype = TF_DataType.DtInvalid, int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) { with(tf.variable_scope("rnn"), scope => { VariableScope varscope = scope; - var flat_input = nest.flatten(inputs); + var flat_input = nest.flatten(inputs_tensor); if (!time_major) { @@ -42,24 +42,146 @@ namespace Tensorflow.Operations parallel_iterations = parallel_iterations ?? 32; - if (sequence_length.HasValue) + if (sequence_length != null) throw new NotImplementedException("dynamic_rnn sequence_length has value"); var batch_size = _best_effort_input_batch_size(flat_input); + Tensor state = null; if (initial_state != null) - { - var state = initial_state; - } + state = initial_state; else + state = cell.get_initial_state(batch_size: batch_size, dtype: dtype); + + var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input); + + var (outputs, final_state) = _dynamic_rnn_loop( + cell, + inputs as Tensor, + state, + parallel_iterations: parallel_iterations.Value, + swap_memory: swap_memory, + sequence_length: sequence_length, + dtype: dtype); + }); + + throw new NotImplementedException(""); + } + + /// + /// Internal implementation of Dynamic RNN. + /// + /// + /// + /// + /// + /// + /// + /// + /// + private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state, + int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + var state = initial_state; + var state_size = cell.state_size; + + var flat_input = nest.flatten(inputs); + var flat_output_size = nest.flatten(cell.output_size); + + // Construct an initial output + var input_shape = array_ops.shape(flat_input[0]); + var time_steps = input_shape.slice(0); + var batch_size = _best_effort_input_batch_size(flat_input); + var inputs_got_shape = flat_input.Select(input_ => input_.TensorShape.with_rank_at_least(3)).ToArray(); + + var dims = inputs_got_shape[0].Dimensions.Take(2).ToArray(); + var (const_time_steps, const_batch_size) = (dims[0], dims[1]); + + foreach(var shape in inputs_got_shape) + { + if (shape[2] == -1) + throw new ValueError("Input size (depth of inputs) must be accessible via shape inference," + + " but saw value None."); + + var got_time_steps = shape.dims[0]; + var got_batch_size = shape.dims[1]; + + if (const_time_steps != got_time_steps) + throw new ValueError("Time steps is not the same for all the elements in the input in a " + + "batch."); + + if (const_batch_size != got_batch_size) + throw new ValueError("Batch_size is not the same for all the elements in the input."); + } + + Func _create_zero_arrays = (size_) => + { + var size = rnn_cell_impl._concat(batch_size, size_); + return array_ops.zeros( + array_ops.stack(size), dtype: _infer_state_dtype(dtype, state)); + }; + + // Prepare dynamic conditional copying of state & output + var flat_zero_output = flat_output_size.Select(output => _create_zero_arrays(output)).ToArray(); + var zero_output = nest.pack_sequence_as(structure: cell.output_size, flat_sequence: flat_zero_output); + + Tensor min_sequence_length = null, max_sequence_length = null; + if (sequence_length != null) + { + min_sequence_length = math_ops.reduce_min(sequence_length); + max_sequence_length = math_ops.reduce_max(sequence_length); + } + else + { + max_sequence_length = time_steps; + } + + var time = array_ops.constant(0, dtype: dtypes.int32, name: "time"); + + string base_name = null; + with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); + + Func _create_ta = (name, element_shape, dtype_) => + { + new TensorArray(dtype: dtype_, + size: time_steps, + element_shape: element_shape, + tensor_array_name: base_name + name); + throw new NotImplementedException(""); + }; + + bool in_graph_mode = true; + if (in_graph_mode) + { + foreach(var (i, out_size) in enumerate(flat_output_size)) { - cell.get_initial_state(batch_size: batch_size, dtype: dtype); + _create_ta($"output_{i}", + new TensorShape(const_batch_size).concatenate( + _maybe_tensor_shape_from_tensor(out_size)), + _infer_state_dtype(dtype, state)); + + + } - }); + } throw new NotImplementedException(""); } + private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape) + => shape.TensorShape; + + private static TensorShape _maybe_tensor_shape_from_tensor(int shape) + => new TensorShape(shape); + + private static TF_DataType _infer_state_dtype(TF_DataType explicit_dtype, Tensor state) + { + if (explicit_dtype != TF_DataType.DtInvalid) + return explicit_dtype; + + throw new NotImplementedException("_infer_state_dtype"); + } + /// /// Transposes the batch and time dimensions of a Tensor. /// diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs new file mode 100644 index 00000000..bd210ecd --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs @@ -0,0 +1,57 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Operations +{ + public class rnn_cell_impl + { + public BasicRNNCell BasicRNNCell(int num_units) + => new BasicRNNCell(num_units); + + public static Tensor _concat(Tensor prefix, int suffix, bool @static = false) + { + var p = prefix; + var p_static = tensor_util.constant_value(prefix); + if (p.NDims == 0) + p = array_ops.expand_dims(p, 0); + else if (p.NDims != 1) + throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}"); + + var s_tensor_shape = new TensorShape(suffix); + var s_static = s_tensor_shape.NDim > -1 ? + s_tensor_shape.Dimensions : + null; + var s = s_tensor_shape.is_fully_defined() ? + constant_op.constant(s_tensor_shape.Dimensions, dtype: dtypes.int32) : + null; + + if (@static) + { + if (p_static is null) return null; + var shape = new TensorShape(p_static).concatenate(s_static); + throw new NotImplementedException("RNNCell _concat"); + } + else + { + if (p is null || s is null) + throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}"); + return array_ops.concat(new[] { p, s }, 0); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/RNNCell.cs index 3b841087..57f46e7b 100644 --- a/src/TensorFlowNET.Core/Operations/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/RNNCell.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using Tensorflow.Operations; using Tensorflow.Util; using static Tensorflow.Python; @@ -48,7 +49,9 @@ namespace Tensorflow /// difference between TF and Keras RNN cell. /// protected bool _is_tf_rnn_cell = false; - protected virtual int state_size { get; } + public virtual int state_size { get; } + + public virtual int output_size { get; } public RNNCell(bool trainable = true, string name = null, @@ -89,12 +92,18 @@ namespace Tensorflow private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) { - nest.map_structure(x => + var output = nest.map_structure(s => { - throw new NotImplementedException(""); + var c = rnn_cell_impl._concat(batch_size, s); + var size = array_ops.zeros(c, dtype: dtype); + + var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); + size.set_shape(c_static); + + return size; }, state_size); - throw new NotImplementedException(""); + return output; } } } diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Operations/TensorArray.cs new file mode 100644 index 00000000..858dac47 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/TensorArray.cs @@ -0,0 +1,54 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + /// + /// TensorArray is designed to hide an underlying implementation object + /// and as such accesses many of that object's hidden fields. + /// + /// "Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. + /// This class is meant to be used with dynamic iteration primitives such as + /// `while_loop` and `map_fn`. It supports gradient back-propagation via special + /// "flow" control flow dependencies. + /// + public class TensorArray + { + _GraphTensorArray _implementation; + + public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null, + string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, TensorShape element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + _implementation = new _GraphTensorArray(dtype, + size: size, + dynamic_size: dynamic_size, + clear_after_read: clear_after_read, + tensor_array_name: tensor_array_name, + handle: handle, + flow: flow, + infer_shape: infer_shape, + element_shape: element_shape, + colocate_with_first_write_call: colocate_with_first_write_call, + name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs new file mode 100644 index 00000000..b4619c05 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -0,0 +1,102 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Python; + +namespace Tensorflow.Operations +{ + internal class _GraphTensorArray + { + TF_DataType _dtype; + + /// + /// Used to keep track of what tensors the TensorArray should be + /// colocated with. We choose to colocate the TensorArray with the + /// first tensor written to it. + /// + bool _colocate_with_first_write_call; + + bool _infer_shape; + List _element_shape; + + object _colocate_with; + + public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, + bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, TensorShape element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + clear_after_read = clear_after_read ?? true; + dynamic_size = dynamic_size ?? false; + + _dtype = dtype; + + _colocate_with_first_write_call = colocate_with_first_write_call; + if (colocate_with_first_write_call) + _colocate_with = new Tensor[0]; + + // Record the current static shape for the array elements. The element + // shape is defined either by `element_shape` or the shape of the tensor + // of the first write. If `infer_shape` is true, all writes checks for + // shape equality. + if(element_shape == null) + { + _infer_shape = infer_shape; + _element_shape = new List { }; + } + else + { + _infer_shape = true; + _element_shape = new List { }; + } + + with(ops.name_scope(name, "", new { handle, size, flow }), scope => + { + if(handle != null) + { + + } + else + { + Func<(Tensor, Tensor)> create = () => gen_data_flow_ops.tensor_array_v3(size, + dtype: dtype, + element_shape: element_shape, + identical_element_shapes: infer_shape, + dynamic_size: dynamic_size.Value, + clear_after_read: clear_after_read.Value, + tensor_array_name: tensor_array_name, + name: scope); + + // Construct the TensorArray with an empty device. The first + // write into the TensorArray from a Tensor with a set device + // will retroactively set the device value of this op. + if (colocate_with_first_write_call) + { + ops.colocate_with(ignore_existing: true); + create(); + } + else + { + + } + } + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 98b36bc6..c3f52cb8 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -29,6 +29,17 @@ namespace Tensorflow public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) => gen_array_ops.prevent_gradient(input, message: message, name: name); + internal static Tensor constant(object value, + TF_DataType dtype = TF_DataType.DtInvalid, + int[] shape = null, + string name = "Const", + bool verify_shape = false) => constant_op._constant_impl(value, + dtype, + shape, + name, + verify_shape: verify_shape, + allow_broadcast: false); + public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { dtype = dtype.as_base_dtype(); diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs similarity index 63% rename from src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs rename to src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 17f64a29..2cb9aac6 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -27,5 +27,23 @@ namespace Tensorflow return _op.outputs[0]; } + + public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, + int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, + bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) + { + var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new + { + size, + dtype, + element_shape, + dynamic_size, + clear_after_read, + identical_element_shapes, + tensor_array_name + }); + + return (null, null); + } } } diff --git a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs deleted file mode 100644 index 72f4b866..00000000 --- a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Tensorflow.Operations -{ - public class rnn_cell_impl - { - public BasicRNNCell BasicRNNCell(int num_units) - => new BasicRNNCell(num_units); - } -} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index e7049e7e..aebca212 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -110,6 +110,11 @@ namespace Tensorflow this.shape = shape.Dimensions; } + public void set_shape(Tensor shape) + { + this.shape = shape is null ? null : shape.shape; + } + public int[] dims => shape; /// diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 8559cbd4..c19ecae7 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -9,6 +9,8 @@ namespace Tensorflow /// public class TensorShape : Shape { + public int[] dims => Dimensions; + public TensorShape(TensorShapeProto proto) { if (proto.UnknownRank) return; @@ -45,6 +47,29 @@ namespace Tensorflow throw new NotImplementedException("TensorShape is_compatible_with"); } + public TensorShape with_rank_at_least(int rank) + { + if (rank != this.NDim) + throw new ValueError($"Shape {this} must have rank at least {rank}"); + else + return this; + } + + /// + /// Returns the concatenation of the dimension in `self` and `other`. + /// + /// + /// + public TensorShape concatenate(int[] other_) + { + var other = new TensorShape(other_); + + if (NDim < 0 || other.NDim < 0) + return new TensorShape(); + else + return new TensorShape(NDim + other.NDim); + } + public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index ae13b31c..5f782ba2 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -223,31 +223,27 @@ namespace Tensorflow.Util private static void _flatten_recursive(T obj, List list) { - if (obj is string) - { - list.Add(obj); - return; - } - if (obj is IDictionary) - { - var dict = obj as IDictionary; - foreach (var key in _sorted(dict)) - _flatten_recursive((T)dict[key], list); - return; - } - if (obj is NDArray) - { - list.Add(obj); - return; - } - if (obj is IEnumerable) + + switch(obj) { - var structure = obj as IEnumerable; - foreach (var child in structure) - _flatten_recursive((T)child, list); - return; + case IDictionary dict: + foreach (var key in _sorted(dict)) + _flatten_recursive((T)dict[key], list); + break; + case String str: + list.Add(obj); + break; + case NDArray nd: + list.Add(obj); + break; + case IEnumerable structure: + foreach (var child in structure) + _flatten_recursive((T)child, list); + break; + default: + list.Add(obj); + break; } - list.Add(obj); } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 90cda74e..8f7fce29 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -314,6 +314,11 @@ namespace Tensorflow return uid_number++; } + public static void colocate_with(bool ignore_existing = false) + { + _colocate_with_for_gradient(null, null, ignore_existing); + } + public static void colocate_with(Operation op, bool ignore_existing = false) { _colocate_with_for_gradient(op, null, ignore_existing); diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index 77a78a66..63cba815 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -40,7 +40,7 @@ Before running verify you installed CUDA and cuDNN https://www.tensorflow.org/install/source_windows -pacman -S git patch unzip +`pacman -S git patch unzip` 1. Build static library diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs index 25ffc46a..2dc355c4 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs @@ -42,7 +42,7 @@ namespace TensorFlowNET.Examples.ImageProcess int n_channels = 1; // Hyper-parameters - int epochs = 10; + int epochs = 5; // accuracy > 98% int batch_size = 100; float learning_rate = 0.001f; Datasets mnist; @@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples.ImageProcess Test(sess); }); - return loss_test < 0.09 && accuracy_test > 0.95; + return loss_test < 0.05 && accuracy_test > 0.98; } public Graph BuildGraph()