From 363b58b85c035b61ff317fe01f33d86d60eb0c88 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 13 Feb 2019 09:18:32 -0600 Subject: [PATCH] add _init_from_proto to create variable. --- .../Framework/meta_graph.py.cs | 16 +++-- .../Variables/RefVariable.cs | 58 +++++++++++++++++-- .../Variables/VariableV1.cs | 2 +- 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index 060c267a..580bc66d 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -1,4 +1,5 @@ -using System; +using Google.Protobuf; +using System; using System.Collections.Generic; using System.IO; using System.Linq; @@ -59,7 +60,7 @@ namespace Tensorflow return_elements: return_elements); // Restores all the other collections. - var variable_objects = new Dictionary(); + var variable_objects = new Dictionary(); foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) { // Don't add unbound_inputs to the new graph. @@ -81,8 +82,15 @@ namespace Tensorflow { foreach (var value in col.Value.BytesList.Value) { - var proto = VariableDef.Parser.ParseFrom(value); - throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + RefVariable variable = null; + if (!variable_objects.ContainsKey(value)) + { + var proto = VariableDef.Parser.ParseFrom(value); + variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); + variable_objects[value] = variable; + } + + graph.add_to_collection(col.Key, variable); } } else diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index d3dfc91c..5dbff318 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -23,18 +23,68 @@ namespace Tensorflow public string name => _variable.name; - public RefVariable(object initial_value, + public RefVariable(object initial_value = null, bool trainable = true, List collections = null, bool validate_shape = true, string caching_device = "", string name = "", - TF_DataType dtype = TF_DataType.DtInvalid) : - base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) + VariableDef variable_def = null, + TF_DataType dtype = TF_DataType.DtInvalid, + string import_scope = "") : base(initial_value, + trainable, + collections, + validate_shape, + caching_device, + name, + dtype) { _in_graph_mode = true; - _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); + if (variable_def != null) + { + if (initial_value != null) + throw new ValueError("variable_def and initial_value are mutually exclusive."); + _init_from_proto(variable_def, import_scope: import_scope); + } + else + { + _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); + } + } + + private void _init_from_proto(VariableDef variable_def, string import_scope = "") + { + var g = ops.get_default_graph(); + + _variable = g.as_graph_element( + ops.prepend_name_scope(variable_def.VariableName, + import_scope: import_scope)) as Tensor; + + _initializer_op = g.as_graph_element( + ops.prepend_name_scope(variable_def.InitializerName, + import_scope: import_scope)) as Operation; + + // Tests whether initial_value_name exists first for backwards compatibility. + if (!string.IsNullOrEmpty(variable_def.InitialValueName)) + _initial_value = g.as_graph_element( + ops.prepend_name_scope(variable_def.InitialValueName, + import_scope: import_scope)) as Tensor; + else + _initial_value = null; + + _trainable = variable_def.Trainable; + _snapshot = g.as_graph_element( + ops.prepend_name_scope(variable_def.SnapshotName, + import_scope: import_scope)) as Tensor; + + if (variable_def.SaveSliceInfoDef != null) + throw new NotImplementedException("save_slice_info_def"); + else + ;// _save_slice_info = null; + + //_caching_device = null; + //_constraint = null; } private void _init_from_args(object initial_value, diff --git a/src/TensorFlowNET.Core/Variables/VariableV1.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs index 7d310f61..2ef715f2 100644 --- a/src/TensorFlowNET.Core/Variables/VariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs @@ -16,7 +16,7 @@ namespace Tensorflow /// public class VariableV1 { - public VariableV1(object initial_value, + public VariableV1(object initial_value = null, bool trainable = true, List collections = null, bool validate_shape = true,