Browse Source

unit test of how to reenter a premade variable scope safely

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
c976b818fc
6 changed files with 147 additions and 33 deletions
  1. +14
    -2
      src/TensorFlowNET.Core/APIs/tf.init.cs
  2. +39
    -11
      src/TensorFlowNET.Core/Variables/PureVariableScope.cs
  3. +6
    -5
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  4. +64
    -15
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  5. +6
    -0
      src/TensorFlowNET.Core/ops.py.cs
  6. +18
    -0
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 14
- 2
src/TensorFlowNET.Core/APIs/tf.init.cs View File

@@ -9,9 +9,21 @@ namespace Tensorflow
public static IInitializer zeros_initializer => new Zeros(); public static IInitializer zeros_initializer => new Zeros();
public static IInitializer glorot_uniform_initializer => new GlorotUniform(); public static IInitializer glorot_uniform_initializer => new GlorotUniform();
public static variable_scope variable_scope(string name_or_scope,
public static variable_scope variable_scope(string name,
string default_name = null, string default_name = null,
object values = null) => new variable_scope(name_or_scope, default_name, values);
object values = null,
bool auxiliary_name_scope = true) => new variable_scope(name,
default_name,
values,
auxiliary_name_scope);

public static variable_scope variable_scope(VariableScope scope,
string default_name = null,
object values = null,
bool auxiliary_name_scope = true) => new variable_scope(scope,
default_name,
values,
auxiliary_name_scope);


