Browse Source

RefVariable, variable_scope

tags/v0.1.0-Tensor
Esther2013 6 years ago
parent
commit
71e1fe6299
12 changed files with 173 additions and 32 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Eager/Context.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  4. +48
    -8
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  5. +7
    -1
      src/TensorFlowNET.Core/Variables/VariableV1.cs
  6. +35
    -0
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  7. +3
    -1
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Variables/variables.py.cs
  9. +7
    -1
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  10. +44
    -0
      src/TensorFlowNET.Core/ops.name_scope.cs
  11. +21
    -14
      src/TensorFlowNET.Core/ops.py.cs
  12. +0
    -4
      src/TensorFlowNET.Core/tf.cs

+ 1
- 1
src/TensorFlowNET.Core/Eager/Context.cs View File

@@ -2,7 +2,7 @@
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
namespace Tensorflow
{
public class Context
{


+ 5
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -152,6 +152,11 @@ namespace Tensorflow
return false;
}

public string get_name_scope()
{
return _name_stack;
}

public string name_scope(string name)
{
string new_stack = "";


+ 1
- 1
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow
{
public class OpDefLibrary
{
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null)
public Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null)
{
var g = ops.get_default_graph();
var op_def = g.GetOpDef(op_type_name);


+ 48
- 8
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -8,22 +8,62 @@ namespace Tensorflow
{
public bool _in_graph_mode = true;
public Tensor _initial_value;
public string _graph_key;
public bool _trainable;
public Tensor _variable;

public RefVariable(object initial_value,
public RefVariable(object initial_value,
bool trainable = true,
List<string> collections = null,
bool validate_shape = true,
string caching_device = "",
string name = "",
TF_DataType trainable = TF_DataType.DtInvalid,
bool validate_shape = true) :
base(initial_value, name, trainable, validate_shape)
TF_DataType dtype = TF_DataType.DtInvalid) :
base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype)
{
_init_from_args(initial_value, name, trainable);
_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
}

private void _init_from_args(object initial_value,
bool trainable = true,
List<string> collections = null,
bool validate_shape = true,
string caching_device = "",
string name = "",
TF_DataType trainable = TF_DataType.DtInvalid)
TF_DataType dtype = TF_DataType.DtInvalid)
{
name = ops.name_scope("", "Variable", initial_value);
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
if (initial_value is null)
throw new ValueError("initial_value must be specified.");

var init_from_fn = false;

if(collections == null)
{
collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES };
}

// Store the graph key so optimizers know how to only retrieve variables from
// this graph.
_graph_key = ops.get_default_graph()._graph_key;

_trainable = trainable;
if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);

ops.init_scope();
name = new ops.name_scope(name, "Variable", init_from_fn ? new List<object>() : new List<object> { initial_value });
if (init_from_fn)
{

}
else
{
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
}

var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype, name);
}
}
}

+ 7
- 1
src/TensorFlowNET.Core/Variables/VariableV1.cs View File

@@ -16,7 +16,13 @@ namespace Tensorflow
/// </summary>
public class VariableV1
{
public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true)
public VariableV1(object initial_value,
bool trainable = true,
List<string> collections = null,
bool validate_shape = true,
string caching_device = "",
string name = "",
TF_DataType dtype = TF_DataType.DtInvalid)
{

}


+ 35
- 0
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class gen_state_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();

/// <summary>
/// Holds state in the form of a tensor that persists across steps.
/// Outputs a ref to the tensor state so it may be read or modified.
/// </summary>
/// <param name="shape">The shape of the variable tensor.</param>
/// <param name="dtype">The type of elements in the variable tensor.</param>
/// <param name="name"></param>
/// <param name="container"></param>
/// <param name="shared_name"></param>
/// <returns></returns>
public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name = "", string container = "", string shared_name = "")
{
var keywords = new Dictionary<string, object>();
keywords.Add("dtype", dtype);
keywords.Add("shape", shape);

var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords);

var _result = _op.outputs;
var _inputs_flat = _op.inputs;

return new Tensor(_op, 0, dtype);
}
}
}

+ 3
- 1
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -26,7 +26,9 @@ namespace Tensorflow
}
else
{
return new RefVariable(initial_value);
return new RefVariable(initial_value,
name: name,
dtype: dtype);
}
}



+ 1
- 1
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
/// <returns></returns>
public static object trainable_variables()
{
return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES);
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
}
}
}

src/TensorFlowNET.Core/ops.GraphKeys.py.cs → src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -15,12 +15,18 @@ namespace Tensorflow
/// specified, but it is also possible to pass an explicit list of
/// variables.
/// </summary>
public static class GraphKey
public static class GraphKeys
{
/// <summary>
/// the subset of `Variable` objects that will be trained by an optimizer.
/// </summary>
public static string TRAINABLE_VARIABLES = "trainable_variables";

/// <summary>
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
/// </summary>
public static string GLOBAL_VARIABLES = "variables";
}
}
}

+ 44
- 0
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public partial class ops
{
public class name_scope
{
public string _name;
public string _default_name;
public object _values;
public Context _ctx;
public string _name_scope;

public name_scope(string name, string default_name, List<object> values)
{
_name = name;
_default_name = default_name;
_values = values;
_ctx = new Context();

_name_scope = __enter__();
}

public string __enter__()
{
if (String.IsNullOrEmpty(_name))
{
_name = _default_name;
}

var g = get_default_graph();
return g.name_scope(_name);
}

public static implicit operator string(name_scope ns)
{
return ns._name_scope;
}
}
}
}

+ 21
- 14
src/TensorFlowNET.Core/ops.py.cs View File

@@ -97,20 +97,6 @@ namespace Tensorflow
return node_def;
}

public static string name_scope(string name, string default_name = "", object values = null)
{
string _name = "";
if (String.IsNullOrEmpty(name))
{
_name = default_name;
}

var g = get_default_graph();
var _name_scope = g.name_scope(_name);

return _name_scope;
}

public static string _name_from_scope_name(string name)
{
if (name.EndsWith("/"))
@@ -123,6 +109,27 @@ namespace Tensorflow
}
}

/// <summary>
/// A context manager that lifts ops out of control-flow scopes and function-building graphs.
/// </summary>
/// <returns></returns>
public static void init_scope()
{
// Retrieve the active name scope: entering an `init_scope` preserves
// the name scope of the current context.
var default_graph = get_default_graph();
var scope = default_graph.get_name_scope();
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
// Names that end with trailing slashes are treated by `name_scope` as
// absolute.
scope += "/";
// inner_device_stack = default_graph._device_function_stack
// var outer_context = default_graph.as_default;

var outer_graph = get_default_graph();
// outer_device_stack = None
}

public static int uid()
{
return 1;


+ 0
- 4
src/TensorFlowNET.Core/tf.cs View File

@@ -1,10 +1,6 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using TF_DataType = Tensorflow.DataType;
using attr_value_pb2 = Tensorflow;
using Tensorflow.Eager;

namespace Tensorflow
{


Loading…
Cancel
Save