@@ -2,7 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Eager | |||
namespace Tensorflow | |||
{ | |||
public class Context | |||
{ | |||
@@ -152,6 +152,11 @@ namespace Tensorflow | |||
return false; | |||
} | |||
public string get_name_scope() | |||
{ | |||
return _name_stack; | |||
} | |||
public string name_scope(string name) | |||
{ | |||
string new_stack = ""; | |||
@@ -10,7 +10,7 @@ namespace Tensorflow | |||
{ | |||
public class OpDefLibrary | |||
{ | |||
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||
public Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||
{ | |||
var g = ops.get_default_graph(); | |||
var op_def = g.GetOpDef(op_type_name); | |||
@@ -8,22 +8,62 @@ namespace Tensorflow | |||
{ | |||
public bool _in_graph_mode = true; | |||
public Tensor _initial_value; | |||
public string _graph_key; | |||
public bool _trainable; | |||
public Tensor _variable; | |||
public RefVariable(object initial_value, | |||
public RefVariable(object initial_value, | |||
bool trainable = true, | |||
List<string> collections = null, | |||
bool validate_shape = true, | |||
string caching_device = "", | |||
string name = "", | |||
TF_DataType trainable = TF_DataType.DtInvalid, | |||
bool validate_shape = true) : | |||
base(initial_value, name, trainable, validate_shape) | |||
TF_DataType dtype = TF_DataType.DtInvalid) : | |||
base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) | |||
{ | |||
_init_from_args(initial_value, name, trainable); | |||
_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); | |||
} | |||
private void _init_from_args(object initial_value, | |||
bool trainable = true, | |||
List<string> collections = null, | |||
bool validate_shape = true, | |||
string caching_device = "", | |||
string name = "", | |||
TF_DataType trainable = TF_DataType.DtInvalid) | |||
TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
name = ops.name_scope("", "Variable", initial_value); | |||
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||
if (initial_value is null) | |||
throw new ValueError("initial_value must be specified."); | |||
var init_from_fn = false; | |||
if(collections == null) | |||
{ | |||
collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES }; | |||
} | |||
// Store the graph key so optimizers know how to only retrieve variables from | |||
// this graph. | |||
_graph_key = ops.get_default_graph()._graph_key; | |||
_trainable = trainable; | |||
if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) | |||
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
ops.init_scope(); | |||
name = new ops.name_scope(name, "Variable", init_from_fn ? new List<object>() : new List<object> { initial_value }); | |||
if (init_from_fn) | |||
{ | |||
} | |||
else | |||
{ | |||
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||
} | |||
var shape = _initial_value.shape; | |||
dtype = _initial_value.dtype; | |||
_variable = gen_state_ops.variable_v2(shape, dtype, name); | |||
} | |||
} | |||
} |
@@ -16,7 +16,13 @@ namespace Tensorflow | |||
/// </summary> | |||
public class VariableV1 | |||
{ | |||
public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) | |||
public VariableV1(object initial_value, | |||
bool trainable = true, | |||
List<string> collections = null, | |||
bool validate_shape = true, | |||
string caching_device = "", | |||
string name = "", | |||
TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
} | |||
@@ -0,0 +1,35 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class gen_state_ops | |||
{ | |||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
/// <summary> | |||
/// Holds state in the form of a tensor that persists across steps. | |||
/// Outputs a ref to the tensor state so it may be read or modified. | |||
/// </summary> | |||
/// <param name="shape">The shape of the variable tensor.</param> | |||
/// <param name="dtype">The type of elements in the variable tensor.</param> | |||
/// <param name="name"></param> | |||
/// <param name="container"></param> | |||
/// <param name="shared_name"></param> | |||
/// <returns></returns> | |||
public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name = "", string container = "", string shared_name = "") | |||
{ | |||
var keywords = new Dictionary<string, object>(); | |||
keywords.Add("dtype", dtype); | |||
keywords.Add("shape", shape); | |||
var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); | |||
var _result = _op.outputs; | |||
var _inputs_flat = _op.inputs; | |||
return new Tensor(_op, 0, dtype); | |||
} | |||
} | |||
} |
@@ -26,7 +26,9 @@ namespace Tensorflow | |||
} | |||
else | |||
{ | |||
return new RefVariable(initial_value); | |||
return new RefVariable(initial_value, | |||
name: name, | |||
dtype: dtype); | |||
} | |||
} | |||
@@ -12,7 +12,7 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static object trainable_variables() | |||
{ | |||
return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES); | |||
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
} | |||
} | |||
} |
@@ -15,12 +15,18 @@ namespace Tensorflow | |||
/// specified, but it is also possible to pass an explicit list of | |||
/// variables. | |||
/// </summary> | |||
public static class GraphKey | |||
public static class GraphKeys | |||
{ | |||
/// <summary> | |||
/// the subset of `Variable` objects that will be trained by an optimizer. | |||
/// </summary> | |||
public static string TRAINABLE_VARIABLES = "trainable_variables"; | |||
/// <summary> | |||
/// Key to collect Variable objects that are global (shared across machines). | |||
/// Default collection for all variables, except local ones. | |||
/// </summary> | |||
public static string GLOBAL_VARIABLES = "variables"; | |||
} | |||
} | |||
} |
@@ -0,0 +1,44 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public partial class ops | |||
{ | |||
public class name_scope | |||
{ | |||
public string _name; | |||
public string _default_name; | |||
public object _values; | |||
public Context _ctx; | |||
public string _name_scope; | |||
public name_scope(string name, string default_name, List<object> values) | |||
{ | |||
_name = name; | |||
_default_name = default_name; | |||
_values = values; | |||
_ctx = new Context(); | |||
_name_scope = __enter__(); | |||
} | |||
public string __enter__() | |||
{ | |||
if (String.IsNullOrEmpty(_name)) | |||
{ | |||
_name = _default_name; | |||
} | |||
var g = get_default_graph(); | |||
return g.name_scope(_name); | |||
} | |||
public static implicit operator string(name_scope ns) | |||
{ | |||
return ns._name_scope; | |||
} | |||
} | |||
} | |||
} |
@@ -97,20 +97,6 @@ namespace Tensorflow | |||
return node_def; | |||
} | |||
public static string name_scope(string name, string default_name = "", object values = null) | |||
{ | |||
string _name = ""; | |||
if (String.IsNullOrEmpty(name)) | |||
{ | |||
_name = default_name; | |||
} | |||
var g = get_default_graph(); | |||
var _name_scope = g.name_scope(_name); | |||
return _name_scope; | |||
} | |||
public static string _name_from_scope_name(string name) | |||
{ | |||
if (name.EndsWith("/")) | |||
@@ -123,6 +109,27 @@ namespace Tensorflow | |||
} | |||
} | |||
/// <summary> | |||
/// A context manager that lifts ops out of control-flow scopes and function-building graphs. | |||
/// </summary> | |||
/// <returns></returns> | |||
public static void init_scope() | |||
{ | |||
// Retrieve the active name scope: entering an `init_scope` preserves | |||
// the name scope of the current context. | |||
var default_graph = get_default_graph(); | |||
var scope = default_graph.get_name_scope(); | |||
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) | |||
// Names that end with trailing slashes are treated by `name_scope` as | |||
// absolute. | |||
scope += "/"; | |||
// inner_device_stack = default_graph._device_function_stack | |||
// var outer_context = default_graph.as_default; | |||
var outer_graph = get_default_graph(); | |||
// outer_device_stack = None | |||
} | |||
public static int uid() | |||
{ | |||
return 1; | |||
@@ -1,10 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using TF_DataType = Tensorflow.DataType; | |||
using attr_value_pb2 = Tensorflow; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow | |||
{ | |||