public class Zeros : IInitializer public class Zeros : IInitializer
{ {


+ 39
- 11
src/TensorFlowNET.Core/Variables/PureVariableScope.cs View File

@@ -6,7 +6,8 @@ namespace Tensorflow
{ {
public class PureVariableScope : IPython public class PureVariableScope : IPython
{ {
private string _name_or_scope;
private string _name;
private VariableScope _scope;
private string _new_name; private string _new_name;
private string _old_name_scope; private string _old_name_scope;
private bool _reuse; private bool _reuse;
@@ -14,29 +15,56 @@ namespace Tensorflow
private VariableScope _old; private VariableScope _old;
private _VariableScopeStore _var_scope_store; private _VariableScopeStore _var_scope_store;
private VariableScope variable_scope_object; private VariableScope variable_scope_object;
private VariableScope _cached_variable_scope_object;


public PureVariableScope(string name_or_scope,
public PureVariableScope(string name,
string old_name_scope = null, string old_name_scope = null,
TF_DataType dtype = TF_DataType.DtInvalid) TF_DataType dtype = TF_DataType.DtInvalid)
{ {
_name_or_scope = name_or_scope;
_name = name;
_old_name_scope = old_name_scope; _old_name_scope = old_name_scope;
_var_store = variable_scope._get_default_variable_store(); _var_store = variable_scope._get_default_variable_store();
_var_scope_store = variable_scope.get_variable_scope_store(); _var_scope_store = variable_scope.get_variable_scope_store();
} }


public void __enter__()
public PureVariableScope(VariableScope scope,
string old_name_scope = null,
TF_DataType dtype = TF_DataType.DtInvalid)
{ {
_old = _var_scope_store.current_scope;
_new_name = string.IsNullOrEmpty(_old.name) ? _name_or_scope : _old.name + "/" + _name_or_scope;
_reuse = _reuse || _old.resue;
string name_scope = _old_name_scope == null ? _name_or_scope : _old_name_scope;
variable_scope_object = new VariableScope(_reuse,
_scope = scope;
_old_name_scope = old_name_scope;
_var_store = variable_scope._get_default_variable_store();
_var_scope_store = variable_scope.get_variable_scope_store();
_new_name = _scope._name;

string name_scope = _scope._name_scope;
variable_scope_object = new VariableScope(_reuse,
name: _new_name, name: _new_name,
name_scope: name_scope); name_scope: name_scope);


_var_scope_store.open_variable_scope(_new_name);
_cached_variable_scope_object = variable_scope_object;
}

public void __enter__()
{
_old = _var_scope_store.current_scope;
if(_scope != null)
{
_var_scope_store.open_variable_scope(_new_name);
variable_scope_object = _cached_variable_scope_object;
}
else
{
_new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name;
_reuse = _reuse || _old.resue;
string name_scope = _old_name_scope == null ? _name : _old_name_scope;

variable_scope_object = new VariableScope(_reuse,
name: _new_name,
name_scope: name_scope);

_var_scope_store.open_variable_scope(_new_name);
}
_var_scope_store.current_scope = variable_scope_object; _var_scope_store.current_scope = variable_scope_object;
} }




+ 6
- 5
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -14,16 +14,17 @@ namespace Tensorflow
public bool resue; public bool resue;


private TF_DataType _dtype; private TF_DataType _dtype;
public string name { get; set; }
public string name_scope { get; set; }
public string _name { get; set; }
public string _name_scope { get; set; }
public string original_name_scope => _name_scope;


public VariableScope(bool reuse, public VariableScope(bool reuse,
string name = "", string name = "",
string name_scope = "", string name_scope = "",
TF_DataType dtype = TF_DataType.TF_FLOAT) TF_DataType dtype = TF_DataType.TF_FLOAT)
{ {
this.name = name;
this.name_scope = name_scope;
_name = name;
_name_scope = name_scope;
_reuse = _ReuseMode.AUTO_REUSE; _reuse = _ReuseMode.AUTO_REUSE;
_dtype = dtype; _dtype = dtype;
} }
@@ -37,7 +38,7 @@ namespace Tensorflow
VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation= VariableAggregation.NONE) VariableAggregation aggregation= VariableAggregation.NONE)
{ {
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;
string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name;
return with(new ops.name_scope(null), scope => return with(new ops.name_scope(null), scope =>
{ {
if (dtype == TF_DataType.DtInvalid) if (dtype == TF_DataType.DtInvalid)


+ 64
- 15
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow namespace Tensorflow
@@ -12,42 +13,83 @@ namespace Tensorflow


private bool _use_resource; private bool _use_resource;
public bool UseResource => _use_resource; public bool UseResource => _use_resource;
private string _name_or_scope;
private string _name;
private VariableScope _scope;
private string _default_name; private string _default_name;
private object _values; private object _values;
private string _current_name_scope;
private ops.name_scope _current_name_scope;
private bool _auxiliary_name_scope;
private PureVariableScope _cached_pure_variable_scope; private PureVariableScope _cached_pure_variable_scope;


public variable_scope(string name_or_scope, string default_name = "", object values = null)
public variable_scope(string name,
string default_name = "",
object values = null,
bool auxiliary_name_scope = true)
{ {
_name_or_scope = name_or_scope;
_name = name;
_default_name = default_name; _default_name = default_name;
_values = values; _values = values;
_current_name_scope = null; _current_name_scope = null;


_use_resource = false; _use_resource = false;
if (_default_name == null && _name_or_scope == null)
throw new TypeError("If default_name is None then name_or_scope is required");
if (_default_name == null && _name == null)
throw new TypeError("If default_name is None then name is required");

_auxiliary_name_scope = auxiliary_name_scope;
}

public variable_scope(VariableScope scope,
string default_name = "",
object values = null,
bool auxiliary_name_scope = true)
{
_scope = scope;
_default_name = default_name;
_values = values;
_current_name_scope = null;

_use_resource = false;
if (_default_name == null && _scope == null)
throw new TypeError("If default_name is None then scope is required");

_auxiliary_name_scope = auxiliary_name_scope;
} }


public void __enter__() public void __enter__()
{ {
_enter_scope_uncached();
_scope = _enter_scope_uncached();
} }


public VariableScope _enter_scope_uncached()
private VariableScope _enter_scope_uncached()
{ {
ops.name_scope current_name_scope = null;
if(_name_or_scope != null)
ops.name_scope current_name_scope;
if (_auxiliary_name_scope)
// Create a new name scope later
current_name_scope = null;
else
{ {
var name_scope = _name_or_scope;
// Reenter the current name scope
string name_scope = ops.get_name_scope();
if(!string.IsNullOrEmpty(name_scope))
// Hack to reenter
name_scope += "/";
current_name_scope = new ops.name_scope(name_scope);
}

if (_name != null || _scope != null)
{
var name_scope = _name == null ? _scope._name.Split('/').Last() : _name;
if (name_scope != null || current_name_scope != null) if (name_scope != null || current_name_scope != null)
current_name_scope = new ops.name_scope(name_scope); current_name_scope = new ops.name_scope(name_scope);
current_name_scope.__enter__(); current_name_scope.__enter__();
string current_name_scope_name = current_name_scope;
var current_name_scope_name = current_name_scope;
_current_name_scope = current_name_scope; _current_name_scope = current_name_scope;
string old_name_scope = current_name_scope_name; string old_name_scope = current_name_scope_name;
var pure_variable_scope = new PureVariableScope(_name_or_scope, old_name_scope: old_name_scope);
PureVariableScope pure_variable_scope = null;
if(_scope == null)
pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope);
else
pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope);
pure_variable_scope.__enter__(); pure_variable_scope.__enter__();
VariableScope entered_pure_variable_scope = pure_variable_scope; VariableScope entered_pure_variable_scope = pure_variable_scope;
_cached_pure_variable_scope = pure_variable_scope; _cached_pure_variable_scope = pure_variable_scope;
@@ -149,14 +191,21 @@ namespace Tensorflow
return trainable.Value; return trainable.Value;
} }


public static implicit operator VariableScope(variable_scope scope)
{
return scope._scope;
}

public void __exit__() public void __exit__()
{ {
if (_current_name_scope != null)
_current_name_scope.__exit__();
} }


public void Dispose() public void Dispose()
{ {
if (_current_name_scope != null)
_current_name_scope.Dispose();
} }
} }
} }

+ 6
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -475,5 +475,11 @@ namespace Tensorflow
return name; return name;
} }
} }

public static string get_name_scope()
{
var g = get_default_graph();
return g.get_name_scope();
}
} }
} }

+ 18
- 0
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -47,10 +47,28 @@ namespace TensorFlowNET.UnitTest
}); });
} }


/// <summary>
/// how to reenter a premade variable scope safely
/// </summary>
[TestMethod] [TestMethod]
public void ReenterVariableScope() public void ReenterVariableScope()
{ {
variable_scope vs = null;
with(tf.variable_scope("foo"), v => vs = v);


// Re-enter the variable scope.
with(tf.variable_scope(vs, auxiliary_name_scope: false), v =>
{
var vs1 = (VariableScope)v;
// Restore the original name_scope.
with(tf.name_scope(vs1.original_name_scope), delegate
{
var v1 = tf.get_variable("v", new TensorShape(1));
Assert.AreEqual(v1.name, "foo/v:0");
var c1 = tf.constant(new int[] { 1 }, name: "c");
Assert.AreEqual(c1.name, "foo/c:0");
});
});
} }


[TestMethod] [TestMethod]


Loading…
Cancel
Save