Browse Source

dynamic_rnn implementation (#323)

tags/v0.12
Haiping GitHub 6 years ago
parent
commit
ca99be32f6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 446 additions and 49 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
  3. +131
    -9
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  4. +57
    -0
      src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
  5. +13
    -4
      src/TensorFlowNET.Core/Operations/RNNCell.cs
  6. +54
    -0
      src/TensorFlowNET.Core/Operations/TensorArray.cs
  7. +102
    -0
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  8. +11
    -0
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  9. +18
    -0
      src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
  10. +0
    -8
      src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
  11. +5
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  12. +25
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  13. +19
    -23
      src/TensorFlowNET.Core/Util/nest.py.cs
  14. +5
    -0
      src/TensorFlowNET.Core/ops.py.cs
  15. +1
    -1
      tensorflowlib/README.md
  16. +2
    -2
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -75,7 +75,7 @@ namespace Tensorflow
/// <param name="time_major"></param> /// <param name="time_major"></param>
/// <returns>A pair (outputs, state)</returns> /// <returns>A pair (outputs, state)</returns>
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
=> rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,
parallel_iterations: parallel_iterations, swap_memory: swap_memory, parallel_iterations: parallel_iterations, swap_memory: swap_memory,


+ 2
- 1
src/TensorFlowNET.Core/Operations/BasicRNNCell.cs View File

@@ -24,7 +24,8 @@ namespace Tensorflow
int _num_units; int _num_units;
Func<Tensor, string, Tensor> _activation; Func<Tensor, string, Tensor> _activation;


protected override int state_size => _num_units;
public override int state_size => _num_units;
public override int output_size => _num_units;


public BasicRNNCell(int num_units, public BasicRNNCell(int num_units,
Func<Tensor, string, Tensor> activation = null, Func<Tensor, string, Tensor> activation = null,


+ 131
- 9
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -24,15 +24,15 @@ namespace Tensorflow.Operations
{ {
internal class rnn internal class rnn
{ {
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
int? sequence_length = null, Tensor initial_state = null,
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor,
Tensor sequence_length = null, Tensor initial_state = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
{ {
with(tf.variable_scope("rnn"), scope => with(tf.variable_scope("rnn"), scope =>
{ {
VariableScope varscope = scope; VariableScope varscope = scope;
var flat_input = nest.flatten(inputs);
var flat_input = nest.flatten(inputs_tensor);


if (!time_major) if (!time_major)
{ {
@@ -42,24 +42,146 @@ namespace Tensorflow.Operations


parallel_iterations = parallel_iterations ?? 32; parallel_iterations = parallel_iterations ?? 32;


if (sequence_length.HasValue)
if (sequence_length != null)
throw new NotImplementedException("dynamic_rnn sequence_length has value"); throw new NotImplementedException("dynamic_rnn sequence_length has value");


var batch_size = _best_effort_input_batch_size(flat_input); var batch_size = _best_effort_input_batch_size(flat_input);


Tensor state = null;
if (initial_state != null) if (initial_state != null)
{
var state = initial_state;
}
state = initial_state;
else else
state = cell.get_initial_state(batch_size: batch_size, dtype: dtype);

var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input);

var (outputs, final_state) = _dynamic_rnn_loop(
cell,
inputs as Tensor,
state,
parallel_iterations: parallel_iterations.Value,
swap_memory: swap_memory,
sequence_length: sequence_length,
dtype: dtype);
});

throw new NotImplementedException("");
}

/// <summary>
/// Internal implementation of Dynamic RNN.
/// </summary>
/// <param name="cell"></param>
/// <param name="inputs"></param>
/// <param name="initial_state"></param>
/// <param name="parallel_iterations"></param>
/// <param name="swap_memory"></param>
/// <param name="sequence_length"></param>
/// <param name="dtype"></param>
/// <returns></returns>
private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state,
int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
var state = initial_state;
var state_size = cell.state_size;

var flat_input = nest.flatten(inputs);
var flat_output_size = nest.flatten(cell.output_size);

// Construct an initial output
var input_shape = array_ops.shape(flat_input[0]);
var time_steps = input_shape.slice(0);
var batch_size = _best_effort_input_batch_size(flat_input);
var inputs_got_shape = flat_input.Select(input_ => input_.TensorShape.with_rank_at_least(3)).ToArray();

var dims = inputs_got_shape[0].Dimensions.Take(2).ToArray();
var (const_time_steps, const_batch_size) = (dims[0], dims[1]);

foreach(var shape in inputs_got_shape)
{
if (shape[2] == -1)
throw new ValueError("Input size (depth of inputs) must be accessible via shape inference," +
" but saw value None.");

var got_time_steps = shape.dims[0];
var got_batch_size = shape.dims[1];

if (const_time_steps != got_time_steps)
throw new ValueError("Time steps is not the same for all the elements in the input in a " +
"batch.");

if (const_batch_size != got_batch_size)
throw new ValueError("Batch_size is not the same for all the elements in the input.");
}

Func<int, Tensor> _create_zero_arrays = (size_) =>
{
var size = rnn_cell_impl._concat(batch_size, size_);
return array_ops.zeros(
array_ops.stack(size), dtype: _infer_state_dtype(dtype, state));
};

// Prepare dynamic conditional copying of state & output
var flat_zero_output = flat_output_size.Select(output => _create_zero_arrays(output)).ToArray();
var zero_output = nest.pack_sequence_as(structure: cell.output_size, flat_sequence: flat_zero_output);

Tensor min_sequence_length = null, max_sequence_length = null;
if (sequence_length != null)
{
min_sequence_length = math_ops.reduce_min(sequence_length);
max_sequence_length = math_ops.reduce_max(sequence_length);
}
else
{
max_sequence_length = time_steps;
}

var time = array_ops.constant(0, dtype: dtypes.int32, name: "time");

string base_name = null;
with(ops.name_scope("dynamic_rnn"), scope => base_name = scope);

Func<string, TensorShape, TF_DataType, Tensor> _create_ta = (name, element_shape, dtype_) =>
{
new TensorArray(dtype: dtype_,
size: time_steps,
element_shape: element_shape,
tensor_array_name: base_name + name);
throw new NotImplementedException("");
};

bool in_graph_mode = true;
if (in_graph_mode)
{
foreach(var (i, out_size) in enumerate(flat_output_size))
{ {
cell.get_initial_state(batch_size: batch_size, dtype: dtype);
_create_ta($"output_{i}",
new TensorShape(const_batch_size).concatenate(
_maybe_tensor_shape_from_tensor(out_size)),
_infer_state_dtype(dtype, state));


} }
});
}


throw new NotImplementedException(""); throw new NotImplementedException("");
} }


private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape)
=> shape.TensorShape;

private static TensorShape _maybe_tensor_shape_from_tensor(int shape)
=> new TensorShape(shape);

private static TF_DataType _infer_state_dtype(TF_DataType explicit_dtype, Tensor state)
{
if (explicit_dtype != TF_DataType.DtInvalid)
return explicit_dtype;

throw new NotImplementedException("_infer_state_dtype");
}

/// <summary> /// <summary>
/// Transposes the batch and time dimensions of a Tensor. /// Transposes the batch and time dimensions of a Tensor.
/// </summary> /// </summary>


+ 57
- 0
src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs View File

@@ -0,0 +1,57 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;

namespace Tensorflow.Operations
{
public class rnn_cell_impl
{
public BasicRNNCell BasicRNNCell(int num_units)
=> new BasicRNNCell(num_units);

public static Tensor _concat(Tensor prefix, int suffix, bool @static = false)
{
var p = prefix;
var p_static = tensor_util.constant_value(prefix);
if (p.NDims == 0)
p = array_ops.expand_dims(p, 0);
else if (p.NDims != 1)
throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}");

var s_tensor_shape = new TensorShape(suffix);
var s_static = s_tensor_shape.NDim > -1 ?
s_tensor_shape.Dimensions :
null;
var s = s_tensor_shape.is_fully_defined() ?
constant_op.constant(s_tensor_shape.Dimensions, dtype: dtypes.int32) :
null;

if (@static)
{
if (p_static is null) return null;
var shape = new TensorShape(p_static).concatenate(s_static);
throw new NotImplementedException("RNNCell _concat");
}
else
{
if (p is null || s is null)
throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}");
return array_ops.concat(new[] { p, s }, 0);
}
}
}
}

+ 13
- 4
src/TensorFlowNET.Core/Operations/RNNCell.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using Tensorflow.Operations;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Python; using static Tensorflow.Python;


@@ -48,7 +49,9 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell. /// difference between TF and Keras RNN cell.
/// </summary> /// </summary>
protected bool _is_tf_rnn_cell = false; protected bool _is_tf_rnn_cell = false;
protected virtual int state_size { get; }
public virtual int state_size { get; }

public virtual int output_size { get; }


public RNNCell(bool trainable = true, public RNNCell(bool trainable = true,
string name = null, string name = null,
@@ -89,12 +92,18 @@ namespace Tensorflow


private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype)
{ {
nest.map_structure(x =>
var output = nest.map_structure(s =>
{ {
throw new NotImplementedException("");
var c = rnn_cell_impl._concat(batch_size, s);
var size = array_ops.zeros(c, dtype: dtype);

var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
size.set_shape(c_static);

return size;
}, state_size); }, state_size);


throw new NotImplementedException("");
return output;
} }
} }
} }

+ 54
- 0
src/TensorFlowNET.Core/Operations/TensorArray.cs View File

@@ -0,0 +1,54 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
/// <summary>
/// TensorArray is designed to hide an underlying implementation object
/// and as such accesses many of that object's hidden fields.
///
/// "Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
/// This class is meant to be used with dynamic iteration primitives such as
/// `while_loop` and `map_fn`. It supports gradient back-propagation via special
/// "flow" control flow dependencies.
/// </summary>
public class TensorArray
{
_GraphTensorArray _implementation;

public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null,
string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, TensorShape element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_implementation = new _GraphTensorArray(dtype,
size: size,
dynamic_size: dynamic_size,
clear_after_read: clear_after_read,
tensor_array_name: tensor_array_name,
handle: handle,
flow: flow,
infer_shape: infer_shape,
element_shape: element_shape,
colocate_with_first_write_call: colocate_with_first_write_call,
name: name);
}
}
}

