@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | ||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}" | |||||
EndProject | |||||
Global | Global | ||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
@@ -27,6 +29,10 @@ Global | |||||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
EndGlobalSection | EndGlobalSection | ||||
GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
@@ -40,30 +40,22 @@ namespace Tensorflow | |||||
} | } | ||||
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
public virtual object run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
{ | { | ||||
var result = _run(fetches, feed_dict); | var result = _run(fetches, feed_dict); | ||||
return result; | return result; | ||||
} | } | ||||
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
private unsafe object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||||
{ | { | ||||
var feed_dict_tensor = new Dictionary<Tensor, object>(); | |||||
var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||||
if (feed_dict != null) | if (feed_dict != null) | ||||
{ | { | ||||
NDArray np_val = null; | |||||
foreach (var feed in feed_dict) | foreach (var feed in feed_dict) | ||||
{ | { | ||||
switch (feed.Value) | |||||
{ | |||||
case float value: | |||||
np_val = np.asarray(value); | |||||
break; | |||||
} | |||||
feed_dict_tensor[feed.Key] = np_val; | |||||
feed_dict_tensor[feed.Key] = feed.Value; | |||||
} | } | ||||
} | } | ||||
@@ -85,9 +77,9 @@ namespace Tensorflow | |||||
return fetch_handler.build_results(null, results); | return fetch_handler.build_results(null, results); | ||||
} | } | ||||
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict) | |||||
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||||
{ | { | ||||
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||||
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | |||||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | ||||
return _call_tf_sessionrun(feeds, fetches); | return _call_tf_sessionrun(feeds, fetches); | ||||
@@ -133,12 +125,14 @@ namespace Tensorflow | |||||
case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
result[i] = *(float*)c_api.TF_TensorData(output_values[i]); | result[i] = *(float*)c_api.TF_TensorData(output_values[i]); | ||||
break; | break; | ||||
case TF_DataType.TF_INT16: | |||||
result[i] = *(short*)c_api.TF_TensorData(output_values[i]); | |||||
break; | |||||
case TF_DataType.TF_INT32: | case TF_DataType.TF_INT32: | ||||
result[i] = *(int*)c_api.TF_TensorData(output_values[i]); | result[i] = *(int*)c_api.TF_TensorData(output_values[i]); | ||||
break; | break; | ||||
default: | default: | ||||
throw new NotImplementedException("can't get output"); | throw new NotImplementedException("can't get output"); | ||||
break; | |||||
} | } | ||||
} | } | ||||
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using NumSharp.Core; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
@@ -15,7 +16,7 @@ namespace Tensorflow | |||||
private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
private List<object> _targets = new List<object>(); | private List<object> _targets = new List<object>(); | ||||
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null) | |||||
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||||
{ | { | ||||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | _fetch_mapper = new _FetchMapper().for_fetch(fetches); | ||||
foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
@@ -43,4 +43,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
<Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" /> | <Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
</ItemGroup> | |||||
</Project> | </Project> |
@@ -74,6 +74,9 @@ namespace Tensorflow | |||||
switch (nd.dtype.Name) | switch (nd.dtype.Name) | ||||
{ | { | ||||
case "Int16": | |||||
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||||
break; | |||||
case "Int32": | case "Int32": | ||||
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | ||||
break; | break; | ||||
@@ -161,6 +164,8 @@ namespace Tensorflow | |||||
{ | { | ||||
switch (type.Name) | switch (type.Name) | ||||
{ | { | ||||
case "Int16": | |||||
return TF_DataType.TF_INT16; | |||||
case "Int32": | case "Int32": | ||||
return TF_DataType.TF_INT32; | return TF_DataType.TF_INT32; | ||||
case "Single": | case "Single": | ||||
@@ -169,9 +174,9 @@ namespace Tensorflow | |||||
return TF_DataType.TF_DOUBLE; | return TF_DataType.TF_DOUBLE; | ||||
case "String": | case "String": | ||||
return TF_DataType.TF_STRING; | return TF_DataType.TF_STRING; | ||||
default: | |||||
throw new NotImplementedException("ToTFDataType error"); | |||||
} | } | ||||
return TF_DataType.DtInvalid; | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using NumSharp.Core; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow; | using Tensorflow; | ||||
@@ -43,11 +44,24 @@ namespace TensorFlowNET.Examples | |||||
// Launch the default graph. | // Launch the default graph. | ||||
using(sess = tf.Session()) | using(sess = tf.Session()) | ||||
{ | { | ||||
// var feed_dict = new Dictionary<string, > | |||||
var feed_dict = new Dictionary<Tensor, NDArray>(); | |||||
feed_dict.Add(a, (short)2); | |||||
feed_dict.Add(b, (short)3); | |||||
// Run every operation with variable input | // Run every operation with variable input | ||||
// Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}"); | |||||
// Console.WriteLine($"Multiplication with variables: {}"); | |||||
Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); | |||||
Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); | |||||
} | } | ||||
// ---------------- | |||||
// More in details: | |||||
// Matrix Multiplication from TensorFlow official tutorial | |||||
// Create a Constant op that produces a 1x2 matrix. The op is | |||||
// added as a node to the default graph. | |||||
// | |||||
// The value returned by the constructor represents the output | |||||
// of the Constant op. | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -10,6 +10,7 @@ | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -1,4 +1,5 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using NumSharp.Core; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
@@ -31,7 +32,7 @@ namespace TensorFlowNET.UnitTest | |||||
using(var sess = tf.Session()) | using(var sess = tf.Session()) | ||||
{ | { | ||||
var feed_dict = new Dictionary<Tensor, object>(); | |||||
var feed_dict = new Dictionary<Tensor, NDArray>(); | |||||
feed_dict.Add(a, 3.0f); | feed_dict.Add(a, 3.0f); | ||||
feed_dict.Add(b, 2.0f); | feed_dict.Add(b, 2.0f); | ||||