diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 4d112805..fdf00590 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -9,4 +9,4 @@ community_bridge: # Replace with a single Community Bridge project-name e.g., cl liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username -custom: ['https://paypal.me/pools/c/8fK9eKwbbL']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] +custom: ['https://bit.ly/2op1mu5']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index d6de08f4..3a790327 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -43,7 +43,8 @@ namespace Tensorflow public ExponentialMovingAverage ExponentialMovingAverage(float decay) => new ExponentialMovingAverage(decay); - public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); + public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5) + => new Saver(var_list: var_list, max_to_keep: max_to_keep); public string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); @@ -54,7 +55,7 @@ namespace Tensorflow clear_devices, import_scope).Item1; - public (MetaGraphDef, Dictionary) export_meta_graph(string filename = "", + public (MetaGraphDef, Dictionary) export_meta_graph(string filename = "", bool as_text = false, bool clear_devices = false, bool clear_extraneous_savers = false, diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs similarity index 98% rename from src/TensorFlowNET.Core/Framework/meta_graph.py.cs rename to src/TensorFlowNET.Core/Framework/meta_graph.cs index aaa5d225..3f5a2777 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -167,7 +167,7 @@ namespace Tensorflow /// /// /// - public static (MetaGraphDef, Dictionary) export_scoped_meta_graph(string filename = "", + public static (MetaGraphDef, Dictionary) export_scoped_meta_graph(string filename = "", GraphDef graph_def = null, bool as_text = false, string unbound_inputs_col_name = "unbound_inputs", @@ -179,8 +179,8 @@ namespace Tensorflow { var graph = ops.get_default_graph(); - var var_list = new Dictionary(); - var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES); + var var_list = new Dictionary(); + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES); if (variables != null) { diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index e98ec21e..f07d2825 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -190,6 +190,26 @@ namespace Tensorflow.Gradients return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; } + [RegisterGradient("Pad")] + public static Tensor[] _PadGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var a = op.inputs[1]; + var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); + var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); + + // Make it a 1-D tensor. + var begin = array_ops.reshape(pad_before, new[] { -1 }); + var sizes = array_ops.shape(x); + var x_grad = array_ops.slice(grad, begin, sizes); + + if (len(op.inputs) == 3) + return new Tensor[] { x_grad, null, null }; + else + return new Tensor[] { x_grad, null }; + } + [RegisterGradient("Squeeze")] public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs index 670731e0..76b6a7b5 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -36,56 +36,54 @@ namespace Tensorflow.Gradients /// /// [RegisterGradient("Switch")] - public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) + public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads) { + var grad = grads[0]; + var graph = ops.get_default_graph(); + var op_ctxt = op._get_control_flow_context(); + var grad_ctxt = graph._get_control_flow_context(); + switch (op_ctxt) + { + case WhileContext cwhile: + throw new NotImplementedException("_SwitchGrad WhileContext"); + case CondContext ccond: + { + var zero_grad = grads[1 - op_ctxt.branch]; + // At this point, we have created zero_grad guarded by the right switch. + // Unfortunately, we may still get None here for not trainable data types. + if(zero_grad == null) + { + throw new NotImplementedException("_SwitchGrad CondContext zero_grad"); + } + + return new Tensor[] + { + merge(grads, name: "cond_grad")[0], + null + }; + } + default: + throw new NotImplementedException("_SwitchGrad WhileContext"); + } throw new NotImplementedException("_SwitchGrad"); - //graph = ops.get_default_graph() - //# pylint: disable=protected-access - //op_ctxt = op._get_control_flow_context() - //grad_ctxt = graph._get_control_flow_context() - //# pylint: enable=protected-access - //if isinstance(op_ctxt, WhileContext): - // merge_grad = grad_ctxt.grad_state.switch_map.get(op) - // if merge_grad is not None: - // # This is the second time this Switch is visited. It comes from - // # the non-exit branch of the Switch, so update the second input - // # to the Merge. - // # TODO(yuanbyu): Perform shape inference with this new input. - // if grad[1] is not None: - // # pylint: disable=protected-access - // control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], - // enforce_shape_invariant=False) - // # pylint: enable=protected-access - // return None, None - // elif grad[0] is not None: - // # This is the first time this Switch is visited. It comes from - // # the Exit branch, which is grad[0]. grad[1] is empty at this point. - // # Use grad[0] for both inputs to merge for now, but update the second - // # input of merge when we see this Switch the second time. - // merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] - // grad_ctxt.grad_state.switch_map[op] = merge_grad - // return merge_grad, None - // else: - // # This is the first time this Switch is visited. It comes from the - // # Identity branch. Such a Switch has `None` gradient for the Exit branch, - // # meaning the output is not differentiable. - // return None, None - //elif isinstance(op_ctxt, CondContext): - // zero_grad = grad[1 - op_ctxt.branch] - // # At this point, we have created zero_grad guarded by the right switch. - // # Unfortunately, we may still get None here for not trainable data types. - // if zero_grad is None: - // # For resource variables we get None always on the other branch, so bypass - // # this. - // if op.inputs[0].dtype == dtypes.resource: - // return merge( - // [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None - // return None, None - // return merge(grad, name="cond_grad")[0], None - //else: - // false_grad = switch(grad[0], op.inputs[1])[0] - // true_grad = switch(grad[1], op.inputs[1])[1] - // return merge([false_grad, true_grad])[0], None + } + + /// + /// Returns the value of an available element of `inputs`. + /// + /// + /// + /// + internal static Tensor[] merge(Tensor[] inputs, string name = null) + { + return tf_with(ops.name_scope(name, "Merge", inputs), scope => + { + name = scope; + if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length) + return gen_control_flow_ops.ref_merge(inputs, name: name); + else + return gen_control_flow_ops.merge(inputs, name: name); + }); } /// diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 7252301a..15ad511b 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -108,10 +108,7 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(tf.get_default_graph()._nodes_by_name.Count > 18505) - { - } _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); //if (loop_state != null) //loop_state.EnterGradWhileContext(op, before: true); @@ -216,7 +213,7 @@ namespace Tensorflow in_grad.Tag == null && // maybe a IndexedSlice t_in.dtype != TF_DataType.TF_RESOURCE) { - in_grad.shape = t_in.shape; + in_grad.set_shape(t_in.TensorShape); } _SetGrad(grads, t_in, in_grad); diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs new file mode 100644 index 00000000..23b19de9 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs @@ -0,0 +1,54 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + [RegisterGradient("image_grad")] + public class image_grad + { + [RegisterGradient("ResizeNearestNeighbor")] + public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var image = op.inputs[0]; + var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); + Tensor image_shape = null; + if (shape.is_fully_defined()) + throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); + else + image_shape = array_ops.shape(image)["1:3"]; + + grad = gen_image_ops.resize_nearest_neighbor_grad( + grad, + image_shape, + align_corners: op.get_attr("align_corners"), + half_pixel_centers: op.get_attr("half_pixel_centers")); + + return new Tensor[] + { + grad, + null + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index 7b5d2ea7..967b3c21 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -166,6 +166,94 @@ namespace Tensorflow.Gradients }; } + [RegisterGradient("FusedBatchNorm")] + public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) + => _BaseFusedBatchNormGrad(op, 0, grads); + + /// + /// Return the gradients for the 3 inputs of BatchNorm. + /// + /// + /// + /// + /// + public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads) + { + var x = op.inputs[0]; + var grad_y = grads[0]; + var scale = op.inputs[1]; + var epsilon = op.get_attr("epsilon"); + var data_format = op.get_attr("data_format"); + var is_training = op.get_attr("is_training"); + Func grad_fun = null; + + switch (version) + { + case 2: + throw new NotImplementedException(""); + case 1: + throw new NotImplementedException(""); + default: + grad_fun = gen_nn_ops.fused_batch_norm_grad; + break; + } + + if (is_training) + { + return grad_fun(new FusedBatchNormParams + { + YBackprop = grad_y, + X = x, + Scale = scale, + ReserveSpace1 = op.outputs[3], + ReserveSpace2 = op.outputs[4], + ReserveSpace3 = version == 2 ? op.outputs[5] : null, + Epsilon = epsilon, + DataFormat = data_format, + IsTraining = is_training + }); + } + else + { + var pop_mean = op.inputs[3]; + var pop_var = op.inputs[4]; + if (data_format == "NCHW") + throw new NotImplementedException(""); + + var results = grad_fun(new FusedBatchNormParams + { + YBackprop = grad_y, + X = x, + Scale = scale, + ReserveSpace1 = op.outputs[3], + ReserveSpace2 = op.outputs[4], + ReserveSpace3 = version == 2 ? op.outputs[5] : null, + Epsilon = epsilon, + DataFormat = data_format, + IsTraining = is_training + }); + + var (dx, dscale, doffset) = (results[0], results[1], results[2]); + if (data_format == "NCHW") + throw new NotImplementedException(""); + + return new Tensor[] + { + dx, + dscale, + doffset, + null, + null + }; + } + } + + [RegisterGradient("BatchNormWithGlobalNormalization")] + public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads) + { + throw new NotImplementedException("BatchNormWithGlobalNormalization"); + } + private static bool IsZero(Tensor g) { if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 8b7cf7a0..7119a4ad 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -440,6 +440,9 @@ namespace Tensorflow case List list: t = list.Select(x => (T)(object)x).ToList(); break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; case List list: t = list.Select(x => (T)(object)x).ToList(); break; diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index aa314efb..ce2295c8 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -27,20 +27,6 @@ namespace Tensorflow.Operations /// public class CondContext : ControlFlowContext, IProtoBuf { - - - /// - /// The boolean tensor for the cond predicate - /// - private Tensor _pred; - - public Tensor pred => _pred; - - /// - /// 0 or 1 representing this branch - /// - private int _branch; - private Dictionary _external_values = new Dictionary(); /// diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 2a76c52c..c076cbc7 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -45,10 +45,19 @@ namespace Tensorflow.Operations /// The predicate tensor in this branch /// protected Tensor _pivot; - public Tensor pivot - { - get => _pivot; - } + public Tensor pivot => _pivot; + + /// + /// The boolean tensor for the cond predicate + /// + protected Tensor _pred; + public Tensor pred => _pred; + + /// + /// 0 or 1 representing this branch + /// + protected int _branch; + public int branch => _branch; protected Stack _context_stack; protected ControlFlowContext _outer_context; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs new file mode 100644 index 00000000..689fa5fe --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public class FusedBatchNormParams + { + public string Name { get; set; } + public Tensor YBackprop { get; set; } + public Tensor X { get; set; } + public Tensor Scale { get; set; } + public Tensor ReserveSpace1 { get; set; } + public Tensor ReserveSpace2 { get; set; } + public Tensor ReserveSpace3 { get; set; } + public float Epsilon { get; set; } + public string DataFormat { get; set; } + public bool IsTraining { get; set; } + + public FusedBatchNormParams() + { + Epsilon = 0.0001f; + DataFormat = "NHWC"; + IsTraining = true; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 82085683..4e376d19 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -156,6 +156,35 @@ namespace Tensorflow.Operations return op.output; } + /// + /// Gradient for batch normalization. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params) + { + var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new + { + y_backprop = @params.YBackprop, + x = @params.X, + scale = @params.Scale, + reserve_space_1 = @params.ReserveSpace1, + reserve_space_2 = @params.ReserveSpace2, + epsilon = @params.Epsilon, + data_format = @params.DataFormat, + is_training = @params.IsTraining + }); + return op.outputs; + } + public static Tensor[] fused_batch_norm(Tensor x, Tensor scale, Tensor offset, diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 6d6403c9..c80e99f6 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -53,7 +53,7 @@ namespace Tensorflow for (int i = 0; i < NumInputs; i++) { var tf_output = Input(i); - var op = new Operation(tf_output.oper); + var op = GetOperation(tf_output.oper); retval[i] = op.outputs[tf_output.index]; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs new file mode 100644 index 00000000..6f6c8226 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs @@ -0,0 +1,41 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; + +namespace Tensorflow +{ + public partial class Operation + { + // cache the mapping between managed and unmanaged op + // some data is stored in managed instance, so when + // create Operation by IntPtr, it will lost some data. + private static Dictionary OpInstances = new Dictionary(); + + /// + /// Get operation by handle + /// + /// + /// + public Operation GetOperation(IntPtr handle) + { + return OpInstances.ContainsKey(handle) ? + OpInstances[handle] : + new Operation(handle); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index b6811917..6118602c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -84,9 +84,10 @@ namespace Tensorflow _control_flow_context = _graph._get_control_flow_context(); // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. + OpInstances[_handle] = this; } - public Operation(Graph g, string opType, string oper_name) + /*public Operation(Graph g, string opType, string oper_name) { _graph = g; @@ -102,7 +103,7 @@ namespace Tensorflow // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); - } + }*/ /// /// Creates an `Operation`. @@ -151,11 +152,6 @@ namespace Tensorflow } } - if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1") - { - - } - // Dict mapping op name to file and line information for op colocation // context managers. _control_flow_context = graph._get_control_flow_context(); @@ -164,7 +160,7 @@ namespace Tensorflow if (op_def == null) op_def = g.GetOpDef(node_def.Op); - var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); + var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); // Initialize self._outputs. @@ -180,6 +176,8 @@ namespace Tensorflow if (_handle != IntPtr.Zero) _control_flow_post_processing(); + + OpInstances[_handle] = this; } public void run(FeedItem[] feed_dict = null, Session session = null) @@ -220,6 +218,9 @@ namespace Tensorflow return grouped_inputs.ToArray(); } + public T get_attr(string name) + => (T)get_attr(name); + public object get_attr(string name) { AttrValue x = null; diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 3e2276c6..12094e41 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -611,7 +611,7 @@ namespace Tensorflow }); } - public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) => gen_array_ops.slice(input, begin, size, name: name); public static Tensor stack(object values, int axis = 0, string name = "stack") diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs similarity index 95% rename from src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs rename to src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 571457b9..e8b5f0eb 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -518,7 +518,7 @@ namespace Tensorflow inputs = inputs.Select(inp => ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) .ToArray(); - return gen_control_flow_ops.merge(inputs, name).Item1; + return gen_control_flow_ops.merge(inputs, name)[0]; }); } @@ -557,8 +557,31 @@ namespace Tensorflow throw new NotImplementedException("ZerosLikeOutsideLoop"); return array_ops.zeros_like(val, optimize: false); } - - throw new NotImplementedException("ZerosLikeOutsideLoop"); + else + { + var op_ctxt = op._get_control_flow_context(); + if(op_ctxt != null) + { + // We are in a cond context. Use a switch to create zeros only when needed. + var pred = op_ctxt.pred; + var branch = op_ctxt.branch; + var switch_val = @switch(op.inputs[0], pred)[1 - branch]; + var pivot = array_ops.identity(switch_val); + if (val.dtype == dtypes.resource) + throw new NotImplementedException(""); + var zeros_shape = array_ops.shape_internal(switch_val, optimize: false); + // Ensure ops created within array_ops.zeros are dominated by switch in + // cond context. + return tf_with(ops.control_dependencies(new[] { pivot }), delegate + { + return array_ops.zeros(zeros_shape, dtype: val.dtype); + }); + } + else + { + return array_ops.zeros_like(val, optimize: false); + } + } } /// diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 59b43766..36837477 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -475,7 +475,7 @@ namespace Tensorflow return op.output; } - public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) { var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs similarity index 94% rename from src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs rename to src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs index 163a50e4..5f0ceb48 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs @@ -148,11 +148,18 @@ namespace Tensorflow return new []{_op.outputs[0], _op.outputs[1]}; } - public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) + public static Tensor[] ref_merge(Tensor[] inputs, string name = null) + { + var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); + + return _op.outputs; + } + + public static Tensor[] merge(Tensor[] inputs, string name = null) { var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); - return (_op.outputs[0], _op.outputs[1]); + return _op.outputs; } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs index b052d9d6..143d4fe8 100644 --- a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs @@ -183,5 +183,19 @@ namespace Tensorflow return op.output; } + + public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tsize size, bool align_corners = false, + bool half_pixel_centers = false, string name = null) + { + var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new + { + grads, + size, + align_corners, + half_pixel_centers + }); + + return op.output; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index f3ad2efd..fb8e2457 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -105,10 +105,13 @@ namespace Tensorflow if (_handle == IntPtr.Zero) { - var status = new Status(); - c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); - status.Check(); - } else + using (var status = new Status()) + { + c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); + status.Check(); + } + } + else { for (int i = 0; i < rank; i++) dims[i] = c_api.TF_Dim(_handle, i); @@ -119,14 +122,15 @@ namespace Tensorflow set { - var status = new Status(); - - if (value == null) - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); - else - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); + using (var status = new Status()) + { + if (value == null) + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); + else + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); - status.Check(true); + status.Check(true); + } } } @@ -142,16 +146,7 @@ namespace Tensorflow /// public void set_shape(TensorShape shape) { - this.shape = (int[]) shape.dims.Clone(); - } - - /// - /// Updates the shape of this tensor. - /// - [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] - public void SetShape(TensorShape shape) - { - this.shape = (int[]) shape.dims.Clone(); + this.shape = shape.rank > 0 ? shape.dims : null; } /// diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index fe0dc5e9..3827229d 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -33,6 +33,7 @@ namespace Tensorflow public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static TF_DataType resource = TF_DataType.TF_RESOURCE; /// /// diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 6ab9feae..846de1ea 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -227,29 +227,30 @@ namespace Tensorflow throw new NotImplementedException("_create_c_op"); } - var status = new Status(); - - // Add control inputs - foreach (var control_input in control_inputs) - c_api.TF_AddControlInput(op_desc, control_input); - - // Add attrs - foreach (var attr in node_def.Attr) + using (var status = new Status()) { - var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. - var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak - Marshal.Copy(bytes, 0, proto, bytes.Length); - uint len = (uint) bytes.Length; - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); + // Add control inputs + foreach (var control_input in control_inputs) + c_api.TF_AddControlInput(op_desc, control_input); - status.Check(true); - } + // Add attrs + foreach (var attr in node_def.Attr) + { + var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. + var protoHandle = Marshal.AllocHGlobal(bytes.Length); + Marshal.Copy(bytes, 0, protoHandle, bytes.Length); + uint len = (uint)bytes.Length; + c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); + status.Check(true); + Marshal.FreeHGlobal(protoHandle); + } - var c_op = c_api.TF_FinishOperation(op_desc, status); + var c_op = c_api.TF_FinishOperation(op_desc, status); - status.Check(true); + status.Check(true); - return c_op; + return c_op; + } } } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index 73f5c213..8cd4a252 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -54,6 +54,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO List first_stage_trainable_var_list; Operation train_op_with_frozen_variables; Operation train_op_with_all_variables; + Saver loader; + Saver saver; #endregion public bool Run() @@ -74,7 +76,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO public void Train(Session sess) { - + sess.run(tf.global_variables_initializer()); + print($"=> Restoring weights from: {cfg.TRAIN.INITIAL_WEIGHT} ... "); + loader.restore(sess, cfg.TRAIN.INITIAL_WEIGHT); } public void Test(Session sess) @@ -184,6 +188,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO }); }); + tf_with(tf.name_scope("loader_and_saver"), delegate + { + loader = tf.train.Saver(net_var); + saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10); + }); + + tf_with(tf.name_scope("summary"), delegate + { + tf.summary.scalar("learn_rate", learn_rate); + tf.summary.scalar("giou_loss", giou_loss); + tf.summary.scalar("conf_loss", conf_loss); + tf.summary.scalar("prob_loss", prob_loss); + tf.summary.scalar("total_loss", loss); + }); + return graph; } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs index b5c46151..39308da8 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs @@ -60,7 +60,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO public TrainConfig(string root) { _root = root; - INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt"); + INITIAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); } }