+ 102
- 0
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -0,0 +1,102 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Python;

namespace Tensorflow.Operations
{
internal class _GraphTensorArray
{
TF_DataType _dtype;

/// <summary>
/// Used to keep track of what tensors the TensorArray should be
/// colocated with. We choose to colocate the TensorArray with the
/// first tensor written to it.
/// </summary>
bool _colocate_with_first_write_call;

bool _infer_shape;
List<TensorShape> _element_shape;

object _colocate_with;

public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, TensorShape element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
clear_after_read = clear_after_read ?? true;
dynamic_size = dynamic_size ?? false;

_dtype = dtype;

_colocate_with_first_write_call = colocate_with_first_write_call;
if (colocate_with_first_write_call)
_colocate_with = new Tensor[0];

// Record the current static shape for the array elements. The element
// shape is defined either by `element_shape` or the shape of the tensor
// of the first write. If `infer_shape` is true, all writes checks for
// shape equality.
if(element_shape == null)
{
_infer_shape = infer_shape;
_element_shape = new List<TensorShape> { };
}
else
{
_infer_shape = true;
_element_shape = new List<TensorShape> { };
}

with(ops.name_scope(name, "", new { handle, size, flow }), scope =>
{
if(handle != null)
{

}
else
{
Func<(Tensor, Tensor)> create = () => gen_data_flow_ops.tensor_array_v3(size,
dtype: dtype,
element_shape: element_shape,
identical_element_shapes: infer_shape,
dynamic_size: dynamic_size.Value,
clear_after_read: clear_after_read.Value,
tensor_array_name: tensor_array_name,
name: scope);

// Construct the TensorArray with an empty device. The first
// write into the TensorArray from a Tensor with a set device
// will retroactively set the device value of this op.
if (colocate_with_first_write_call)
{
ops.colocate_with(ignore_existing: true);
create();
}
else
{

}
}
});
}
}
}

