@@ -75,7 +75,7 @@ namespace Tensorflow | |||||
/// <param name="time_major"></param> | /// <param name="time_major"></param> | ||||
/// <returns>A pair (outputs, state)</returns> | /// <returns>A pair (outputs, state)</returns> | ||||
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, | public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, | ||||
int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | ||||
=> rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, | => rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype, | ||||
parallel_iterations: parallel_iterations, swap_memory: swap_memory, | parallel_iterations: parallel_iterations, swap_memory: swap_memory, | ||||
@@ -24,7 +24,8 @@ namespace Tensorflow | |||||
int _num_units; | int _num_units; | ||||
Func<Tensor, string, Tensor> _activation; | Func<Tensor, string, Tensor> _activation; | ||||
protected override int state_size => _num_units; | |||||
public override int state_size => _num_units; | |||||
public override int output_size => _num_units; | |||||
public BasicRNNCell(int num_units, | public BasicRNNCell(int num_units, | ||||
Func<Tensor, string, Tensor> activation = null, | Func<Tensor, string, Tensor> activation = null, | ||||
@@ -24,15 +24,15 @@ namespace Tensorflow.Operations | |||||
{ | { | ||||
internal class rnn | internal class rnn | ||||
{ | { | ||||
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, | |||||
int? sequence_length = null, Tensor initial_state = null, | |||||
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor, | |||||
Tensor sequence_length = null, Tensor initial_state = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | int? parallel_iterations = null, bool swap_memory = false, bool time_major = false) | ||||
{ | { | ||||
with(tf.variable_scope("rnn"), scope => | with(tf.variable_scope("rnn"), scope => | ||||
{ | { | ||||
VariableScope varscope = scope; | VariableScope varscope = scope; | ||||
var flat_input = nest.flatten(inputs); | |||||
var flat_input = nest.flatten(inputs_tensor); | |||||
if (!time_major) | if (!time_major) | ||||
{ | { | ||||
@@ -42,24 +42,146 @@ namespace Tensorflow.Operations | |||||
parallel_iterations = parallel_iterations ?? 32; | parallel_iterations = parallel_iterations ?? 32; | ||||
if (sequence_length.HasValue) | |||||
if (sequence_length != null) | |||||
throw new NotImplementedException("dynamic_rnn sequence_length has value"); | throw new NotImplementedException("dynamic_rnn sequence_length has value"); | ||||
var batch_size = _best_effort_input_batch_size(flat_input); | var batch_size = _best_effort_input_batch_size(flat_input); | ||||
Tensor state = null; | |||||
if (initial_state != null) | if (initial_state != null) | ||||
{ | |||||
var state = initial_state; | |||||
} | |||||
state = initial_state; | |||||
else | else | ||||
state = cell.get_initial_state(batch_size: batch_size, dtype: dtype); | |||||
var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input); | |||||
var (outputs, final_state) = _dynamic_rnn_loop( | |||||
cell, | |||||
inputs as Tensor, | |||||
state, | |||||
parallel_iterations: parallel_iterations.Value, | |||||
swap_memory: swap_memory, | |||||
sequence_length: sequence_length, | |||||
dtype: dtype); | |||||
}); | |||||
throw new NotImplementedException(""); | |||||
} | |||||
/// <summary> | |||||
/// Internal implementation of Dynamic RNN. | |||||
/// </summary> | |||||
/// <param name="cell"></param> | |||||
/// <param name="inputs"></param> | |||||
/// <param name="initial_state"></param> | |||||
/// <param name="parallel_iterations"></param> | |||||
/// <param name="swap_memory"></param> | |||||
/// <param name="sequence_length"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <returns></returns> | |||||
private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state, | |||||
int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
var state = initial_state; | |||||
var state_size = cell.state_size; | |||||
var flat_input = nest.flatten(inputs); | |||||
var flat_output_size = nest.flatten(cell.output_size); | |||||
// Construct an initial output | |||||
var input_shape = array_ops.shape(flat_input[0]); | |||||
var time_steps = input_shape.slice(0); | |||||
var batch_size = _best_effort_input_batch_size(flat_input); | |||||
var inputs_got_shape = flat_input.Select(input_ => input_.TensorShape.with_rank_at_least(3)).ToArray(); | |||||
var dims = inputs_got_shape[0].Dimensions.Take(2).ToArray(); | |||||
var (const_time_steps, const_batch_size) = (dims[0], dims[1]); | |||||
foreach(var shape in inputs_got_shape) | |||||
{ | |||||
if (shape[2] == -1) | |||||
throw new ValueError("Input size (depth of inputs) must be accessible via shape inference," + | |||||
" but saw value None."); | |||||
var got_time_steps = shape.dims[0]; | |||||
var got_batch_size = shape.dims[1]; | |||||
if (const_time_steps != got_time_steps) | |||||
throw new ValueError("Time steps is not the same for all the elements in the input in a " + | |||||
"batch."); | |||||
if (const_batch_size != got_batch_size) | |||||
throw new ValueError("Batch_size is not the same for all the elements in the input."); | |||||
} | |||||
Func<int, Tensor> _create_zero_arrays = (size_) => | |||||
{ | |||||
var size = rnn_cell_impl._concat(batch_size, size_); | |||||
return array_ops.zeros( | |||||
array_ops.stack(size), dtype: _infer_state_dtype(dtype, state)); | |||||
}; | |||||
// Prepare dynamic conditional copying of state & output | |||||
var flat_zero_output = flat_output_size.Select(output => _create_zero_arrays(output)).ToArray(); | |||||
var zero_output = nest.pack_sequence_as(structure: cell.output_size, flat_sequence: flat_zero_output); | |||||
Tensor min_sequence_length = null, max_sequence_length = null; | |||||
if (sequence_length != null) | |||||
{ | |||||
min_sequence_length = math_ops.reduce_min(sequence_length); | |||||
max_sequence_length = math_ops.reduce_max(sequence_length); | |||||
} | |||||
else | |||||
{ | |||||
max_sequence_length = time_steps; | |||||
} | |||||
var time = array_ops.constant(0, dtype: dtypes.int32, name: "time"); | |||||
string base_name = null; | |||||
with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); | |||||
Func<string, TensorShape, TF_DataType, Tensor> _create_ta = (name, element_shape, dtype_) => | |||||
{ | |||||
new TensorArray(dtype: dtype_, | |||||
size: time_steps, | |||||
element_shape: element_shape, | |||||
tensor_array_name: base_name + name); | |||||
throw new NotImplementedException(""); | |||||
}; | |||||
bool in_graph_mode = true; | |||||
if (in_graph_mode) | |||||
{ | |||||
foreach(var (i, out_size) in enumerate(flat_output_size)) | |||||
{ | { | ||||
cell.get_initial_state(batch_size: batch_size, dtype: dtype); | |||||
_create_ta($"output_{i}", | |||||
new TensorShape(const_batch_size).concatenate( | |||||
_maybe_tensor_shape_from_tensor(out_size)), | |||||
_infer_state_dtype(dtype, state)); | |||||
} | } | ||||
}); | |||||
} | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape) | |||||
=> shape.TensorShape; | |||||
private static TensorShape _maybe_tensor_shape_from_tensor(int shape) | |||||
=> new TensorShape(shape); | |||||
private static TF_DataType _infer_state_dtype(TF_DataType explicit_dtype, Tensor state) | |||||
{ | |||||
if (explicit_dtype != TF_DataType.DtInvalid) | |||||
return explicit_dtype; | |||||
throw new NotImplementedException("_infer_state_dtype"); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Transposes the batch and time dimensions of a Tensor. | /// Transposes the batch and time dimensions of a Tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -0,0 +1,57 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public class rnn_cell_impl | |||||
{ | |||||
public BasicRNNCell BasicRNNCell(int num_units) | |||||
=> new BasicRNNCell(num_units); | |||||
public static Tensor _concat(Tensor prefix, int suffix, bool @static = false) | |||||
{ | |||||
var p = prefix; | |||||
var p_static = tensor_util.constant_value(prefix); | |||||
if (p.NDims == 0) | |||||
p = array_ops.expand_dims(p, 0); | |||||
else if (p.NDims != 1) | |||||
throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}"); | |||||
var s_tensor_shape = new TensorShape(suffix); | |||||
var s_static = s_tensor_shape.NDim > -1 ? | |||||
s_tensor_shape.Dimensions : | |||||
null; | |||||
var s = s_tensor_shape.is_fully_defined() ? | |||||
constant_op.constant(s_tensor_shape.Dimensions, dtype: dtypes.int32) : | |||||
null; | |||||
if (@static) | |||||
{ | |||||
if (p_static is null) return null; | |||||
var shape = new TensorShape(p_static).concatenate(s_static); | |||||
throw new NotImplementedException("RNNCell _concat"); | |||||
} | |||||
else | |||||
{ | |||||
if (p is null || s is null) | |||||
throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}"); | |||||
return array_ops.concat(new[] { p, s }, 0); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -15,6 +15,7 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using Tensorflow.Operations; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Python; | using static Tensorflow.Python; | ||||
@@ -48,7 +49,9 @@ namespace Tensorflow | |||||
/// difference between TF and Keras RNN cell. | /// difference between TF and Keras RNN cell. | ||||
/// </summary> | /// </summary> | ||||
protected bool _is_tf_rnn_cell = false; | protected bool _is_tf_rnn_cell = false; | ||||
protected virtual int state_size { get; } | |||||
public virtual int state_size { get; } | |||||
public virtual int output_size { get; } | |||||
public RNNCell(bool trainable = true, | public RNNCell(bool trainable = true, | ||||
string name = null, | string name = null, | ||||
@@ -89,12 +92,18 @@ namespace Tensorflow | |||||
private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) | private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) | ||||
{ | { | ||||
nest.map_structure(x => | |||||
var output = nest.map_structure(s => | |||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
var c = rnn_cell_impl._concat(batch_size, s); | |||||
var size = array_ops.zeros(c, dtype: dtype); | |||||
var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); | |||||
size.set_shape(c_static); | |||||
return size; | |||||
}, state_size); | }, state_size); | ||||
throw new NotImplementedException(""); | |||||
return output; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,54 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
/// <summary> | |||||
/// TensorArray is designed to hide an underlying implementation object | |||||
/// and as such accesses many of that object's hidden fields. | |||||
/// | |||||
/// "Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. | |||||
/// This class is meant to be used with dynamic iteration primitives such as | |||||
/// `while_loop` and `map_fn`. It supports gradient back-propagation via special | |||||
/// "flow" control flow dependencies. | |||||
/// </summary> | |||||
public class TensorArray | |||||
{ | |||||
_GraphTensorArray _implementation; | |||||
public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null, | |||||
string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||||
bool infer_shape = true, TensorShape element_shape = null, | |||||
bool colocate_with_first_write_call = true, string name = null) | |||||
{ | |||||
_implementation = new _GraphTensorArray(dtype, | |||||
size: size, | |||||
dynamic_size: dynamic_size, | |||||
clear_after_read: clear_after_read, | |||||
tensor_array_name: tensor_array_name, | |||||
handle: handle, | |||||
flow: flow, | |||||
infer_shape: infer_shape, | |||||
element_shape: element_shape, | |||||
colocate_with_first_write_call: colocate_with_first_write_call, | |||||
name: name); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,102 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using static Tensorflow.Python; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
internal class _GraphTensorArray | |||||
{ | |||||
TF_DataType _dtype; | |||||
/// <summary> | |||||
/// Used to keep track of what tensors the TensorArray should be | |||||
/// colocated with. We choose to colocate the TensorArray with the | |||||
/// first tensor written to it. | |||||
/// </summary> | |||||
bool _colocate_with_first_write_call; | |||||
bool _infer_shape; | |||||
List<TensorShape> _element_shape; | |||||
object _colocate_with; | |||||
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||||
bool infer_shape = true, TensorShape element_shape = null, | |||||
bool colocate_with_first_write_call = true, string name = null) | |||||
{ | |||||
clear_after_read = clear_after_read ?? true; | |||||
dynamic_size = dynamic_size ?? false; | |||||
_dtype = dtype; | |||||
_colocate_with_first_write_call = colocate_with_first_write_call; | |||||
if (colocate_with_first_write_call) | |||||
_colocate_with = new Tensor[0]; | |||||
// Record the current static shape for the array elements. The element | |||||
// shape is defined either by `element_shape` or the shape of the tensor | |||||
// of the first write. If `infer_shape` is true, all writes checks for | |||||
// shape equality. | |||||
if(element_shape == null) | |||||
{ | |||||
_infer_shape = infer_shape; | |||||
_element_shape = new List<TensorShape> { }; | |||||
} | |||||
else | |||||
{ | |||||
_infer_shape = true; | |||||
_element_shape = new List<TensorShape> { }; | |||||
} | |||||
with(ops.name_scope(name, "", new { handle, size, flow }), scope => | |||||
{ | |||||
if(handle != null) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
Func<(Tensor, Tensor)> create = () => gen_data_flow_ops.tensor_array_v3(size, | |||||
dtype: dtype, | |||||
element_shape: element_shape, | |||||
identical_element_shapes: infer_shape, | |||||
dynamic_size: dynamic_size.Value, | |||||
clear_after_read: clear_after_read.Value, | |||||
tensor_array_name: tensor_array_name, | |||||
name: scope); | |||||
// Construct the TensorArray with an empty device. The first | |||||
// write into the TensorArray from a Tensor with a set device | |||||
// will retroactively set the device value of this op. | |||||
if (colocate_with_first_write_call) | |||||
{ | |||||
ops.colocate_with(ignore_existing: true); | |||||
create(); | |||||
} | |||||
else | |||||
{ | |||||
} | |||||
} | |||||
}); | |||||
} | |||||
} | |||||
} |
@@ -29,6 +29,17 @@ namespace Tensorflow | |||||
public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) | public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) | ||||
=> gen_array_ops.prevent_gradient(input, message: message, name: name); | => gen_array_ops.prevent_gradient(input, message: message, name: name); | ||||
internal static Tensor constant(object value, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
int[] shape = null, | |||||
string name = "Const", | |||||
bool verify_shape = false) => constant_op._constant_impl(value, | |||||
dtype, | |||||
shape, | |||||
name, | |||||
verify_shape: verify_shape, | |||||
allow_broadcast: false); | |||||
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
{ | { | ||||
dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
@@ -27,5 +27,23 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||||
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | |||||
{ | |||||
size, | |||||
dtype, | |||||
element_shape, | |||||
dynamic_size, | |||||
clear_after_read, | |||||
identical_element_shapes, | |||||
tensor_array_name | |||||
}); | |||||
return (null, null); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,8 +0,0 @@ | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public class rnn_cell_impl | |||||
{ | |||||
public BasicRNNCell BasicRNNCell(int num_units) | |||||
=> new BasicRNNCell(num_units); | |||||
} | |||||
} |
@@ -110,6 +110,11 @@ namespace Tensorflow | |||||
this.shape = shape.Dimensions; | this.shape = shape.Dimensions; | ||||
} | } | ||||
public void set_shape(Tensor shape) | |||||
{ | |||||
this.shape = shape is null ? null : shape.shape; | |||||
} | |||||
public int[] dims => shape; | public int[] dims => shape; | ||||
/// <summary> | /// <summary> | ||||
@@ -9,6 +9,8 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public class TensorShape : Shape | public class TensorShape : Shape | ||||
{ | { | ||||
public int[] dims => Dimensions; | |||||
public TensorShape(TensorShapeProto proto) | public TensorShape(TensorShapeProto proto) | ||||
{ | { | ||||
if (proto.UnknownRank) return; | if (proto.UnknownRank) return; | ||||
@@ -45,6 +47,29 @@ namespace Tensorflow | |||||
throw new NotImplementedException("TensorShape is_compatible_with"); | throw new NotImplementedException("TensorShape is_compatible_with"); | ||||
} | } | ||||
public TensorShape with_rank_at_least(int rank) | |||||
{ | |||||
if (rank != this.NDim) | |||||
throw new ValueError($"Shape {this} must have rank at least {rank}"); | |||||
else | |||||
return this; | |||||
} | |||||
/// <summary> | |||||
/// Returns the concatenation of the dimension in `self` and `other`. | |||||
/// </summary> | |||||
/// <param name="other"></param> | |||||
/// <returns></returns> | |||||
public TensorShape concatenate(int[] other_) | |||||
{ | |||||
var other = new TensorShape(other_); | |||||
if (NDim < 0 || other.NDim < 0) | |||||
return new TensorShape(); | |||||
else | |||||
return new TensorShape(NDim + other.NDim); | |||||
} | |||||
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | ||||
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | ||||
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | ||||
@@ -223,31 +223,27 @@ namespace Tensorflow.Util | |||||
private static void _flatten_recursive<T>(T obj, List<T> list) | private static void _flatten_recursive<T>(T obj, List<T> list) | ||||
{ | { | ||||
if (obj is string) | |||||
{ | |||||
list.Add(obj); | |||||
return; | |||||
} | |||||
if (obj is IDictionary) | |||||
{ | |||||
var dict = obj as IDictionary; | |||||
foreach (var key in _sorted(dict)) | |||||
_flatten_recursive((T)dict[key], list); | |||||
return; | |||||
} | |||||
if (obj is NDArray) | |||||
{ | |||||
list.Add(obj); | |||||
return; | |||||
} | |||||
if (obj is IEnumerable) | |||||
switch(obj) | |||||
{ | { | ||||
var structure = obj as IEnumerable; | |||||
foreach (var child in structure) | |||||
_flatten_recursive((T)child, list); | |||||
return; | |||||
case IDictionary dict: | |||||
foreach (var key in _sorted(dict)) | |||||
_flatten_recursive((T)dict[key], list); | |||||
break; | |||||
case String str: | |||||
list.Add(obj); | |||||
break; | |||||
case NDArray nd: | |||||
list.Add(obj); | |||||
break; | |||||
case IEnumerable structure: | |||||
foreach (var child in structure) | |||||
_flatten_recursive((T)child, list); | |||||
break; | |||||
default: | |||||
list.Add(obj); | |||||
break; | |||||
} | } | ||||
list.Add(obj); | |||||
} | } | ||||
@@ -314,6 +314,11 @@ namespace Tensorflow | |||||
return uid_number++; | return uid_number++; | ||||
} | } | ||||
public static void colocate_with(bool ignore_existing = false) | |||||
{ | |||||
_colocate_with_for_gradient(null, null, ignore_existing); | |||||
} | |||||
public static void colocate_with(Operation op, bool ignore_existing = false) | public static void colocate_with(Operation op, bool ignore_existing = false) | ||||
{ | { | ||||
_colocate_with_for_gradient(op, null, ignore_existing); | _colocate_with_for_gradient(op, null, ignore_existing); | ||||
@@ -40,7 +40,7 @@ Before running verify you installed CUDA and cuDNN | |||||
https://www.tensorflow.org/install/source_windows | https://www.tensorflow.org/install/source_windows | ||||
pacman -S git patch unzip | |||||
`pacman -S git patch unzip` | |||||
1. Build static library | 1. Build static library | ||||
@@ -42,7 +42,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
int n_channels = 1; | int n_channels = 1; | ||||
// Hyper-parameters | // Hyper-parameters | ||||
int epochs = 10; | |||||
int epochs = 5; // accuracy > 98% | |||||
int batch_size = 100; | int batch_size = 100; | ||||
float learning_rate = 0.001f; | float learning_rate = 0.001f; | ||||
Datasets<DataSetMnist> mnist; | Datasets<DataSetMnist> mnist; | ||||
@@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
Test(sess); | Test(sess); | ||||
}); | }); | ||||
return loss_test < 0.09 && accuracy_test > 0.95; | |||||
return loss_test < 0.05 && accuracy_test > 0.98; | |||||
} | } | ||||
public Graph BuildGraph() | public Graph BuildGraph() | ||||