@@ -2,7 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow.Eager | |||||
namespace Tensorflow | |||||
{ | { | ||||
public class Context | public class Context | ||||
{ | { | ||||
@@ -152,6 +152,11 @@ namespace Tensorflow | |||||
return false; | return false; | ||||
} | } | ||||
public string get_name_scope() | |||||
{ | |||||
return _name_stack; | |||||
} | |||||
public string name_scope(string name) | public string name_scope(string name) | ||||
{ | { | ||||
string new_stack = ""; | string new_stack = ""; | ||||
@@ -10,7 +10,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public class OpDefLibrary | public class OpDefLibrary | ||||
{ | { | ||||
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||||
public Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||||
{ | { | ||||
var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
var op_def = g.GetOpDef(op_type_name); | var op_def = g.GetOpDef(op_type_name); | ||||
@@ -8,22 +8,62 @@ namespace Tensorflow | |||||
{ | { | ||||
public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
public Tensor _initial_value; | public Tensor _initial_value; | ||||
public string _graph_key; | |||||
public bool _trainable; | |||||
public Tensor _variable; | |||||
public RefVariable(object initial_value, | |||||
public RefVariable(object initial_value, | |||||
bool trainable = true, | |||||
List<string> collections = null, | |||||
bool validate_shape = true, | |||||
string caching_device = "", | |||||
string name = "", | string name = "", | ||||
TF_DataType trainable = TF_DataType.DtInvalid, | |||||
bool validate_shape = true) : | |||||
base(initial_value, name, trainable, validate_shape) | |||||
TF_DataType dtype = TF_DataType.DtInvalid) : | |||||
base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) | |||||
{ | { | ||||
_init_from_args(initial_value, name, trainable); | |||||
_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); | |||||
} | } | ||||
private void _init_from_args(object initial_value, | private void _init_from_args(object initial_value, | ||||
bool trainable = true, | |||||
List<string> collections = null, | |||||
bool validate_shape = true, | |||||
string caching_device = "", | |||||
string name = "", | string name = "", | ||||
TF_DataType trainable = TF_DataType.DtInvalid) | |||||
TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | { | ||||
name = ops.name_scope("", "Variable", initial_value); | |||||
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||||
if (initial_value is null) | |||||
throw new ValueError("initial_value must be specified."); | |||||
var init_from_fn = false; | |||||
if(collections == null) | |||||
{ | |||||
collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES }; | |||||
} | |||||
// Store the graph key so optimizers know how to only retrieve variables from | |||||
// this graph. | |||||
_graph_key = ops.get_default_graph()._graph_key; | |||||
_trainable = trainable; | |||||
if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) | |||||
collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); | |||||
ops.init_scope(); | |||||
name = new ops.name_scope(name, "Variable", init_from_fn ? new List<object>() : new List<object> { initial_value }); | |||||
if (init_from_fn) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||||
} | |||||
var shape = _initial_value.shape; | |||||
dtype = _initial_value.dtype; | |||||
_variable = gen_state_ops.variable_v2(shape, dtype, name); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -16,7 +16,13 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public class VariableV1 | public class VariableV1 | ||||
{ | { | ||||
public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) | |||||
public VariableV1(object initial_value, | |||||
bool trainable = true, | |||||
List<string> collections = null, | |||||
bool validate_shape = true, | |||||
string caching_device = "", | |||||
string name = "", | |||||
TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | { | ||||
} | } | ||||
@@ -0,0 +1,35 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class gen_state_ops | |||||
{ | |||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||||
/// <summary> | |||||
/// Holds state in the form of a tensor that persists across steps. | |||||
/// Outputs a ref to the tensor state so it may be read or modified. | |||||
/// </summary> | |||||
/// <param name="shape">The shape of the variable tensor.</param> | |||||
/// <param name="dtype">The type of elements in the variable tensor.</param> | |||||
/// <param name="name"></param> | |||||
/// <param name="container"></param> | |||||
/// <param name="shared_name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name = "", string container = "", string shared_name = "") | |||||
{ | |||||
var keywords = new Dictionary<string, object>(); | |||||
keywords.Add("dtype", dtype); | |||||
keywords.Add("shape", shape); | |||||
var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); | |||||
var _result = _op.outputs; | |||||
var _inputs_flat = _op.inputs; | |||||
return new Tensor(_op, 0, dtype); | |||||
} | |||||
} | |||||
} |
@@ -26,7 +26,9 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
return new RefVariable(initial_value); | |||||
return new RefVariable(initial_value, | |||||
name: name, | |||||
dtype: dtype); | |||||
} | } | ||||
} | } | ||||
@@ -12,7 +12,7 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public static object trainable_variables() | public static object trainable_variables() | ||||
{ | { | ||||
return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES); | |||||
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -15,12 +15,18 @@ namespace Tensorflow | |||||
/// specified, but it is also possible to pass an explicit list of | /// specified, but it is also possible to pass an explicit list of | ||||
/// variables. | /// variables. | ||||
/// </summary> | /// </summary> | ||||
public static class GraphKey | |||||
public static class GraphKeys | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// the subset of `Variable` objects that will be trained by an optimizer. | /// the subset of `Variable` objects that will be trained by an optimizer. | ||||
/// </summary> | /// </summary> | ||||
public static string TRAINABLE_VARIABLES = "trainable_variables"; | public static string TRAINABLE_VARIABLES = "trainable_variables"; | ||||
/// <summary> | |||||
/// Key to collect Variable objects that are global (shared across machines). | |||||
/// Default collection for all variables, except local ones. | |||||
/// </summary> | |||||
public static string GLOBAL_VARIABLES = "variables"; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,44 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class ops | |||||
{ | |||||
public class name_scope | |||||
{ | |||||
public string _name; | |||||
public string _default_name; | |||||
public object _values; | |||||
public Context _ctx; | |||||
public string _name_scope; | |||||
public name_scope(string name, string default_name, List<object> values) | |||||
{ | |||||
_name = name; | |||||
_default_name = default_name; | |||||
_values = values; | |||||
_ctx = new Context(); | |||||
_name_scope = __enter__(); | |||||
} | |||||
public string __enter__() | |||||
{ | |||||
if (String.IsNullOrEmpty(_name)) | |||||
{ | |||||
_name = _default_name; | |||||
} | |||||
var g = get_default_graph(); | |||||
return g.name_scope(_name); | |||||
} | |||||
public static implicit operator string(name_scope ns) | |||||
{ | |||||
return ns._name_scope; | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -97,20 +97,6 @@ namespace Tensorflow | |||||
return node_def; | return node_def; | ||||
} | } | ||||
public static string name_scope(string name, string default_name = "", object values = null) | |||||
{ | |||||
string _name = ""; | |||||
if (String.IsNullOrEmpty(name)) | |||||
{ | |||||
_name = default_name; | |||||
} | |||||
var g = get_default_graph(); | |||||
var _name_scope = g.name_scope(_name); | |||||
return _name_scope; | |||||
} | |||||
public static string _name_from_scope_name(string name) | public static string _name_from_scope_name(string name) | ||||
{ | { | ||||
if (name.EndsWith("/")) | if (name.EndsWith("/")) | ||||
@@ -123,6 +109,27 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
/// <summary> | |||||
/// A context manager that lifts ops out of control-flow scopes and function-building graphs. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static void init_scope() | |||||
{ | |||||
// Retrieve the active name scope: entering an `init_scope` preserves | |||||
// the name scope of the current context. | |||||
var default_graph = get_default_graph(); | |||||
var scope = default_graph.get_name_scope(); | |||||
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) | |||||
// Names that end with trailing slashes are treated by `name_scope` as | |||||
// absolute. | |||||
scope += "/"; | |||||
// inner_device_stack = default_graph._device_function_stack | |||||
// var outer_context = default_graph.as_default; | |||||
var outer_graph = get_default_graph(); | |||||
// outer_device_stack = None | |||||
} | |||||
public static int uid() | public static int uid() | ||||
{ | { | ||||
return 1; | return 1; | ||||
@@ -1,10 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Runtime.InteropServices; | |||||
using System.Text; | using System.Text; | ||||
using TF_DataType = Tensorflow.DataType; | |||||
using attr_value_pb2 = Tensorflow; | |||||
using Tensorflow.Eager; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||