+ 11
- 0
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -29,6 +29,17 @@ namespace Tensorflow
public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) public static Tensor prevent_gradient(Tensor input, string message = "", string name = null)
=> gen_array_ops.prevent_gradient(input, message: message, name: name); => gen_array_ops.prevent_gradient(input, message: message, name: name);
internal static Tensor constant(object value,
TF_DataType dtype = TF_DataType.DtInvalid,
int[] shape = null,
string name = "Const",
bool verify_shape = false) => constant_op._constant_impl(value,
dtype,
shape,
name,
verify_shape: verify_shape,
allow_broadcast: false);
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{ {
dtype = dtype.as_base_dtype(); dtype = dtype.as_base_dtype();


src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs → src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs View File

@@ -27,5 +27,23 @@ namespace Tensorflow


return _op.outputs[0]; return _op.outputs[0];
} }

public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid,
int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null)
{
var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new
{
size,
dtype,
element_shape,
dynamic_size,
clear_after_read,
identical_element_shapes,
tensor_array_name
});

return (null, null);
}
} }
} }

+ 0
- 8
src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs View File

@@ -1,8 +0,0 @@
namespace Tensorflow.Operations
{
public class rnn_cell_impl
{
public BasicRNNCell BasicRNNCell(int num_units)
=> new BasicRNNCell(num_units);
}
}

