@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public interface IInitializer | |||||
{ | |||||
Tensor call(TensorShape shape, TF_DataType dtype); | |||||
object get_config(); | |||||
} | |||||
} |
@@ -0,0 +1,34 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public static partial class tf | |||||
{ | |||||
public static IInitializer zeros_initializer => new Zeros(); | |||||
public class Zeros : IInitializer | |||||
{ | |||||
private TF_DataType dtype; | |||||
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
{ | |||||
this.dtype = dtype; | |||||
} | |||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
if (dtype == TF_DataType.DtInvalid) | |||||
dtype = this.dtype; | |||||
return array_ops.zeros(shape, dtype); | |||||
} | |||||
public object get_config() | |||||
{ | |||||
return new { dtype = dtype.name() }; | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -71,6 +71,11 @@ namespace Tensorflow | |||||
type; | type; | ||||
} | } | ||||
public static int name(this TF_DataType type) | |||||
{ | |||||
return (int)type; | |||||
} | |||||
public static DataType as_base_dtype(this DataType type) | public static DataType as_base_dtype(this DataType type) | ||||
{ | { | ||||
return (int)type > 100 ? | return (int)type > 100 ? | ||||
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public enum VariableAggregation | |||||
{ | |||||
NONE = 0, | |||||
SUM = 1, | |||||
MEAN = 2, | |||||
ONLY_FIRST_REPLICA = 3 // ONLY_FIRST_TOWER | |||||
} | |||||
} |
@@ -6,6 +6,35 @@ namespace Tensorflow | |||||
{ | { | ||||
public class VariableScope | public class VariableScope | ||||
{ | { | ||||
public bool? use_resource { get; set; } | |||||
public bool use_resource { get; set; } | |||||
private _ReuseMode _reuse { get; set; } | |||||
private object _regularizer; | |||||
private TF_DataType _dtype; | |||||
public string name { get; set; } | |||||
public VariableScope() | |||||
{ | |||||
_reuse = _ReuseMode.AUTO_REUSE; | |||||
} | |||||
public RefVariable get_variable(_VariableStore var_store, | |||||
string name, | |||||
TensorShape shape = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
VariableAggregation aggregation= VariableAggregation.NONE) | |||||
{ | |||||
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; | |||||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope => | |||||
{ | |||||
if (dtype == TF_DataType.DtInvalid) | |||||
dtype = _dtype; | |||||
return var_store.get_variable(full_name); | |||||
}); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// Mode for variable access within a variable scope. | |||||
/// </summary> | |||||
public enum _ReuseMode | |||||
{ | |||||
// Indicates that variables are to be fetched if they already exist or | |||||
// otherwise created. | |||||
AUTO_REUSE = 1 | |||||
} | |||||
} |
@@ -0,0 +1,80 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// Variable store that carries a number of named Variables. | |||||
/// </summary> | |||||
public class _VariableStore | |||||
{ | |||||
private Dictionary<string, object> _vars; | |||||
private Dictionary<string, object> _partitioned_vars; | |||||
private bool _store_eager_variables; | |||||
public _VariableStore() | |||||
{ | |||||
_vars = new Dictionary<string, object>(); | |||||
_partitioned_vars = new Dictionary<string, object>(); | |||||
_store_eager_variables = false; | |||||
} | |||||
public RefVariable get_variable(string name, | |||||
TensorShape shape = null, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
IInitializer initializer = null, | |||||
bool trainable = false, | |||||
bool validate_shape = true, | |||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
VariableAggregation aggregation = VariableAggregation.NONE) | |||||
{ | |||||
dtype = dtype.as_base_dtype(); | |||||
trainable = variable_scope._get_trainable_value(synchronization, trainable); | |||||
return _true_getter(name, | |||||
shape: shape, | |||||
dtype: dtype, | |||||
initializer: initializer, | |||||
trainable: trainable, | |||||
validate_shape: validate_shape, | |||||
synchronization: synchronization, | |||||
aggregation: aggregation); | |||||
} | |||||
private RefVariable _true_getter(string name, | |||||
TensorShape shape = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
IInitializer initializer = null, | |||||
bool trainable = false, | |||||
bool validate_shape = true, | |||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
VariableAggregation aggregation = VariableAggregation.NONE) | |||||
{ | |||||
return _get_single_variable(name: name); | |||||
} | |||||
private RefVariable _get_single_variable(string name, | |||||
TensorShape shape = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
IInitializer initializer = null, | |||||
bool reuse = false, | |||||
bool trainable = false, | |||||
bool validate_shape = false, | |||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
VariableAggregation aggregation = VariableAggregation.NONE) | |||||
{ | |||||
if (_vars.ContainsKey(name)) | |||||
{ | |||||
if (!reuse) | |||||
{ | |||||
var var = _vars[name]; | |||||
} | |||||
throw new NotImplementedException("_get_single_variable"); | |||||
} | |||||
throw new NotImplementedException("_get_single_variable"); | |||||
} | |||||
} | |||||
} |
@@ -11,5 +11,15 @@ namespace Tensorflow | |||||
var g = variables.global_variables(); | var g = variables.global_variables(); | ||||
return variables.variables_initializer(g.ToArray()); | return variables.variables_initializer(g.ToArray()); | ||||
} | } | ||||
public static RefVariable get_variable(string name, | |||||
TensorShape shape = null, | |||||
IInitializer initializer = null, | |||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
VariableAggregation aggregation = VariableAggregation.NONE) | |||||
{ | |||||
var store = variable_scope._get_default_variable_store(); | |||||
return variable_scope.get_variable_scope().get_variable(store, name, shape: shape); | |||||
} | |||||
} | } | ||||
} | } |
@@ -6,6 +6,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public class variable_scope | public class variable_scope | ||||
{ | { | ||||
public static string _VARSTORE_KEY = "__variable_store"; | |||||
public static string _VARSCOPESTORE_KEY = "__varscope"; | public static string _VARSCOPESTORE_KEY = "__varscope"; | ||||
public static bool _DEFAULT_USE_RESOURCE = false; | public static bool _DEFAULT_USE_RESOURCE = false; | ||||
@@ -32,6 +33,17 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public static _VariableStore _get_default_variable_store() | |||||
{ | |||||
var store = ops.get_collection(_VARSTORE_KEY); | |||||
if (store != null) | |||||
return (store as List<_VariableStore>)[0]; | |||||
var store1 = new _VariableStore(); | |||||
ops.add_to_collection(_VARSTORE_KEY, store1); | |||||
return store1; | |||||
} | |||||
public static VariableScope get_variable_scope() | public static VariableScope get_variable_scope() | ||||
{ | { | ||||
return get_variable_scope_store().current_scope; | return get_variable_scope_store().current_scope; | ||||
@@ -65,24 +77,18 @@ namespace Tensorflow | |||||
return ret; | return ret; | ||||
} | } | ||||
public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | |||||
public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true) | |||||
{ | { | ||||
if(synchronization == VariableSynchronization.ON_READ) | |||||
if (synchronization == VariableSynchronization.ON_READ) | |||||
{ | { | ||||
if (trainable.Value) | |||||
if (trainable) | |||||
throw new ValueError("Synchronization value can be set to " + | throw new ValueError("Synchronization value can be set to " + | ||||
"VariableSynchronization.ON_READ only for non-trainable variables. " + | "VariableSynchronization.ON_READ only for non-trainable variables. " + | ||||
"You have specified trainable=True and " + | "You have specified trainable=True and " + | ||||
"synchronization=VariableSynchronization.ON_READ."); | "synchronization=VariableSynchronization.ON_READ."); | ||||
else | |||||
trainable = false; | |||||
} | } | ||||
else if (!trainable.HasValue) | |||||
{ | |||||
trainable = true; | |||||
} | |||||
return trainable.Value; | |||||
return trainable; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,21 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.UnitTest | |||||
{ | |||||
[TestClass] | |||||
public class TrainSaverTest | |||||
{ | |||||
[TestMethod] | |||||
public void Save() | |||||
{ | |||||
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||||
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,26 @@ | |||||
| |||||
import tensorflow as tf | |||||
# Create some variables. | |||||
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) | |||||
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) | |||||
inc_v1 = v1.assign(v1+1) | |||||
dec_v2 = v2.assign(v2-1) | |||||
# Add an op to initialize the variables. | |||||
init_op = tf.global_variables_initializer() | |||||
# Add ops to save and restore all the variables. | |||||
saver = tf.train.Saver() | |||||
# Later, launch the model, initialize the variables, do some work, and save the | |||||
# variables to disk. | |||||
with tf.Session() as sess: | |||||
sess.run(init_op) | |||||
# Do some work with the model. | |||||
inc_v1.op.run() | |||||
dec_v2.op.run() | |||||
# Save the variables to disk. | |||||
save_path = saver.save(sess, "/tmp/model.ckpt") | |||||
print("Model saved in path: %s" % save_path) |