@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public interface IInitializer | |||
{ | |||
Tensor call(TensorShape shape, TF_DataType dtype); | |||
object get_config(); | |||
} | |||
} |
@@ -0,0 +1,34 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public static partial class tf | |||
{ | |||
public static IInitializer zeros_initializer => new Zeros(); | |||
public class Zeros : IInitializer | |||
{ | |||
private TF_DataType dtype; | |||
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
this.dtype = dtype; | |||
} | |||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = this.dtype; | |||
return array_ops.zeros(shape, dtype); | |||
} | |||
public object get_config() | |||
{ | |||
return new { dtype = dtype.name() }; | |||
} | |||
} | |||
} | |||
} |
@@ -71,6 +71,11 @@ namespace Tensorflow | |||
type; | |||
} | |||
public static int name(this TF_DataType type) | |||
{ | |||
return (int)type; | |||
} | |||
public static DataType as_base_dtype(this DataType type) | |||
{ | |||
return (int)type > 100 ? | |||
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public enum VariableAggregation | |||
{ | |||
NONE = 0, | |||
SUM = 1, | |||
MEAN = 2, | |||
ONLY_FIRST_REPLICA = 3 // ONLY_FIRST_TOWER | |||
} | |||
} |
@@ -6,6 +6,35 @@ namespace Tensorflow | |||
{ | |||
public class VariableScope | |||
{ | |||
public bool? use_resource { get; set; } | |||
public bool use_resource { get; set; } | |||
private _ReuseMode _reuse { get; set; } | |||
private object _regularizer; | |||
private TF_DataType _dtype; | |||
public string name { get; set; } | |||
public VariableScope() | |||
{ | |||
_reuse = _ReuseMode.AUTO_REUSE; | |||
} | |||
public RefVariable get_variable(_VariableStore var_store, | |||
string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation= VariableAggregation.NONE) | |||
{ | |||
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; | |||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope => | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = _dtype; | |||
return var_store.get_variable(full_name); | |||
}); | |||
} | |||
} | |||
} |
@@ -0,0 +1,16 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Mode for variable access within a variable scope. | |||
/// </summary> | |||
public enum _ReuseMode | |||
{ | |||
// Indicates that variables are to be fetched if they already exist or | |||
// otherwise created. | |||
AUTO_REUSE = 1 | |||
} | |||
} |
@@ -0,0 +1,80 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Variable store that carries a number of named Variables. | |||
/// </summary> | |||
public class _VariableStore | |||
{ | |||
private Dictionary<string, object> _vars; | |||
private Dictionary<string, object> _partitioned_vars; | |||
private bool _store_eager_variables; | |||
public _VariableStore() | |||
{ | |||
_vars = new Dictionary<string, object>(); | |||
_partitioned_vars = new Dictionary<string, object>(); | |||
_store_eager_variables = false; | |||
} | |||
public RefVariable get_variable(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
bool trainable = false, | |||
bool validate_shape = true, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation = VariableAggregation.NONE) | |||
{ | |||
dtype = dtype.as_base_dtype(); | |||
trainable = variable_scope._get_trainable_value(synchronization, trainable); | |||
return _true_getter(name, | |||
shape: shape, | |||
dtype: dtype, | |||
initializer: initializer, | |||
trainable: trainable, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
} | |||
private RefVariable _true_getter(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
bool trainable = false, | |||
bool validate_shape = true, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation = VariableAggregation.NONE) | |||
{ | |||
return _get_single_variable(name: name); | |||
} | |||
private RefVariable _get_single_variable(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
bool reuse = false, | |||
bool trainable = false, | |||
bool validate_shape = false, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation = VariableAggregation.NONE) | |||
{ | |||
if (_vars.ContainsKey(name)) | |||
{ | |||
if (!reuse) | |||
{ | |||
var var = _vars[name]; | |||
} | |||
throw new NotImplementedException("_get_single_variable"); | |||
} | |||
throw new NotImplementedException("_get_single_variable"); | |||
} | |||
} | |||
} |
@@ -11,5 +11,15 @@ namespace Tensorflow | |||
var g = variables.global_variables(); | |||
return variables.variables_initializer(g.ToArray()); | |||
} | |||
public static RefVariable get_variable(string name, | |||
TensorShape shape = null, | |||
IInitializer initializer = null, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation = VariableAggregation.NONE) | |||
{ | |||
var store = variable_scope._get_default_variable_store(); | |||
return variable_scope.get_variable_scope().get_variable(store, name, shape: shape); | |||
} | |||
} | |||
} |
@@ -6,6 +6,7 @@ namespace Tensorflow | |||
{ | |||
public class variable_scope | |||
{ | |||
public static string _VARSTORE_KEY = "__variable_store"; | |||
public static string _VARSCOPESTORE_KEY = "__varscope"; | |||
public static bool _DEFAULT_USE_RESOURCE = false; | |||
@@ -32,6 +33,17 @@ namespace Tensorflow | |||
} | |||
} | |||
public static _VariableStore _get_default_variable_store() | |||
{ | |||
var store = ops.get_collection(_VARSTORE_KEY); | |||
if (store != null) | |||
return (store as List<_VariableStore>)[0]; | |||
var store1 = new _VariableStore(); | |||
ops.add_to_collection(_VARSTORE_KEY, store1); | |||
return store1; | |||
} | |||
public static VariableScope get_variable_scope() | |||
{ | |||
return get_variable_scope_store().current_scope; | |||
@@ -65,24 +77,18 @@ namespace Tensorflow | |||
return ret; | |||
} | |||
public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | |||
public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true) | |||
{ | |||
if(synchronization == VariableSynchronization.ON_READ) | |||
if (synchronization == VariableSynchronization.ON_READ) | |||
{ | |||
if (trainable.Value) | |||
if (trainable) | |||
throw new ValueError("Synchronization value can be set to " + | |||
"VariableSynchronization.ON_READ only for non-trainable variables. " + | |||
"You have specified trainable=True and " + | |||
"synchronization=VariableSynchronization.ON_READ."); | |||
else | |||
trainable = false; | |||
} | |||
else if (!trainable.HasValue) | |||
{ | |||
trainable = true; | |||
} | |||
return trainable.Value; | |||
return trainable; | |||
} | |||
} | |||
} |
@@ -0,0 +1,21 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
[TestClass] | |||
public class TrainSaverTest | |||
{ | |||
[TestMethod] | |||
public void Save() | |||
{ | |||
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | |||
} | |||
} | |||
} |
@@ -0,0 +1,26 @@ | |||
| |||
import tensorflow as tf | |||
# Create some variables. | |||
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) | |||
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) | |||
inc_v1 = v1.assign(v1+1) | |||
dec_v2 = v2.assign(v2-1) | |||
# Add an op to initialize the variables. | |||
init_op = tf.global_variables_initializer() | |||
# Add ops to save and restore all the variables. | |||
saver = tf.train.Saver() | |||
# Later, launch the model, initialize the variables, do some work, and save the | |||
# variables to disk. | |||
with tf.Session() as sess: | |||
sess.run(init_op) | |||
# Do some work with the model. | |||
inc_v1.op.run() | |||
dec_v2.op.run() | |||
# Save the variables to disk. | |||
save_path = saver.save(sess, "/tmp/model.ckpt") | |||
print("Model saved in path: %s" % save_path) |