diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 019bad25..3ac50b23 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -27,6 +29,10 @@ Global {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU + {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 47fda701..3d7ad4f4 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -40,30 +40,22 @@ namespace Tensorflow } - public virtual object run(Tensor fetches, Dictionary feed_dict = null) + public virtual object run(Tensor fetches, Dictionary feed_dict = null) { var result = _run(fetches, feed_dict); return result; } - private unsafe object _run(Tensor fetches, Dictionary feed_dict = null) + private unsafe object _run(Tensor fetches, Dictionary feed_dict = null) { - var feed_dict_tensor = new Dictionary(); + var feed_dict_tensor = new Dictionary(); if (feed_dict != null) { - NDArray np_val = null; foreach (var feed in feed_dict) { - switch (feed.Value) - { - case float value: - np_val = np.asarray(value); - break; - } - - feed_dict_tensor[feed.Key] = np_val; + feed_dict_tensor[feed.Key] = feed.Value; } } @@ -85,9 +77,9 @@ namespace Tensorflow return fetch_handler.build_results(null, results); } - private object[] _do_run(List fetch_list, Dictionary feed_dict) + private object[] _do_run(List fetch_list, Dictionary feed_dict) { - var feeds = feed_dict.Select(x => new KeyValuePair(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); + var feeds = feed_dict.Select(x => new KeyValuePair(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); return _call_tf_sessionrun(feeds, fetches); @@ -133,12 +125,14 @@ namespace Tensorflow case TF_DataType.TF_FLOAT: result[i] = *(float*)c_api.TF_TensorData(output_values[i]); break; + case TF_DataType.TF_INT16: + result[i] = *(short*)c_api.TF_TensorData(output_values[i]); + break; case TF_DataType.TF_INT32: result[i] = *(int*)c_api.TF_TensorData(output_values[i]); break; default: throw new NotImplementedException("can't get output"); - break; } } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 44f8f261..bb61453a 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; @@ -15,7 +16,7 @@ namespace Tensorflow private List _final_fetches = new List(); private List _targets = new List(); - public _FetchHandler(Graph graph, Tensor fetches, Dictionary feeds = null, object feed_handles = null) + public _FetchHandler(Graph graph, Tensor fetches, Dictionary feeds = null, object feed_handles = null) { _fetch_mapper = new _FetchMapper().for_fetch(fetches); foreach(var fetch in _fetch_mapper.unique_fetches()) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 9d4e3b20..bbc7011b 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -43,4 +43,8 @@ Docs: https://tensorflownet.readthedocs.io + + + + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 0c955505..e4907a81 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -74,6 +74,9 @@ namespace Tensorflow switch (nd.dtype.Name) { + case "Int16": + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); + break; case "Int32": Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; @@ -161,6 +164,8 @@ namespace Tensorflow { switch (type.Name) { + case "Int16": + return TF_DataType.TF_INT16; case "Int32": return TF_DataType.TF_INT32; case "Single": @@ -169,9 +174,9 @@ namespace Tensorflow return TF_DataType.TF_DOUBLE; case "String": return TF_DataType.TF_STRING; + default: + throw new NotImplementedException("ToTFDataType error"); } - - return TF_DataType.DtInvalid; } public void Dispose() diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs index c43effd2..b4bd92cf 100644 --- a/test/TensorFlowNET.Examples/BasicOperations.cs +++ b/test/TensorFlowNET.Examples/BasicOperations.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; using Tensorflow; @@ -43,11 +44,24 @@ namespace TensorFlowNET.Examples // Launch the default graph. using(sess = tf.Session()) { - // var feed_dict = new Dictionary + var feed_dict = new Dictionary(); + feed_dict.Add(a, (short)2); + feed_dict.Add(b, (short)3); // Run every operation with variable input - // Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}"); - // Console.WriteLine($"Multiplication with variables: {}"); + Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); + Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); } + + // ---------------- + // More in details: + // Matrix Multiplication from TensorFlow official tutorial + + // Create a Constant op that produces a 1x2 matrix. The op is + // added as a node to the default graph. + // + // The value returned by the constructor represents the output + // of the Constant op. + } } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 116a0b90..de4a6985 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -10,6 +10,7 @@ + diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 4a7fc054..0e0910ff 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp.Core; using System; using System.Collections.Generic; using System.Text; @@ -31,7 +32,7 @@ namespace TensorFlowNET.UnitTest using(var sess = tf.Session()) { - var feed_dict = new Dictionary(); + var feed_dict = new Dictionary(); feed_dict.Add(a, 3.0f); feed_dict.Add(b, 2.0f);