Browse Source

Rollback Fix collections typing #448

tags/v0.13
Oceania2018 5 years ago
parent
commit
945ac02415
7 changed files with 146 additions and 97 deletions
  1. +5
    -5
      src/TensorFlowNET.Core/APIs/tf.variable.cs
  2. +15
    -3
      src/TensorFlowNET.Core/Binding.FuncTools.cs
  3. +29
    -25
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  4. +56
    -38
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Training/TrainingUtil.cs
  6. +39
    -24
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  7. +1
    -1
      src/TensorFlowNET.Core/ops.cs

+ 5
- 5
src/TensorFlowNET.Core/APIs/tf.variable.cs View File

@@ -23,14 +23,14 @@ namespace Tensorflow
{
public VariableV1[] global_variables(string scope = null)
{
return (ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope))
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
.ToArray();
}

public Operation global_variables_initializer()
{
var g = variables.global_variables();
return variables.variables_initializer(g?.ToArray());
return variables.variables_initializer(g.ToArray());
}

/// <summary>
@@ -54,9 +54,9 @@ namespace Tensorflow
{
var scope = Tensorflow.variable_scope.get_variable_scope();
var store = Tensorflow.variable_scope._get_default_variable_store();
return scope.get_variable(store,
name,
shape: shape,
return scope.get_variable(store,
name,
shape: shape,
dtype: dtype,
use_resource: use_resource,
validate_shape: validate_shape,


+ 15
- 3
src/TensorFlowNET.Core/Binding.FuncTools.cs View File

@@ -10,11 +10,23 @@ namespace Tensorflow
{
public static class functools
{
public static Func<Tin, Tout> partial<Tin, Tout>(Func<Tin, Tout> func, Tin arg)
=> (arg0) => func(arg0);
public static PartialFunc<Tin, Tout> partial<Tin, Tout>(Func<Tin, Tout> func, Tin arg)
=> new PartialFunc<Tin, Tout>
{
args = arg,
invoke = func
};

public static Func<Tin1, Tin2, Tout> partial<Tin1, Tin2, Tout>(Func<Tin1, Tin2, Tout> func, (Tin1, Tin2) args)
=> (arg1, arg2) => func(arg1, arg2);
=> (arg1, arg2) => func(args.Item1, args.Item2);
}

public class PartialFunc<Tin, Tout>
{
public Tin args { get; set; }
public object[] keywords { get; set; }

public Func<Tin, Tout> invoke { get; set; }
}
}
}

+ 29
- 25
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -46,9 +46,9 @@ namespace Tensorflow

if (!string.IsNullOrEmpty(unbound_inputs_col_name))
{
foreach(var col in meta_graph_def.CollectionDef)
foreach (var col in meta_graph_def.CollectionDef)
{
if(col.Key == unbound_inputs_col_name)
if (col.Key == unbound_inputs_col_name)
{
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
}
@@ -78,7 +78,7 @@ namespace Tensorflow

// Restores all the other collections.
var variable_objects = new Dictionary<ByteString, VariableV1>();
foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
{
// Don't add unbound_inputs to the new graph.
if (col.Key == unbound_inputs_col_name)
@@ -87,7 +87,7 @@ namespace Tensorflow
switch (col.Value.KindCase)
{
case KindOneofCase.NodeList:
foreach(var value in col.Value.NodeList.Value)
foreach (var value in col.Value.NodeList.Value)
{
var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
graph.add_to_collection(col.Key, col_op);
@@ -115,7 +115,7 @@ namespace Tensorflow
}
else
{
foreach(var value in col.Value.BytesList.Value)
foreach (var value in col.Value.BytesList.Value)
{
switch (col.Key)
{
@@ -139,7 +139,7 @@ namespace Tensorflow
}
}
}
break;
default:
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
@@ -173,8 +173,8 @@ namespace Tensorflow
string unbound_inputs_col_name = "unbound_inputs",
bool clear_devices = false,
SaverDef saver_def = null,
bool clear_extraneous_savers= false,
bool strip_default_attrs= false,
bool clear_extraneous_savers = false,
bool strip_default_attrs = false,
byte[] meta_info_def = null)
{
var graph = ops.get_default_graph();
@@ -236,12 +236,12 @@ namespace Tensorflow
meta_graph_def.GraphDef = graph_def;

// Fills in meta_info_def.stripped_op_list using the ops from graph_def.
if (meta_graph_def.MetaInfoDef.StrippedOpList == null ||
if (meta_graph_def.MetaInfoDef.StrippedOpList == null ||
meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0)
meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef);

var clist = graph.get_all_collection_keys();
foreach(var ctype in clist)
foreach (var ctype in clist)
{
if (clear_extraneous_savers)
{
@@ -256,30 +256,34 @@ namespace Tensorflow
return meta_graph_def;
}

private static void add_collection_def(MetaGraphDef meta_graph_def,
string key,
private static void add_collection_def(MetaGraphDef meta_graph_def,
string key,
Graph graph = null,
string export_scope = "")
{
if (!meta_graph_def.CollectionDef.ContainsKey(key))
meta_graph_def.CollectionDef[key] = new CollectionDef();
var col_def = meta_graph_def.CollectionDef[key];
col_def.NodeList = new Types.NodeList();
col_def.BytesList = new Types.BytesList();
foreach (object value in graph.get_collection(key))

switch (graph.get_collection(key))
{
switch (value)
{
case RefVariable x:
case List<RefVariable> collection_list:
col_def.BytesList = new Types.BytesList();
foreach (var x in collection_list)
{
var proto = x.to_proto(export_scope);
col_def.BytesList.Value.Add(proto.ToByteString());
break;
case ITensorOrOperation x2:
col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
break;
default:
break;
}
}

break;
case List<object> collection_list:
col_def.NodeList = new Types.NodeList();
foreach (var x in collection_list)
if (x is ITensorOrOperation x2)
col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
break;
case List<Operation> collection_list:
break;
}
}



+ 56
- 38
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -77,7 +77,7 @@ namespace Tensorflow
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
public partial class Graph : DisposableObject
#if !SERIALIZABLE
,IEnumerable<Operation>
, IEnumerable<Operation>
#endif
{
private Dictionary<int, ITensorOrOperation> _nodes_by_id;
@@ -100,15 +100,13 @@ namespace Tensorflow
/// </summary>
private bool _finalized = false;


/// <summary>
/// Arbitrary collections of objects inside the graph.
/// TODO: Access might be slow (-> O(n)) depending on size.
/// Arbitrary collections of objects.
/// </summary>
private readonly ICollection<(string name, string scope, object item)> _collections = new List<(string name, string scope, object item)>();
private Dictionary<string, object> _collections = new Dictionary<string, object>();

public bool building_function;
public bool building_function;
public Graph()
{
_handle = c_api.TF_NewGraph();
@@ -230,14 +228,16 @@ namespace Tensorflow
throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
}

public void add_to_collection(string name, object value)
public void add_to_collection<T>(string name, T value)
{
_check_not_finalized();
_collections.Add((name, null, value));
if (_collections.ContainsKey(name))
(_collections[name] as List<T>).Add(value);
else
_collections[name] = new List<T> { value };
}


public void add_to_collections(List<string> names, object value)
public void add_to_collections<T>(List<string> names, T value)
{
foreach (string name in names)
add_to_collection(name, value);
@@ -278,6 +278,12 @@ namespace Tensorflow

_create_op_helper(op, true);

/*Console.Write($"create_op: {op_type} '{node_def.Name}'");
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
Console.WriteLine();*/

return op;
}

@@ -394,7 +400,7 @@ namespace Tensorflow
_names_in_use[name_key] = 1;

// Return the new name with the original capitalization of the given name.
name = $"{name}_{i-1}";
name = $"{name}_{i - 1}";
}
return name;
}
@@ -407,8 +413,8 @@ namespace Tensorflow
TF_Output[] return_outputs = new TF_Output[num_return_outputs];
unsafe
{
var tf_output_ptr = (TF_Output*) return_output_handle;
for (int i = 0; i < num_return_outputs; i++)
var tf_output_ptr = (TF_Output*)return_output_handle;
for (int i = 0; i < num_return_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);
return return_outputs;
}
@@ -416,34 +422,46 @@ namespace Tensorflow

public string[] get_all_collection_keys()
{
return (from c in _collections where !c.name.StartsWith("__") select c.name).ToArray();
return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
}

public List<object> get_collection(string name, string scope = null)
public object get_collection(string name, string scope = null)
{
return get_collection<object>(name, scope);
}
return _collections.ContainsKey(name) ? _collections[name] : null;
}

public List<T> get_collection<T>(string name, string scope = null)
{
return (from c in _collections
where c.name == name &&
(scope == null || c.scope == scope) &&
implementationOf<T>(c.item)
select (T)(c.item)).ToList();
}
private static bool implementationOf<T>(object item)
{
return (item.GetType() == typeof(T) || item.GetType().IsSubclassOf(typeof(T)));
}
{
List<T> t = default;
var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>();
switch (collection)
{
case List<VariableV1> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<ResourceVariable> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<RefVariable> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<Tensor> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<Operation> list:
t = list.Select(x => (T)(object)x).ToList();
break;
default:
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
}
return t;
}

public List<T> get_collection_ref<T>(string name)
{
return get_collection<T>(name);
if (!_collections.ContainsKey(name))
_collections[name] = new List<T>();
return _collections[name] as List<T>;
}

public void prevent_feeding(Tensor tensor)
@@ -497,7 +515,7 @@ namespace Tensorflow
string debugString = string.Empty;
public override string ToString()
{
return $"{graph_key}, ({_handle})";
return $"{graph_key}, ({_handle})";
/*if (string.IsNullOrEmpty(debugString))
{
int len = 0;
@@ -514,7 +532,7 @@ namespace Tensorflow
IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator()
=> GetEnumerable().GetEnumerator();
IEnumerator IEnumerable.GetEnumerator()
IEnumerator IEnumerable.GetEnumerator()
=> throw new NotImplementedException();
#endif


+ 1
- 1
src/TensorFlowNET.Core/Training/TrainingUtil.cs View File

@@ -16,7 +16,7 @@ namespace Tensorflow.Train
// Create in proper graph and base name_scope.
var g = graph.as_default();
g.name_scope(null);
var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64,
var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new int[0], dtype: dtypes.int64,
initializer: tf.zeros_initializer,
trainable: false,
aggregation: VariableAggregation.OnlyFirstReplica,


+ 39
- 24
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
protected Graph _graph;
bool _building_function;

public variable_scope(string name,
public variable_scope(string name,
string default_name = "",
Tensor[] values = null,
bool? reuse = null,
@@ -113,7 +113,7 @@ namespace Tensorflow
{
// Reenter the current name scope
string name_scope = ops.get_name_scope();
if(!string.IsNullOrEmpty(name_scope))
if (!string.IsNullOrEmpty(name_scope))
// Hack to reenter
name_scope += "/";
current_name_scope = ops.name_scope(name_scope);
@@ -128,8 +128,8 @@ namespace Tensorflow
string current_name_scope_name = current_name_scope;
_current_name_scope = current_name_scope;
string old_name_scope = _scope == null ? current_name_scope_name : _scope.original_name_scope;
if(_scope == null)
if (_scope == null)
pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope);
else
pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope);
@@ -179,7 +179,7 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
int[] shape = null,
bool validate_shape = false,
bool ? use_resource = null,
bool? use_resource = null,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None)
{
@@ -189,7 +189,7 @@ namespace Tensorflow
use_resource = get_variable_scope().use_resource;
}

if(!use_resource.HasValue)
if (!use_resource.HasValue)
use_resource = _DEFAULT_USE_RESOURCE;

if (use_resource.Value)
@@ -204,7 +204,7 @@ namespace Tensorflow
}
else
{
return new RefVariable(initial_value,
return new RefVariable(initial_value,
trainable: trainable.Value,
validate_shape: validate_shape,
collections: collections,
@@ -215,13 +215,13 @@ namespace Tensorflow

public static _VariableStore _get_default_variable_store()
{
var store = ops.get_collection<_VariableStore>(_VARSTORE_KEY).FirstOrDefault();
if (store == null)
{
store = new _VariableStore();
ops.add_to_collection(_VARSTORE_KEY, store);
}
return store;
var store = ops.get_collection(_VARSTORE_KEY);
if (store != null)
return (store as List<_VariableStore>)[0];
var store1 = new _VariableStore();
ops.add_to_collection(_VARSTORE_KEY, store1);
return store1;
}

public static VariableScope get_variable_scope()
@@ -231,15 +231,30 @@ namespace Tensorflow

public static _VariableScopeStore get_variable_scope_store()
{
var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault();
if (scope_store == null)
scope_store = ops.get_collection<RefVariable>(_VARSCOPESTORE_KEY).FirstOrDefault();
_VariableScopeStore ret = null;
var scope_store = ops.get_collection(_VARSCOPESTORE_KEY);
if (scope_store == null)
{
scope_store = new _VariableScopeStore();
ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store);
ret = new _VariableScopeStore();
ops.add_to_collection(_VARSCOPESTORE_KEY, ret);
}
else
{
switch (scope_store)
{
case List<RefVariable> values:
ret = values[0];
break;
case List<_VariableScopeStore> values:
ret = values[0];
break;
default:
throw new InvalidOperationException("get_variable_scope_store");
}

}
return scope_store;

return ret;
}

public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true)
@@ -256,7 +271,7 @@ namespace Tensorflow
{
trainable = true;
}
return trainable.Value;
}

@@ -279,7 +294,7 @@ namespace Tensorflow
}

// TODO for Switch/Case
public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource,
public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool trainable = false,
@@ -290,12 +305,12 @@ namespace Tensorflow

public void __init__()
{
}

public void __del__()
{
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -63,7 +63,7 @@ namespace Tensorflow
/// list contains the values in the order under which they were
/// collected.
/// </returns>
public static List<object> get_collection(string key, string scope = null)
public static object get_collection(string key, string scope = null)
{
return get_default_graph().get_collection(key, scope);
}


Loading…
Cancel
Save