diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index d6de08f4..c1e76d11 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -54,7 +54,7 @@ namespace Tensorflow clear_devices, import_scope).Item1; - public (MetaGraphDef, Dictionary) export_meta_graph(string filename = "", + public (MetaGraphDef, Dictionary) export_meta_graph(string filename = "", bool as_text = false, bool clear_devices = false, bool clear_extraneous_savers = false, diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs similarity index 98% rename from src/TensorFlowNET.Core/Framework/meta_graph.py.cs rename to src/TensorFlowNET.Core/Framework/meta_graph.cs index aaa5d225..3f5a2777 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -167,7 +167,7 @@ namespace Tensorflow /// /// /// - public static (MetaGraphDef, Dictionary) export_scoped_meta_graph(string filename = "", + public static (MetaGraphDef, Dictionary) export_scoped_meta_graph(string filename = "", GraphDef graph_def = null, bool as_text = false, string unbound_inputs_col_name = "unbound_inputs", @@ -179,8 +179,8 @@ namespace Tensorflow { var graph = ops.get_default_graph(); - var var_list = new Dictionary(); - var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES); + var var_list = new Dictionary(); + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES); if (variables != null) { diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index e98ec21e..baa7145d 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -190,6 +190,26 @@ namespace Tensorflow.Gradients return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; } + [RegisterGradient("Pad")] + public static Tensor[] _PadGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var a = op.inputs[1]; + var pad_before = array_ops.slice(a, new[] { 0, 0 }, + new[] { array_ops.stack(new object[] { array_ops.rank(x), 1 }) }); + + // Make it a 1-D tensor. + var begin = array_ops.reshape(pad_before, new[] { -1 }); + var sizes = array_ops.shape(x); + var x_grad = array_ops.slice(grad, new[] { begin }, new[] { sizes }); + + if (len(op.inputs) == 3) + return new Tensor[] { x_grad, null, null }; + else + return new Tensor[] { x_grad, null }; + } + [RegisterGradient("Squeeze")] public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs index d8447163..76b6a7b5 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -36,56 +36,54 @@ namespace Tensorflow.Gradients /// /// [RegisterGradient("Switch")] - public Tensor[] _SwitchGrad(Operation op, Tensor[] grads) + public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads) { + var grad = grads[0]; + var graph = ops.get_default_graph(); + var op_ctxt = op._get_control_flow_context(); + var grad_ctxt = graph._get_control_flow_context(); + switch (op_ctxt) + { + case WhileContext cwhile: + throw new NotImplementedException("_SwitchGrad WhileContext"); + case CondContext ccond: + { + var zero_grad = grads[1 - op_ctxt.branch]; + // At this point, we have created zero_grad guarded by the right switch. + // Unfortunately, we may still get None here for not trainable data types. + if(zero_grad == null) + { + throw new NotImplementedException("_SwitchGrad CondContext zero_grad"); + } + + return new Tensor[] + { + merge(grads, name: "cond_grad")[0], + null + }; + } + default: + throw new NotImplementedException("_SwitchGrad WhileContext"); + } throw new NotImplementedException("_SwitchGrad"); - //graph = ops.get_default_graph() - //# pylint: disable=protected-access - //op_ctxt = op._get_control_flow_context() - //grad_ctxt = graph._get_control_flow_context() - //# pylint: enable=protected-access - //if isinstance(op_ctxt, WhileContext): - // merge_grad = grad_ctxt.grad_state.switch_map.get(op) - // if merge_grad is not None: - // # This is the second time this Switch is visited. It comes from - // # the non-exit branch of the Switch, so update the second input - // # to the Merge. - // # TODO(yuanbyu): Perform shape inference with this new input. - // if grad[1] is not None: - // # pylint: disable=protected-access - // control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], - // enforce_shape_invariant=False) - // # pylint: enable=protected-access - // return None, None - // elif grad[0] is not None: - // # This is the first time this Switch is visited. It comes from - // # the Exit branch, which is grad[0]. grad[1] is empty at this point. - // # Use grad[0] for both inputs to merge for now, but update the second - // # input of merge when we see this Switch the second time. - // merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] - // grad_ctxt.grad_state.switch_map[op] = merge_grad - // return merge_grad, None - // else: - // # This is the first time this Switch is visited. It comes from the - // # Identity branch. Such a Switch has `None` gradient for the Exit branch, - // # meaning the output is not differentiable. - // return None, None - //elif isinstance(op_ctxt, CondContext): - // zero_grad = grad[1 - op_ctxt.branch] - // # At this point, we have created zero_grad guarded by the right switch. - // # Unfortunately, we may still get None here for not trainable data types. - // if zero_grad is None: - // # For resource variables we get None always on the other branch, so bypass - // # this. - // if op.inputs[0].dtype == dtypes.resource: - // return merge( - // [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None - // return None, None - // return merge(grad, name="cond_grad")[0], None - //else: - // false_grad = switch(grad[0], op.inputs[1])[0] - // true_grad = switch(grad[1], op.inputs[1])[1] - // return merge([false_grad, true_grad])[0], None + } + + /// + /// Returns the value of an available element of `inputs`. + /// + /// + /// + /// + internal static Tensor[] merge(Tensor[] inputs, string name = null) + { + return tf_with(ops.name_scope(name, "Merge", inputs), scope => + { + name = scope; + if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length) + return gen_control_flow_ops.ref_merge(inputs, name: name); + else + return gen_control_flow_ops.merge(inputs, name: name); + }); } /// diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index a4508d3c..f10a5fed 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -108,7 +108,7 @@ namespace Tensorflow { // generate gradient subgraph for op. var op = queue.Dequeue(); - if(tf.get_default_graph()._nodes_by_name.Count > 18577) + if(tf.get_default_graph()._nodes_by_name.Count >= 20611) { } diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs new file mode 100644 index 00000000..23b19de9 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs @@ -0,0 +1,54 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + [RegisterGradient("image_grad")] + public class image_grad + { + [RegisterGradient("ResizeNearestNeighbor")] + public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var image = op.inputs[0]; + var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); + Tensor image_shape = null; + if (shape.is_fully_defined()) + throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); + else + image_shape = array_ops.shape(image)["1:3"]; + + grad = gen_image_ops.resize_nearest_neighbor_grad( + grad, + image_shape, + align_corners: op.get_attr("align_corners"), + half_pixel_centers: op.get_attr("half_pixel_centers")); + + return new Tensor[] + { + grad, + null + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 8b7cf7a0..7119a4ad 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -440,6 +440,9 @@ namespace Tensorflow case List list: t = list.Select(x => (T)(object)x).ToList(); break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; case List list: t = list.Select(x => (T)(object)x).ToList(); break; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs similarity index 99% rename from src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs rename to src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 54ccf590..e8b5f0eb 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -518,7 +518,7 @@ namespace Tensorflow inputs = inputs.Select(inp => ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) .ToArray(); - return gen_control_flow_ops.merge(inputs, name).Item1; + return gen_control_flow_ops.merge(inputs, name)[0]; }); } diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs similarity index 94% rename from src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs rename to src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs index 163a50e4..5f0ceb48 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs @@ -148,11 +148,18 @@ namespace Tensorflow return new []{_op.outputs[0], _op.outputs[1]}; } - public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) + public static Tensor[] ref_merge(Tensor[] inputs, string name = null) + { + var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); + + return _op.outputs; + } + + public static Tensor[] merge(Tensor[] inputs, string name = null) { var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); - return (_op.outputs[0], _op.outputs[1]); + return _op.outputs; } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs index b052d9d6..143d4fe8 100644 --- a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs @@ -183,5 +183,19 @@ namespace Tensorflow return op.output; } + + public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tsize size, bool align_corners = false, + bool half_pixel_centers = false, string name = null) + { + var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new + { + grads, + size, + align_corners, + half_pixel_centers + }); + + return op.output; + } } }