From 1a025e622376464dfacb63beefe830cf74a602aa Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 5 Sep 2019 22:59:17 -0500 Subject: [PATCH] add moving_averages, fix ExponentialMovingAverage --- TensorFlow.NET.sln | 6 -- src/TensorFlowNET.Core/Graphs/Graph.cs | 15 ++++- .../Operations/Operation.Output.cs | 15 +++++ .../Operations/c_api.ops.cs | 2 +- .../Train/ExponentialMovingAverage.cs | 36 ++++++++-- src/TensorFlowNET.Core/Train/SlotCreator.cs | 20 +++++- .../Train/moving_averages.cs | 32 +++++++++ .../Variables/RefVariable.cs | 66 ++++++++++++++++--- .../Variables/gen_state_ops.py.cs | 3 +- src/TensorFlowNET.Core/ops.GraphKeys.cs | 10 ++- .../ImageProcessing/YOLO/Main.cs | 40 ++++++++++- .../ImageProcessing/YOLO/YOLOv3.cs | 35 +++++++++- .../TensorFlowNET.Examples.GPU.csproj | 5 +- .../TensorFlowNET.Examples.csproj | 5 +- .../TensorFlowNET.UnitTest.csproj | 5 +- 15 files changed, 261 insertions(+), 34 deletions(-) create mode 100644 src/TensorFlowNET.Core/Train/moving_averages.cs diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 8d230d26..ca14ecbd 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -9,8 +9,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.Core", "src\KerasNET.Core\Keras.Core.csproj", "{902E188F-A953-43B4-9991-72BAB1697BC3}" -EndProject Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowBenchmark", "src\TensorFlowNet.Benchmarks\TensorFlowBenchmark.csproj", "{68861442-971A-4196-876E-C9330F0B3C54}" @@ -41,10 +39,6 @@ Global {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU - {902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.Build.0 = Debug|Any CPU - {902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.ActiveCfg = Release|Any CPU - {902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.Build.0 = Release|Any CPU {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 0dfb68db..cad7a5a6 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -420,7 +420,20 @@ namespace Tensorflow public List get_collection(string name, string scope = null) { - return _collections.ContainsKey(name) ? _collections[name] as List : new List(); + List t = default; + var collection = _collections.ContainsKey(name) ? _collections[name] : new List(); + switch (collection) + { + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + default: + throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); + } + return t; } public object get_collection_ref(string name) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 62c8f378..9701d77a 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -17,6 +17,7 @@ using System; using System.Linq; using System.Runtime.InteropServices; +using static Tensorflow.Binding; namespace Tensorflow { @@ -48,6 +49,20 @@ namespace Tensorflow public TF_Output this[int index] => _tf_output(index); + /// + /// List this operation's output types. + /// + public TF_DataType[] _output_types + { + get + { + var output_types = range(NumOutputs) + .Select(i => OutputType(i)) + .ToArray(); + return output_types; + } + } + public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index fa24b0ef..a23cd406 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -198,7 +198,7 @@ namespace Tensorflow /// int /// [DllImport(TensorFlowLibName)] - public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers); + public static extern int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers); [DllImport(TensorFlowLibName)] public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); diff --git a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs index e129edce..64068c18 100644 --- a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs +++ b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Train bool _zero_debias; string _name; public string name => _name; - List _averages; + Dictionary _averages; public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, string name = "ExponentialMovingAverage") @@ -22,7 +22,7 @@ namespace Tensorflow.Train _num_updates = num_updates; _zero_debias = zero_debias; _name = name; - _averages = new List(); + _averages = new Dictionary(); } /// @@ -37,16 +37,38 @@ namespace Tensorflow.Train foreach(var var in var_list) { - if (!_averages.Contains(var)) + if (!_averages.ContainsKey(var)) { ops.init_scope(); - var slot = new SlotCreator(); - var.initialized_value(); - // var avg = slot.create_zeros_slot + var slot_creator = new SlotCreator(); + var value = var.initialized_value(); + var avg = slot_creator.create_slot(var, + value, + name, + colocate_with_primary: true); + ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var); + _averages[var] = avg; } } - throw new NotImplementedException(""); + return tf_with(ops.name_scope(name), scope => + { + var decay = ops.convert_to_tensor(_decay, name: "decay"); + if (_num_updates.HasValue) + { + throw new NotImplementedException("ExponentialMovingAverage.apply"); + } + + var updates = new List(); + foreach (var var in var_list) + { + var zero_debias = false;// _averages[var] in zero_debias_true + var ama = moving_averages.assign_moving_average(_averages[var], var, decay, zero_debias: zero_debias); + updates.Add(ama); + } + + return control_flow_ops.group(updates.ToArray(), name: scope); + }); } } } diff --git a/src/TensorFlowNET.Core/Train/SlotCreator.cs b/src/TensorFlowNET.Core/Train/SlotCreator.cs index 29e073c7..1334b4bd 100644 --- a/src/TensorFlowNET.Core/Train/SlotCreator.cs +++ b/src/TensorFlowNET.Core/Train/SlotCreator.cs @@ -22,6 +22,24 @@ namespace Tensorflow.Train { public class SlotCreator { + /// + /// Create a slot initialized to the given value. + /// + /// + /// + /// + /// + /// + public RefVariable create_slot(RefVariable primary, Tensor val, string name, bool colocate_with_primary = true) + { + var validate_shape = val.TensorShape.is_fully_defined(); + var prefix = primary.op.name; + return tf_with(tf.variable_scope(name: null, prefix + "/" + name), delegate + { + return _create_slot_var(primary, val, "", validate_shape, null, TF_DataType.DtInvalid); + }); + } + /// /// Create a slot initialized to 0 with same shape as the primary object. /// @@ -73,7 +91,7 @@ namespace Tensorflow.Train /// /// /// - private RefVariable _create_slot_var(VariableV1 primary, IInitializer val, string scope, bool validate_shape, + private RefVariable _create_slot_var(VariableV1 primary, object val, string scope, bool validate_shape, TensorShape shape, TF_DataType dtype) { bool use_resource = primary is ResourceVariable; diff --git a/src/TensorFlowNET.Core/Train/moving_averages.cs b/src/TensorFlowNET.Core/Train/moving_averages.cs new file mode 100644 index 00000000..5aee7901 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/moving_averages.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class moving_averages + { + /// + /// Compute the moving average of a variable. + /// + /// + /// + /// + /// + /// + /// + public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay, + bool zero_debias = true, string name = null) + { + tf_with(ops.name_scope(name, "", new { variable, value, decay }), scope => + { + decay = ops.convert_to_tensor(1.0f - decay, name: "decay"); + if (decay.dtype != variable.dtype.as_base_dtype()) + decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); + }); + + throw new NotImplementedException("assign_moving_average"); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index e0e3e0f7..1f7ca41b 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -17,6 +17,7 @@ using Google.Protobuf; using System; using System.Collections.Generic; +using System.Linq; using static Tensorflow.Binding; namespace Tensorflow @@ -176,7 +177,7 @@ namespace Tensorflow // If 'initial_value' makes use of other variables, make sure we don't // have an issue if these other variables aren't initialized first by // using their initialized_value() method. - var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value); + var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; @@ -215,9 +216,9 @@ namespace Tensorflow /// Attempt to guard against dependencies on uninitialized variables. /// /// - private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value) + private Tensor _try_guard_against_uninitialized_dependencies(string name, Tensor initial_value) { - return _safe_initial_value_from_tensor(initial_value, new Dictionary()); + return _safe_initial_value_from_tensor(name, initial_value, op_cache: new Dictionary()); } /// @@ -226,19 +227,19 @@ namespace Tensorflow /// A `Tensor`. The tensor to replace. /// A dict mapping operation names to `Operation`s. /// A `Tensor` compatible with `tensor`. - private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary op_cache) + private Tensor _safe_initial_value_from_tensor(string name, Tensor tensor, Dictionary op_cache) { var op = tensor.op; var new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null; if(new_op == null) { - new_op = _safe_initial_value_from_op(op, op_cache); + new_op = _safe_initial_value_from_op(name, op, op_cache); op_cache[op.name] = new_op; } return new_op.outputs[tensor.value_index]; } - private Operation _safe_initial_value_from_op(Operation op, Dictionary op_cache) + private Operation _safe_initial_value_from_op(string name, Operation op, Dictionary op_cache) { var op_type = op.node_def.Op; switch (op_type) @@ -250,13 +251,50 @@ namespace Tensorflow case "Variable": case "VariableV2": case "VarHandleOp": - break; + var initialized_value = _find_initialized_value_for_variable(op); + return initialized_value == null ? op : initialized_value.op; } // Recursively build initializer expressions for inputs. + var modified = false; + var new_op_inputs = new List(); + foreach (var op_input in op.inputs) + { + var new_op_input = _safe_initial_value_from_tensor(name, op_input as Tensor, op_cache); + new_op_inputs.Add(new_op_input); + modified = modified || new_op_input != op_input; + } + + // If at least one input was modified, replace the op. + if (modified) + { + var new_op_type = op_type; + if (new_op_type == "RefSwitch") + new_op_type = "Switch"; + var new_op_name = op.node_def.Name + "_" + name; + new_op_name = new_op_name.Replace(":", "_"); + var attrs = new Dictionary(); + attrs[op.node_def.Name] = op.node_def.Attr.ElementAt(0).Value; + /*return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types, + name: new_op_name, attrs: attrs);*/ + } return op; } + private Operation _find_initialized_value_for_variable(Operation variable_op) + { + var var_names = new[] { variable_op.node_def.Name, variable_op.node_def.Name + ":0" }; + foreach(var collection_name in new[]{tf.GraphKeys.GLOBAL_VARIABLES, + tf.GraphKeys.LOCAL_VARIABLES }) + { + foreach (var var in variable_op.graph.get_collection(collection_name)) + if (var_names.Contains(var.name)) + return var.initialized_value(); + } + + return null; + } + /// /// Assigns a new value to the variable. /// @@ -318,6 +356,15 @@ namespace Tensorflow return array_ops.identity(_variable, name: "read"); } + /// + /// Returns the Tensor used as the initial value for the variable. + /// + /// + private ITensorOrOperation initial_value() + { + return _initial_value; + } + public Tensor is_variable_initialized(RefVariable variable) { return state_ops.is_variable_initialized(variable); @@ -326,10 +373,9 @@ namespace Tensorflow public Tensor initialized_value() { ops.init_scope(); - throw new NotImplementedException(""); - /*return control_flow_ops.cond(is_variable_initialized(this), + return control_flow_ops.cond(is_variable_initialized(this), read_value, - () => initial_value);*/ + initial_value); } } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 24cb11f5..9c006170 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -149,7 +149,8 @@ namespace Tensorflow public static Tensor is_variable_initialized(RefVariable @ref, string name = null) { - throw new NotImplementedException(""); + var _op = _op_def_lib._apply_op_helper("IsVariableInitialized", name: name, args: new { @ref }); + return _op.output; } } } diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index c5a06433..dad81af9 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -52,6 +52,8 @@ namespace Tensorflow /// public const string LOSSES_ = "losses"; + public const string MOVING_AVERAGE_VARIABLES = "moving_average_variables"; + /// /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. @@ -100,6 +102,12 @@ namespace Tensorflow /// public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; + /// + /// Key to collect local variables that are local to the machine and are not + /// saved/restored. + /// + public string LOCAL_VARIABLES = "local_variables"; + /// /// Key to collect losses /// @@ -109,7 +117,7 @@ namespace Tensorflow /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. /// - public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; + public string GLOBAL_VARIABLES = GLOBAL_VARIABLES_; public string TRAIN_OP => TRAIN_OP_; diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index f10ac7d1..fa81af38 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text; using Tensorflow; using static Tensorflow.Binding; @@ -47,6 +48,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO YOLOv3 model; VariableV1[] net_var; Tensor giou_loss, conf_loss, prob_loss; + RefVariable global_step; + Tensor learn_rate; + Tensor loss; #endregion public bool Run() @@ -98,11 +102,45 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO (giou_loss, conf_loss, prob_loss) = model.compute_loss( label_sbbox, label_mbbox, label_lbbox, true_sbboxes, true_mbboxes, true_lbboxes); + loss = giou_loss + conf_loss + prob_loss; }); + Tensor global_step_update = null; + tf_with(tf.name_scope("learn_rate"), scope => + { + global_step = tf.Variable(1.0, dtype: tf.float64, trainable: false, name: "global_step"); + var warmup_steps = tf.constant(warmup_periods * steps_per_period, + dtype: tf.float64, name: "warmup_steps"); + var train_steps = tf.constant((first_stage_epochs + second_stage_epochs) * steps_per_period, + dtype: tf.float64, name: "train_steps"); + + learn_rate = tf.cond( + pred: global_step < warmup_steps, + true_fn: delegate + { + return global_step / warmup_steps * learn_rate_init; + }, + false_fn: delegate + { + return learn_rate_end + 0.5 * (learn_rate_init - learn_rate_end) * + (1 + tf.cos( + (global_step - warmup_steps) / (train_steps - warmup_steps) * Math.PI)); + } + ); + + global_step_update = tf.assign_add(global_step, 1.0f); + }); + + Operation moving_ave = null; tf_with(tf.name_scope("define_weight_decay"), scope => { - var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables()); + var emv = tf.train.ExponentialMovingAverage(moving_ave_decay); + var vars = tf.trainable_variables().Select(x => (RefVariable)x).ToArray(); + moving_ave = emv.apply(vars); + }); + + tf_with(tf.name_scope("define_first_stage_train"), scope => + { }); return graph; diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs index 1cff167f..5ac85fe9 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs @@ -23,6 +23,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO Tensor conv_mbbox; Tensor conv_sbbox; Tensor pred_sbbox; + Tensor pred_mbbox; + Tensor pred_lbbox; public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_) { @@ -46,12 +48,12 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO tf_with(tf.variable_scope("pred_mbbox"), scope => { - pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + pred_mbbox = decode(conv_mbbox, anchors[1], strides[1]); }); tf_with(tf.variable_scope("pred_lbbox"), scope => { - pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); + pred_lbbox = decode(conv_lbbox, anchors[2], strides[2]); }); } @@ -144,6 +146,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO { Tensor giou_loss = null, conf_loss = null, prob_loss = null; (Tensor, Tensor, Tensor) loss_sbbox = (null, null, null); + (Tensor, Tensor, Tensor) loss_mbbox = (null, null, null); + (Tensor, Tensor, Tensor) loss_lbbox = (null, null, null); tf_with(tf.name_scope("smaller_box_loss"), delegate { @@ -151,6 +155,33 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO anchors: anchors[0], stride: strides[0]); }); + tf_with(tf.name_scope("medium_box_loss"), delegate + { + loss_mbbox = loss_layer(conv_mbbox, pred_mbbox, label_mbbox, true_mbbox, + anchors: anchors[1], stride: strides[1]); + }); + + tf_with(tf.name_scope("bigger_box_loss"), delegate + { + loss_lbbox = loss_layer(conv_lbbox, pred_lbbox, label_lbbox, true_lbbox, + anchors: anchors[2], stride: strides[2]); + }); + + tf_with(tf.name_scope("giou_loss"), delegate + { + giou_loss = loss_sbbox.Item1 + loss_mbbox.Item1 + loss_lbbox.Item1; + }); + + tf_with(tf.name_scope("conf_loss"), delegate + { + conf_loss = loss_sbbox.Item2 + loss_mbbox.Item2 + loss_lbbox.Item2; + }); + + tf_with(tf.name_scope("prob_loss"), delegate + { + prob_loss = loss_sbbox.Item3 + loss_mbbox.Item3 + loss_lbbox.Item3; + }); + return (giou_loss, conf_loss, prob_loss); } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj index 55e9b27d..710c9465 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj @@ -14,6 +14,10 @@ bin\release-gpu + + + + @@ -23,7 +27,6 @@ - diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index c675bedc..652d9bd1 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -10,6 +10,10 @@ DEBUG;TRACE + + + + @@ -19,7 +23,6 @@ - diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 702bb2ae..58ab60cf 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -23,6 +23,10 @@ true + + + + @@ -32,7 +36,6 @@ -