diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 84a1aa7e..71b81d70 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{1FE7246F-9273-42A8-841C-98051356FB67}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -27,6 +29,10 @@ Global {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU + {1FE7246F-9273-42A8-841C-98051356FB67}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1FE7246F-9273-42A8-841C-98051356FB67}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1FE7246F-9273-42A8-841C-98051356FB67}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1FE7246F-9273-42A8-841C-98051356FB67}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/APIs/tf.constant.cs b/src/TensorFlowNET.Core/APIs/tf.constant.cs index d60fb50e..7d1e2a8e 100644 --- a/src/TensorFlowNET.Core/APIs/tf.constant.cs +++ b/src/TensorFlowNET.Core/APIs/tf.constant.cs @@ -9,7 +9,7 @@ namespace Tensorflow { public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) { - return constant_op.Create(nd, name, verify_shape); + return constant_op.Constant(nd, name, verify_shape); } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 48672e32..dc70344a 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -21,14 +21,14 @@ namespace Tensorflow return gen_math_ops.mul(x, y); } - public static Tensor pow(Tensor x, Tensor y) + public static Tensor pow(Tensor x, double y) { return gen_math_ops.pow(x, y); } - public static Tensor reduce_sum(Tensor input, int? axis = null) + public static Tensor reduce_sum(Tensor input, int[] axis = null) { - return gen_math_ops.sum(input, input); + return gen_math_ops.sum(input, axis); } } } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 9b9076c8..7c4d0c87 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; @@ -21,21 +22,30 @@ namespace Tensorflow string scope = g.unique_name(name) + "/"; + var default_type_attr_map = new Dictionary(); foreach (var attr_def in op_def.Attr) { if (attr_def.Type != "type") continue; var key = attr_def.Name; + if(attr_def.DefaultValue != null) + { + default_type_attr_map[key] = attr_def.DefaultValue.Type; + } } var attrs = new Dictionary(); - - // Perform input type inference var inputs = new List(); var input_types = new List(); - + + // Perform input type inference foreach (var input_arg in op_def.InputArg) { var input_name = input_arg.Name; + if (keywords[input_name] is double int_value) + { + keywords[input_name] = constant_op.Constant(int_value, input_name); + } + if (keywords[input_name] is Tensor value) { if (keywords.ContainsKey(input_name)) diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 66e2811c..f1963f38 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.IO; using System.Text; @@ -42,6 +43,17 @@ namespace Tensorflow return new Tensor(_op, 0, _op.OutputType(0)); } + public static Tensor real_div(Tensor x, Tensor y) + { + var keywords = new Dictionary(); + keywords.Add("x", x); + keywords.Add("y", y); + + var _op = _op_def_lib._apply_op_helper("RealDiv", name: "truediv", keywords: keywords); + + return new Tensor(_op, 0, _op.OutputType(0)); + } + public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) { var keywords = new Dictionary(); @@ -55,7 +67,7 @@ namespace Tensorflow return new Tensor(_op, 0, _op.OutputType(0)); } - public static Tensor pow(Tensor x, Tensor y) + public static Tensor pow(Tensor x, double y) { var keywords = new Dictionary(); keywords.Add("x", x); @@ -66,13 +78,15 @@ namespace Tensorflow return new Tensor(_op, 0, _op.OutputType(0)); } - public static Tensor sum(Tensor x, Tensor y) + public static Tensor sum(Tensor input, int[] axis = null) { + if(axis == null) axis = new int[0]; var keywords = new Dictionary(); - keywords.Add("x", x); - keywords.Add("y", y); + keywords.Add("input", input); + keywords.Add("reduction_indices", constant_op.Constant(axis)); + keywords.Add("keep_dims", false); - var _op = _op_def_lib._apply_op_helper("Pow", name: "Pow", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("Sum", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 66138150..3208c349 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -7,7 +7,7 @@ using System.Text; namespace Tensorflow { - public class BaseSession : IDisposable + public class BaseSession { protected Graph _graph; protected bool _opened; @@ -35,19 +35,12 @@ namespace Tensorflow c_api.TF_DeleteSessionOptions(opts); } - public void Dispose() + public virtual NDArray run(Tensor fetches, Dictionary feed_dict = null) { - + return _run(fetches, feed_dict); } - public virtual object run(Tensor fetches, Dictionary feed_dict = null) - { - var result = _run(fetches, feed_dict); - - return result; - } - - private object _run(Tensor fetches, Dictionary feed_dict = null) + private NDArray _run(Tensor fetches, Dictionary feed_dict = null) { var feed_dict_tensor = new Dictionary(); @@ -77,7 +70,7 @@ namespace Tensorflow return fetch_handler.build_results(null, results); } - private object[] _do_run(List fetch_list, Dictionary feed_dict) + private NDArray[] _do_run(List fetch_list, Dictionary feed_dict) { var feeds = feed_dict.Select(x => new KeyValuePair(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); @@ -85,7 +78,7 @@ namespace Tensorflow return _call_tf_sessionrun(feeds, fetches); } - private unsafe object[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list) + private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list) { // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); @@ -109,14 +102,12 @@ namespace Tensorflow status.Check(true); - object[] result = new object[fetch_list.Length]; + var result = new NDArray[fetch_list.Length]; for (int i = 0; i < fetch_list.Length; i++) { var tensor = new Tensor(output_values[i]); - Type type = tensor.dtype.as_numpy_datatype(); - var ndims = tensor.shape.Select(x => (int)x).ToArray(); - + switch (tensor.dtype) { case TF_DataType.TF_STRING: @@ -124,25 +115,25 @@ namespace Tensorflow // wired, don't know why we have to start from offset 9. var bytes = tensor.Data(); var output = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); - result[i] = fetchValue(tensor, ndims, output); + result[i] = fetchValue(tensor, output); } break; case TF_DataType.TF_FLOAT: { var output = *(float*)c_api.TF_TensorData(output_values[i]); - result[i] = fetchValue(tensor, ndims, output); + result[i] = fetchValue(tensor, output); } break; case TF_DataType.TF_INT16: { var output = *(short*)c_api.TF_TensorData(output_values[i]); - result[i] = fetchValue(tensor, ndims, output); + result[i] = fetchValue(tensor, output); } break; case TF_DataType.TF_INT32: { var output = *(int*)c_api.TF_TensorData(output_values[i]); - result[i] = fetchValue(tensor, ndims, output); + result[i] = fetchValue(tensor, output); } break; default: @@ -153,16 +144,22 @@ namespace Tensorflow return result; } - private object fetchValue(Tensor tensor, int[] ndims, T output) + private NDArray fetchValue(Tensor tensor, T output) { + NDArray nd; + Type type = tensor.dtype.as_numpy_datatype(); + var ndims = tensor.shape.Select(x => (int)x).ToArray(); + if (tensor.NDims == 0) { - return output; + nd = np.array(output).reshape(); } else { - return np.array(output).reshape(ndims); + nd = np.array(output).reshape(ndims); } + + return nd; } /// diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 70644aac..8cf84c52 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow { - public class Session : BaseSession + public class Session : BaseSession, IDisposable { private IntPtr _handle; public Status Status { get; } @@ -34,5 +34,12 @@ namespace Tensorflow public static implicit operator IntPtr(Session session) => session._handle; public static implicit operator Session(IntPtr handle) => new Session(handle); + + public void Dispose() + { + Options.Dispose(); + Status.Dispose(); + c_api.TF_DeleteSession(_handle, Status); + } } } diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index 908f516c..83068972 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; @@ -21,7 +22,7 @@ namespace Tensorflow } } - public object build_results(object[] values) + public NDArray build_results(NDArray[] values) { return values[0]; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index bb61453a..1da14507 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -35,7 +35,7 @@ namespace Tensorflow _final_fetches = _fetches; } - public object build_results(Session session, object[] results) + public NDArray build_results(Session session, NDArray[] results) { return _fetch_mapper.build_results(results); } diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 591de8a2..646ff4b2 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -7,12 +7,25 @@ namespace Tensorflow { public static partial class c_api { + /// + /// Destroy a session object. + /// + /// Even if error information is recorded in *status, this call discards all + /// local resources associated with the session. The session may not be used + /// during or after this call (and the session drops its reference to the + /// corresponding graph). + /// + /// TF_Session* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteSession(IntPtr session, IntPtr status); + /// /// Destroy an options object. /// /// TF_SessionOptions* [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_DeleteSessionOptions(IntPtr opts); + public static extern void TF_DeleteSessionOptions(IntPtr opts); /// /// Return a new execution session with the associated graph, or NULL on diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index cf49b32a..2dd1ee60 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -43,4 +43,8 @@ Docs: https://tensorflownet.readthedocs.io + + + + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs new file mode 100644 index 00000000..30bdf488 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class Tensor + { + public static implicit operator Tensor(double scalar) + { + return constant_op.Constant(scalar); + } + + public static implicit operator Tensor(int scalar) + { + return constant_op.Constant(scalar); + } + + public static implicit operator IntPtr(Tensor tensor) + { + return tensor._handle; + } + + public static implicit operator Tensor(IntPtr handle) + { + return new Tensor(handle); + } + + public static implicit operator Tensor(RefVariable var) + { + return var._initial_value; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index f2ac09fc..1025cf50 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -20,5 +20,10 @@ namespace Tensorflow { return gen_math_ops.mul(t1, t2); } + + public static Tensor operator /(Tensor t1, Tensor t2) + { + return gen_math_ops.real_div(t1, t2); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index dd37c8b2..d4bc910c 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -19,7 +19,7 @@ namespace Tensorflow public Operation op { get; } public string name; - public object value; + public int value_index { get; } private Status status = new Status(); @@ -90,7 +90,6 @@ namespace Tensorflow public Tensor(NDArray nd) { _handle = Allocate(nd); - value = nd.Data(); } private IntPtr Allocate(NDArray nd) @@ -205,30 +204,24 @@ namespace Tensorflow } } - public void Dispose() - { - c_api.TF_DeleteTensor(_handle); - status.Dispose(); - } - - public static implicit operator Tensor(int scalar) - { - return new Tensor(scalar); - } - - public static implicit operator IntPtr(Tensor tensor) + public override string ToString() { - return tensor._handle; - } + if(NDims == 0) + { + switch (dtype) + { + case TF_DataType.TF_INT32: + return Data()[0].ToString(); + } + } - public static implicit operator Tensor(IntPtr handle) - { - return new Tensor(handle); + return ""; } - public static implicit operator Tensor(RefVariable var) + public void Dispose() { - return var._initial_value; + c_api.TF_DeleteTensor(_handle); + status.Dispose(); } } } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 5298f78c..9d364b3c 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -19,7 +19,7 @@ namespace Tensorflow /// Optional name for the tensor. /// Boolean that enables verification of a shape of values. /// - public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape = false) + public static Tensor Constant(NDArray nd, string name = "Const", bool verify_shape = false) { Graph g = ops.get_default_graph(); var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); @@ -44,10 +44,7 @@ namespace Tensorflow attrs: attrs, name: name); - var const_tensor = op.outputs[0]; - const_tensor.value = nd.Data(); - - return const_tensor; + return op.outputs[0]; } } } diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs index 8d41b2d4..54e5c3a2 100644 --- a/test/TensorFlowNET.Examples/BasicOperations.cs +++ b/test/TensorFlowNET.Examples/BasicOperations.cs @@ -88,7 +88,7 @@ namespace TensorFlowNET.Examples { var result = sess.run(product); Console.WriteLine(result); - if((result as NDArray).Data()[0] != 12) + if(result.Data()[0] != 12) { throw new Exception("BasicOperations error"); } diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index bfa4c4d5..ce600097 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -26,6 +26,10 @@ namespace TensorFlowNET.Examples // Run the op var result = sess.run(hello); Console.WriteLine(result); + if(!result.ToString().Equals("Hello, TensorFlow!")) + { + throw new Exception("HelloWorld error"); + } } } } diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index 3199a923..19effb6f 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -40,8 +40,14 @@ namespace TensorFlowNET.Examples var pred = tf.add(part1, b); // Mean squared error - var pow = tf.pow(pred - Y, 2); - //var cost = tf.reduce_sum(pow) / (2 * n_samples); + var sub = pred - Y; + var pow = tf.pow(sub, 2); + var reduce = tf.reduce_sum(pow); + var cost = reduce / (2d * n_samples); + + // radient descent + // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default + // var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); } } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index b6b587eb..b18eb508 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -11,6 +11,7 @@ + diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index c70f584d..42c1b239 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -24,6 +24,7 @@ +