+ 5
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -110,6 +110,11 @@ namespace Tensorflow
this.shape = shape.Dimensions; this.shape = shape.Dimensions;
} }


public void set_shape(Tensor shape)
{
this.shape = shape is null ? null : shape.shape;
}

public int[] dims => shape; public int[] dims => shape;


/// <summary> /// <summary>


+ 25
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -9,6 +9,8 @@ namespace Tensorflow
/// </summary> /// </summary>
public class TensorShape : Shape public class TensorShape : Shape
{ {
public int[] dims => Dimensions;

public TensorShape(TensorShapeProto proto) public TensorShape(TensorShapeProto proto)
{ {
if (proto.UnknownRank) return; if (proto.UnknownRank) return;
@@ -45,6 +47,29 @@ namespace Tensorflow
throw new NotImplementedException("TensorShape is_compatible_with"); throw new NotImplementedException("TensorShape is_compatible_with");
} }


public TensorShape with_rank_at_least(int rank)
{
if (rank != this.NDim)
throw new ValueError($"Shape {this} must have rank at least {rank}");
else
return this;
}

/// <summary>
/// Returns the concatenation of the dimension in `self` and `other`.
/// </summary>
/// <param name="other"></param>
/// <returns></returns>
public TensorShape concatenate(int[] other_)
{
var other = new TensorShape(other_);

if (NDim < 0 || other.NDim < 0)
return new TensorShape();
else
return new TensorShape(NDim + other.NDim);
}

public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);


+ 19
- 23
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -223,31 +223,27 @@ namespace Tensorflow.Util
private static void _flatten_recursive<T>(T obj, List<T> list) private static void _flatten_recursive<T>(T obj, List<T> list)
{ {
if (obj is string)
{
list.Add(obj);
return;
}
if (obj is IDictionary)
{
var dict = obj as IDictionary;
foreach (var key in _sorted(dict))
_flatten_recursive((T)dict[key], list);
return;
}
if (obj is NDArray)
{
list.Add(obj);
return;
}
if (obj is IEnumerable)
switch(obj)
{ {
var structure = obj as IEnumerable;
foreach (var child in structure)
_flatten_recursive((T)child, list);
return;
case IDictionary dict:
foreach (var key in _sorted(dict))
_flatten_recursive((T)dict[key], list);
break;
case String str:
list.Add(obj);
break;
case NDArray nd:
list.Add(obj);
break;
case IEnumerable structure:
foreach (var child in structure)
_flatten_recursive((T)child, list);
break;
default:
list.Add(obj);
break;
} }
list.Add(obj);
} }


+ 5
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -314,6 +314,11 @@ namespace Tensorflow
return uid_number++; return uid_number++;
} }


public static void colocate_with(bool ignore_existing = false)
{
_colocate_with_for_gradient(null, null, ignore_existing);
}

public static void colocate_with(Operation op, bool ignore_existing = false) public static void colocate_with(Operation op, bool ignore_existing = false)
{ {
_colocate_with_for_gradient(op, null, ignore_existing); _colocate_with_for_gradient(op, null, ignore_existing);


+ 1
- 1
tensorflowlib/README.md View File

@@ -40,7 +40,7 @@ Before running verify you installed CUDA and cuDNN


https://www.tensorflow.org/install/source_windows https://www.tensorflow.org/install/source_windows


pacman -S git patch unzip
`pacman -S git patch unzip`


1. Build static library 1. Build static library




+ 2
- 2
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs View File

@@ -42,7 +42,7 @@ namespace TensorFlowNET.Examples.ImageProcess
int n_channels = 1; int n_channels = 1;


// Hyper-parameters // Hyper-parameters
int epochs = 10;
int epochs = 5; // accuracy > 98%
int batch_size = 100; int batch_size = 100;
float learning_rate = 0.001f; float learning_rate = 0.001f;
Datasets<DataSetMnist> mnist; Datasets<DataSetMnist> mnist;
@@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples.ImageProcess
Test(sess); Test(sess);
}); });


return loss_test < 0.09 && accuracy_test > 0.95;
return loss_test < 0.05 && accuracy_test > 0.98;
} }


public Graph BuildGraph() public Graph BuildGraph()


Loading…
Cancel
Save