@@ -9,9 +9,21 @@ namespace Tensorflow | |||
public static IInitializer zeros_initializer => new Zeros(); | |||
public static IInitializer glorot_uniform_initializer => new GlorotUniform(); | |||
public static variable_scope variable_scope(string name_or_scope, | |||
public static variable_scope variable_scope(string name, | |||
string default_name = null, | |||
object values = null) => new variable_scope(name_or_scope, default_name, values); | |||
object values = null, | |||
bool auxiliary_name_scope = true) => new variable_scope(name, | |||
default_name, | |||
values, | |||
auxiliary_name_scope); | |||
public static variable_scope variable_scope(VariableScope scope, | |||
string default_name = null, | |||
object values = null, | |||
bool auxiliary_name_scope = true) => new variable_scope(scope, | |||
default_name, | |||
values, | |||
auxiliary_name_scope); | |||
public class Zeros : IInitializer | |||
{ | |||
@@ -6,7 +6,8 @@ namespace Tensorflow | |||
{ | |||
public class PureVariableScope : IPython | |||
{ | |||
private string _name_or_scope; | |||
private string _name; | |||
private VariableScope _scope; | |||
private string _new_name; | |||
private string _old_name_scope; | |||
private bool _reuse; | |||
@@ -14,29 +15,56 @@ namespace Tensorflow | |||
private VariableScope _old; | |||
private _VariableScopeStore _var_scope_store; | |||
private VariableScope variable_scope_object; | |||
private VariableScope _cached_variable_scope_object; | |||
public PureVariableScope(string name_or_scope, | |||
public PureVariableScope(string name, | |||
string old_name_scope = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
_name_or_scope = name_or_scope; | |||
_name = name; | |||
_old_name_scope = old_name_scope; | |||
_var_store = variable_scope._get_default_variable_store(); | |||
_var_scope_store = variable_scope.get_variable_scope_store(); | |||
} | |||
public void __enter__() | |||
public PureVariableScope(VariableScope scope, | |||
string old_name_scope = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
_old = _var_scope_store.current_scope; | |||
_new_name = string.IsNullOrEmpty(_old.name) ? _name_or_scope : _old.name + "/" + _name_or_scope; | |||
_reuse = _reuse || _old.resue; | |||
string name_scope = _old_name_scope == null ? _name_or_scope : _old_name_scope; | |||
variable_scope_object = new VariableScope(_reuse, | |||
_scope = scope; | |||
_old_name_scope = old_name_scope; | |||
_var_store = variable_scope._get_default_variable_store(); | |||
_var_scope_store = variable_scope.get_variable_scope_store(); | |||
_new_name = _scope._name; | |||
string name_scope = _scope._name_scope; | |||
variable_scope_object = new VariableScope(_reuse, | |||
name: _new_name, | |||
name_scope: name_scope); | |||
_var_scope_store.open_variable_scope(_new_name); | |||
_cached_variable_scope_object = variable_scope_object; | |||
} | |||
public void __enter__() | |||
{ | |||
_old = _var_scope_store.current_scope; | |||
if(_scope != null) | |||
{ | |||
_var_scope_store.open_variable_scope(_new_name); | |||
variable_scope_object = _cached_variable_scope_object; | |||
} | |||
else | |||
{ | |||
_new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name; | |||
_reuse = _reuse || _old.resue; | |||
string name_scope = _old_name_scope == null ? _name : _old_name_scope; | |||
variable_scope_object = new VariableScope(_reuse, | |||
name: _new_name, | |||
name_scope: name_scope); | |||
_var_scope_store.open_variable_scope(_new_name); | |||
} | |||
_var_scope_store.current_scope = variable_scope_object; | |||
} | |||
@@ -14,16 +14,17 @@ namespace Tensorflow | |||
public bool resue; | |||
private TF_DataType _dtype; | |||
public string name { get; set; } | |||
public string name_scope { get; set; } | |||
public string _name { get; set; } | |||
public string _name_scope { get; set; } | |||
public string original_name_scope => _name_scope; | |||
public VariableScope(bool reuse, | |||
string name = "", | |||
string name_scope = "", | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
this.name = name; | |||
this.name_scope = name_scope; | |||
_name = name; | |||
_name_scope = name_scope; | |||
_reuse = _ReuseMode.AUTO_REUSE; | |||
_dtype = dtype; | |||
} | |||
@@ -37,7 +38,7 @@ namespace Tensorflow | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation= VariableAggregation.NONE) | |||
{ | |||
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; | |||
string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name; | |||
return with(new ops.name_scope(null), scope => | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow | |||
@@ -12,42 +13,83 @@ namespace Tensorflow | |||
private bool _use_resource; | |||
public bool UseResource => _use_resource; | |||
private string _name_or_scope; | |||
private string _name; | |||
private VariableScope _scope; | |||
private string _default_name; | |||
private object _values; | |||
private string _current_name_scope; | |||
private ops.name_scope _current_name_scope; | |||
private bool _auxiliary_name_scope; | |||
private PureVariableScope _cached_pure_variable_scope; | |||
public variable_scope(string name_or_scope, string default_name = "", object values = null) | |||
public variable_scope(string name, | |||
string default_name = "", | |||
object values = null, | |||
bool auxiliary_name_scope = true) | |||
{ | |||
_name_or_scope = name_or_scope; | |||
_name = name; | |||
_default_name = default_name; | |||
_values = values; | |||
_current_name_scope = null; | |||
_use_resource = false; | |||
if (_default_name == null && _name_or_scope == null) | |||
throw new TypeError("If default_name is None then name_or_scope is required"); | |||
if (_default_name == null && _name == null) | |||
throw new TypeError("If default_name is None then name is required"); | |||
_auxiliary_name_scope = auxiliary_name_scope; | |||
} | |||
public variable_scope(VariableScope scope, | |||
string default_name = "", | |||
object values = null, | |||
bool auxiliary_name_scope = true) | |||
{ | |||
_scope = scope; | |||
_default_name = default_name; | |||
_values = values; | |||
_current_name_scope = null; | |||
_use_resource = false; | |||
if (_default_name == null && _scope == null) | |||
throw new TypeError("If default_name is None then scope is required"); | |||
_auxiliary_name_scope = auxiliary_name_scope; | |||
} | |||
public void __enter__() | |||
{ | |||
_enter_scope_uncached(); | |||
_scope = _enter_scope_uncached(); | |||
} | |||
public VariableScope _enter_scope_uncached() | |||
private VariableScope _enter_scope_uncached() | |||
{ | |||
ops.name_scope current_name_scope = null; | |||
if(_name_or_scope != null) | |||
ops.name_scope current_name_scope; | |||
if (_auxiliary_name_scope) | |||
// Create a new name scope later | |||
current_name_scope = null; | |||
else | |||
{ | |||
var name_scope = _name_or_scope; | |||
// Reenter the current name scope | |||
string name_scope = ops.get_name_scope(); | |||
if(!string.IsNullOrEmpty(name_scope)) | |||
// Hack to reenter | |||
name_scope += "/"; | |||
current_name_scope = new ops.name_scope(name_scope); | |||
} | |||
if (_name != null || _scope != null) | |||
{ | |||
var name_scope = _name == null ? _scope._name.Split('/').Last() : _name; | |||
if (name_scope != null || current_name_scope != null) | |||
current_name_scope = new ops.name_scope(name_scope); | |||
current_name_scope.__enter__(); | |||
string current_name_scope_name = current_name_scope; | |||
var current_name_scope_name = current_name_scope; | |||
_current_name_scope = current_name_scope; | |||
string old_name_scope = current_name_scope_name; | |||
var pure_variable_scope = new PureVariableScope(_name_or_scope, old_name_scope: old_name_scope); | |||
PureVariableScope pure_variable_scope = null; | |||
if(_scope == null) | |||
pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); | |||
else | |||
pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); | |||
pure_variable_scope.__enter__(); | |||
VariableScope entered_pure_variable_scope = pure_variable_scope; | |||
_cached_pure_variable_scope = pure_variable_scope; | |||
@@ -149,14 +191,21 @@ namespace Tensorflow | |||
return trainable.Value; | |||
} | |||
public static implicit operator VariableScope(variable_scope scope) | |||
{ | |||
return scope._scope; | |||
} | |||
public void __exit__() | |||
{ | |||
if (_current_name_scope != null) | |||
_current_name_scope.__exit__(); | |||
} | |||
public void Dispose() | |||
{ | |||
if (_current_name_scope != null) | |||
_current_name_scope.Dispose(); | |||
} | |||
} | |||
} |
@@ -475,5 +475,11 @@ namespace Tensorflow | |||
return name; | |||
} | |||
} | |||
public static string get_name_scope() | |||
{ | |||
var g = get_default_graph(); | |||
return g.get_name_scope(); | |||
} | |||
} | |||
} |
@@ -47,10 +47,28 @@ namespace TensorFlowNET.UnitTest | |||
}); | |||
} | |||
/// <summary> | |||
/// how to reenter a premade variable scope safely | |||
/// </summary> | |||
[TestMethod] | |||
public void ReenterVariableScope() | |||
{ | |||
variable_scope vs = null; | |||
with(tf.variable_scope("foo"), v => vs = v); | |||
// Re-enter the variable scope. | |||
with(tf.variable_scope(vs, auxiliary_name_scope: false), v => | |||
{ | |||
var vs1 = (VariableScope)v; | |||
// Restore the original name_scope. | |||
with(tf.name_scope(vs1.original_name_scope), delegate | |||
{ | |||
var v1 = tf.get_variable("v", new TensorShape(1)); | |||
Assert.AreEqual(v1.name, "foo/v:0"); | |||
var c1 = tf.constant(new int[] { 1 }, name: "c"); | |||
Assert.AreEqual(c1.name, "foo/c:0"); | |||
}); | |||
}); | |||
} | |||
[TestMethod] | |||