Browse Source

add VariableScope and _VariableStore

tags/v0.8.0
haiping008 6 years ago
parent
commit
bf45277be8
11 changed files with 265 additions and 12 deletions
  1. +12
    -0
      src/TensorFlowNET.Core/Operations/IInitializer.cs
  2. +34
    -0
      src/TensorFlowNET.Core/Operations/tf.init_ops.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  4. +14
    -0
      src/TensorFlowNET.Core/Variables/VariableAggregation.cs
  5. +30
    -1
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  6. +16
    -0
      src/TensorFlowNET.Core/Variables/_ReuseMode.cs
  7. +80
    -0
      src/TensorFlowNET.Core/Variables/_VariableStore.cs
  8. +10
    -0
      src/TensorFlowNET.Core/Variables/tf.variable.cs
  9. +17
    -11
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  10. +21
    -0
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs
  11. +26
    -0
      test/TensorFlowNET.UnitTest/python/train_saver.py

+ 12
- 0
src/TensorFlowNET.Core/Operations/IInitializer.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public interface IInitializer
{
Tensor call(TensorShape shape, TF_DataType dtype);
object get_config();
}
}

+ 34
- 0
src/TensorFlowNET.Core/Operations/tf.init_ops.cs View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static IInitializer zeros_initializer => new Zeros();
public class Zeros : IInitializer
{
private TF_DataType dtype;

public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;

return array_ops.zeros(shape, dtype);
}

public object get_config()
{
return new { dtype = dtype.name() };
}
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -71,6 +71,11 @@ namespace Tensorflow
type;
}

public static int name(this TF_DataType type)
{
return (int)type;
}

public static DataType as_base_dtype(this DataType type)
{
return (int)type > 100 ?


+ 14
- 0
src/TensorFlowNET.Core/Variables/VariableAggregation.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public enum VariableAggregation
{
NONE = 0,
SUM = 1,
MEAN = 2,
ONLY_FIRST_REPLICA = 3 // ONLY_FIRST_TOWER
}
}

+ 30
- 1
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -6,6 +6,35 @@ namespace Tensorflow
{
public class VariableScope
{
public bool? use_resource { get; set; }
public bool use_resource { get; set; }
private _ReuseMode _reuse { get; set; }

private object _regularizer;
private TF_DataType _dtype;
public string name { get; set; }

public VariableScope()
{
_reuse = _ReuseMode.AUTO_REUSE;
}

public RefVariable get_variable(_VariableStore var_store,
string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation= VariableAggregation.NONE)
{
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope =>
{
if (dtype == TF_DataType.DtInvalid)
dtype = _dtype;

return var_store.get_variable(full_name);

});

}
}
}

+ 16
- 0
src/TensorFlowNET.Core/Variables/_ReuseMode.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Mode for variable access within a variable scope.
/// </summary>
public enum _ReuseMode
{
// Indicates that variables are to be fetched if they already exist or
// otherwise created.
AUTO_REUSE = 1
}
}

+ 80
- 0
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -0,0 +1,80 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Variable store that carries a number of named Variables.
/// </summary>
public class _VariableStore
{
private Dictionary<string, object> _vars;
private Dictionary<string, object> _partitioned_vars;
private bool _store_eager_variables;

public _VariableStore()
{
_vars = new Dictionary<string, object>();
_partitioned_vars = new Dictionary<string, object>();
_store_eager_variables = false;
}

public RefVariable get_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = false,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
dtype = dtype.as_base_dtype();
trainable = variable_scope._get_trainable_value(synchronization, trainable);

return _true_getter(name,
shape: shape,
dtype: dtype,
initializer: initializer,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
}

private RefVariable _true_getter(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool trainable = false,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
return _get_single_variable(name: name);
}

private RefVariable _get_single_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool reuse = false,
bool trainable = false,
bool validate_shape = false,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
if (_vars.ContainsKey(name))
{
if (!reuse)
{
var var = _vars[name];

}
throw new NotImplementedException("_get_single_variable");
}

throw new NotImplementedException("_get_single_variable");
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Variables/tf.variable.cs View File

@@ -11,5 +11,15 @@ namespace Tensorflow
var g = variables.global_variables();
return variables.variables_initializer(g.ToArray());
}

public static RefVariable get_variable(string name,
TensorShape shape = null,
IInitializer initializer = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
var store = variable_scope._get_default_variable_store();
return variable_scope.get_variable_scope().get_variable(store, name, shape: shape);
}
}
}

+ 17
- 11
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -6,6 +6,7 @@ namespace Tensorflow
{
public class variable_scope
{
public static string _VARSTORE_KEY = "__variable_store";
public static string _VARSCOPESTORE_KEY = "__varscope";
public static bool _DEFAULT_USE_RESOURCE = false;

@@ -32,6 +33,17 @@ namespace Tensorflow
}
}

public static _VariableStore _get_default_variable_store()
{
var store = ops.get_collection(_VARSTORE_KEY);
if (store != null)
return (store as List<_VariableStore>)[0];

var store1 = new _VariableStore();
ops.add_to_collection(_VARSTORE_KEY, store1);
return store1;
}

public static VariableScope get_variable_scope()
{
return get_variable_scope_store().current_scope;
@@ -65,24 +77,18 @@ namespace Tensorflow
return ret;
}

public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null)
public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true)
{
if(synchronization == VariableSynchronization.ON_READ)
if (synchronization == VariableSynchronization.ON_READ)
{
if (trainable.Value)
if (trainable)
throw new ValueError("Synchronization value can be set to " +
"VariableSynchronization.ON_READ only for non-trainable variables. " +
"You have specified trainable=True and " +
"synchronization=VariableSynchronization.ON_READ.");
else
trainable = false;
}
else if (!trainable.HasValue)
{
trainable = true;
}

return trainable.Value;
return trainable;
}
}
}

+ 21
- 0
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -0,0 +1,21 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class TrainSaverTest
{
[TestMethod]
public void Save()
{
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);


}
}
}

+ 26
- 0
test/TensorFlowNET.UnitTest/python/train_saver.py View File

@@ -0,0 +1,26 @@

import tensorflow as tf

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)

Loading…
Cancel
Save