diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index 395410d1..f7eb3853 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -16,6 +16,18 @@ namespace Tensorflow return tf.Graph(); } + public static Tensor convert_to_tensor() + { + return internal_convert_to_tensor(); + } + + private static Tensor internal_convert_to_tensor() + { + return null; + } + + + public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) { var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); diff --git a/src/TensorFlowNET.Core/Tensors/RefVariable.cs b/src/TensorFlowNET.Core/Tensors/RefVariable.cs new file mode 100644 index 00000000..129d0618 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/RefVariable.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class RefVariable : Variable + { + public bool _in_graph_mode = true; + + public RefVariable(object initial_value, + TF_DataType trainable, + bool validate_shape = true) : + base(initial_value, trainable, validate_shape) + { + + } + + private void _init_from_args() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Variable.cs b/src/TensorFlowNET.Core/Tensors/Variable.cs index 883c9a44..19253ce1 100644 --- a/src/TensorFlowNET.Core/Tensors/Variable.cs +++ b/src/TensorFlowNET.Core/Tensors/Variable.cs @@ -4,7 +4,21 @@ using System.Text; namespace Tensorflow { + /// + /// A variable maintains state in the graph across calls to `run()`. You add a + /// variable to the graph by constructing an instance of the class `Variable`. + /// + /// The `Variable()` constructor requires an initial value for the variable, + /// which can be a `Tensor` of any type and shape. The initial value defines the + /// type and shape of the variable. After construction, the type and shape of + /// the variable are fixed. The value can be changed using one of the assign methods. + /// https://tensorflow.org/guide/variables + /// public class Variable { + public Variable(object initial_value, TF_DataType trainable, bool validate_shape = true) + { + + } } } diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 95850312..09783aa1 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -11,11 +11,17 @@ namespace Tensorflow public static class tf { public static TF_DataType float32 = TF_DataType.TF_FLOAT; + public static TF_DataType chars = TF_DataType.TF_STRING; public static Context context = new Context(); public static Graph g = new Graph(c_api.TF_NewGraph()); + public static object Variable(T data, TF_DataType dtype) + { + return new Variable(null, TF_DataType.DtInvalid); + } + public static unsafe Tensor add(Tensor a, Tensor b) { return gen_math_ops.add(a, b); diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs new file mode 100644 index 00000000..a761d93e --- /dev/null +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -0,0 +1,17 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class VariableTest + { + public void Creating() + { + var mammal = tf.Variable("Elephant", tf.chars); + } + } +}