Browse Source

add _init_from_proto to create variable.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
363b58b85c
3 changed files with 67 additions and 9 deletions
  1. +12
    -4
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  2. +54
    -4
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Variables/VariableV1.cs

+ 12
- 4
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -1,4 +1,5 @@
using System;
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@@ -59,7 +60,7 @@ namespace Tensorflow
return_elements: return_elements);

// Restores all the other collections.
var variable_objects = new Dictionary<string, RefVariable>();
var variable_objects = new Dictionary<ByteString, RefVariable>();
foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
{
// Don't add unbound_inputs to the new graph.
@@ -81,8 +82,15 @@ namespace Tensorflow
{
foreach (var value in col.Value.BytesList.Value)
{
var proto = VariableDef.Parser.ParseFrom(value);
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
RefVariable variable = null;
if (!variable_objects.ContainsKey(value))
{
var proto = VariableDef.Parser.ParseFrom(value);
variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
variable_objects[value] = variable;
}
graph.add_to_collection(col.Key, variable);
}
}
else


+ 54
- 4
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -23,18 +23,68 @@ namespace Tensorflow

public string name => _variable.name;

public RefVariable(object initial_value,
public RefVariable(object initial_value = null,
bool trainable = true,
List<string> collections = null,
bool validate_shape = true,
string caching_device = "",
string name = "",
TF_DataType dtype = TF_DataType.DtInvalid) :
base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype)
VariableDef variable_def = null,
TF_DataType dtype = TF_DataType.DtInvalid,
string import_scope = "") : base(initial_value,
trainable,
collections,
validate_shape,
caching_device,
name,
dtype)
{
_in_graph_mode = true;

_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
if (variable_def != null)
{
if (initial_value != null)
throw new ValueError("variable_def and initial_value are mutually exclusive.");
_init_from_proto(variable_def, import_scope: import_scope);
}
else
{
_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
}
}

private void _init_from_proto(VariableDef variable_def, string import_scope = "")
{
var g = ops.get_default_graph();

_variable = g.as_graph_element(
ops.prepend_name_scope(variable_def.VariableName,
import_scope: import_scope)) as Tensor;

_initializer_op = g.as_graph_element(
ops.prepend_name_scope(variable_def.InitializerName,
import_scope: import_scope)) as Operation;

// Tests whether initial_value_name exists first for backwards compatibility.
if (!string.IsNullOrEmpty(variable_def.InitialValueName))
_initial_value = g.as_graph_element(
ops.prepend_name_scope(variable_def.InitialValueName,
import_scope: import_scope)) as Tensor;
else
_initial_value = null;

_trainable = variable_def.Trainable;
_snapshot = g.as_graph_element(
ops.prepend_name_scope(variable_def.SnapshotName,
import_scope: import_scope)) as Tensor;

if (variable_def.SaveSliceInfoDef != null)
throw new NotImplementedException("save_slice_info_def");
else
;// _save_slice_info = null;

//_caching_device = null;
//_constraint = null;
}

private void _init_from_args(object initial_value,


+ 1
- 1
src/TensorFlowNET.Core/Variables/VariableV1.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow
/// </summary>
public class VariableV1
{
public VariableV1(object initial_value,
public VariableV1(object initial_value = null,
bool trainable = true,
List<string> collections = null,
bool validate_shape = true,


Loading…
Cancel
Save