diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 80cc563a..954f5b11 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -19,14 +20,51 @@ namespace Tensorflow public static void _GradientsHelper(object ys, object xs, - List grad_ys = null, + object grad_ys = null, string name = "gradients", bool colocate_gradients_with_ops = false, bool gate_gradients = false, + object stop_gradients = null, Graph src_graph = null) { if (src_graph == null) src_graph = ops.get_default_graph(); + + var ys1 = _AsList(ys); + var xs1 = _AsList(xs); + List grad_ys1 = null; + List stop_gradients1 = stop_gradients == null ? new List() : _AsList(stop_gradients); + if (grad_ys == null) + grad_ys1 = ys1.Select(x => new Tensor(IntPtr.Zero)).ToList(); + else + grad_ys = _AsList(grad_ys); + + var all = new List(); + all.AddRange(ys1); + all.AddRange(xs1); + all.AddRange(stop_gradients1); + all.AddRange(grad_ys1); + + string grad_scope = ""; + using (var namescope = new ops.name_scope(name, "gradients", values: all)) + grad_scope = namescope; + } + + private static List _AsList(object ys) + { + List ret = null; + + switch (ys) + { + case Tensor value: + ret = new List { value }; + break; + case List value: + ret = value; + break; + } + + return ret; } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index e0956b09..d2ed0c99 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -23,7 +23,8 @@ namespace Tensorflow private int _next_id_counter; private List _unfetchable_ops = new List(); - private string _name_stack; + public string _name_stack = ""; + public string old_stack = ""; public string _graph_key; public Status Status { get; } @@ -168,23 +169,22 @@ namespace Tensorflow public string name_scope(string name) { + old_stack = _name_stack; + string new_stack = ""; + if (name.EndsWith("/")) - { new_stack = ops._name_from_scope_name(name); - } else - { new_stack = unique_name(name); - } _name_stack = new_stack; return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/"; } - public string unique_name(string name) + public string unique_name(string name, bool mark_as_used = true) { if (!String.IsNullOrEmpty(_name_stack)) { @@ -192,17 +192,45 @@ namespace Tensorflow } var name_key = name.ToLower(); + int i = 0; if (_names_in_use.ContainsKey(name_key)) { - _names_in_use[name_key]++; + foreach (var item in _names_in_use) + { + if (item.Key == name_key) + { + i = _names_in_use[name_key]; + break; + } + + i++; + } } - else + + if (mark_as_used) + if (_names_in_use.ContainsKey(name_key)) + _names_in_use[name_key]++; + else + _names_in_use[name_key] = i + 1; + + if (i > 0) { - _names_in_use[name_key] = 1; - return name; + var base_name_key = name_key; + + // Make sure the composed name key is not already used. + if (_names_in_use.ContainsKey(name_key)) + { + name_key = $"{base_name_key}_{i}"; + i += 1; + } + + if (mark_as_used) + _names_in_use[name_key] = 1; + + name = $"{name}_{i - 1}"; } - return $"{name}_{_names_in_use[name_key]}"; + return name; } public TF_Output[] ReturnOutputs(IntPtr results) diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 26bb6374..a5ee3b23 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -20,7 +20,9 @@ namespace Tensorflow name = op_type_name; } - string scope = new ops.name_scope(name); + string scope = ""; + using (var namescope = new ops.name_scope(name)) + scope = namescope; var default_type_attr_map = new Dictionary(); foreach (var attr_def in op_def.Attr) diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index c073b311..e70684b7 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -51,31 +51,36 @@ namespace Tensorflow collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); ops.init_scope(); - name = new ops.name_scope(name, "Variable", init_from_fn ? new List() : new List { initial_value }); - if (init_from_fn) + var values = init_from_fn ? new List() : new List { initial_value }; + using (var namescope = new ops.name_scope(name, "Variable", values)) { + name = namescope; - } - else - { - _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); - } + if (init_from_fn) + { - var shape = _initial_value.shape; - dtype = _initial_value.dtype; - _variable = gen_state_ops.variable_v2(shape, dtype, name); + } + else + { + _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); + } - // Manually overrides the variable's shape with the initial value's. - if (validate_shape) - { - var initial_value_shape = _initial_value.shape; - } + var shape = _initial_value.shape; + dtype = _initial_value.dtype; + _variable = gen_state_ops.variable_v2(shape, dtype, name); + + // Manually overrides the variable's shape with the initial value's. + if (validate_shape) + { + var initial_value_shape = _initial_value.shape; + } - // If 'initial_value' makes use of other variables, make sure we don't - // have an issue if these other variables aren't initialized first by - // using their initialized_value() method. + // If 'initial_value' makes use of other variables, make sure we don't + // have an issue if these other variables aren't initialized first by + // using their initialized_value() method. - ops.add_to_collections(collections, this); + ops.add_to_collections(collections, this); + } } public Tensor _ref() diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index 8d00c416..1fc70e46 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -6,15 +6,16 @@ namespace Tensorflow { public partial class ops { - public class name_scope + public class name_scope : IDisposable { public string _name; public string _default_name; public object _values; public Context _ctx; public string _name_scope; + private object _g_manager; - public name_scope(string name, string default_name = "", List values = null) + public name_scope(string name, string default_name = "", List values = null) { _name = name; _default_name = default_name; @@ -31,11 +32,23 @@ namespace Tensorflow _name = _default_name; } + Graph g = null; + if (_values is List values) + g = _get_graph_from_inputs(values); + + if (g == null) + g = get_default_graph(); + + return g.name_scope(_name); ; + } + + public void Dispose() + { var g = get_default_graph(); - return g.name_scope(_name); + g._name_stack = g.old_stack; } - public static implicit operator string(name_scope ns) + public static implicit operator string(name_scope ns) { return ns._name_scope; } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 92312c8f..3cc5d18f 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -34,6 +34,17 @@ namespace Tensorflow return tf.Graph(); } + public static Graph _get_graph_from_inputs(List op_input_list, Graph graph = null) + { + foreach(var op_input in op_input_list) + { + // Determine if this is a valid graph_element. + var graph_element = op_input; + } + + return get_default_graph(); + } + public static Tensor convert_to_tensor(object value, string name = "") { var nd = tensor_util.convert_to_numpy_ndarray(value);