@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}" | |||
EndProject | |||
Global | |||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
Debug|Any CPU = Debug|Any CPU | |||
@@ -27,6 +29,10 @@ Global | |||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.Build.0 = Release|Any CPU | |||
EndGlobalSection | |||
GlobalSection(SolutionProperties) = preSolution | |||
HideSolutionNode = FALSE | |||
@@ -40,30 +40,22 @@ namespace Tensorflow | |||
} | |||
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
public virtual object run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
{ | |||
var result = _run(fetches, feed_dict); | |||
return result; | |||
} | |||
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||
private unsafe object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null) | |||
{ | |||
var feed_dict_tensor = new Dictionary<Tensor, object>(); | |||
var feed_dict_tensor = new Dictionary<Tensor, NDArray>(); | |||
if (feed_dict != null) | |||
{ | |||
NDArray np_val = null; | |||
foreach (var feed in feed_dict) | |||
{ | |||
switch (feed.Value) | |||
{ | |||
case float value: | |||
np_val = np.asarray(value); | |||
break; | |||
} | |||
feed_dict_tensor[feed.Key] = np_val; | |||
feed_dict_tensor[feed.Key] = feed.Value; | |||
} | |||
} | |||
@@ -85,9 +77,9 @@ namespace Tensorflow | |||
return fetch_handler.build_results(null, results); | |||
} | |||
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict) | |||
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict) | |||
{ | |||
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray(); | |||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
return _call_tf_sessionrun(feeds, fetches); | |||
@@ -133,12 +125,14 @@ namespace Tensorflow | |||
case TF_DataType.TF_FLOAT: | |||
result[i] = *(float*)c_api.TF_TensorData(output_values[i]); | |||
break; | |||
case TF_DataType.TF_INT16: | |||
result[i] = *(short*)c_api.TF_TensorData(output_values[i]); | |||
break; | |||
case TF_DataType.TF_INT32: | |||
result[i] = *(int*)c_api.TF_TensorData(output_values[i]); | |||
break; | |||
default: | |||
throw new NotImplementedException("can't get output"); | |||
break; | |||
} | |||
} | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
@@ -15,7 +16,7 @@ namespace Tensorflow | |||
private List<Tensor> _final_fetches = new List<Tensor>(); | |||
private List<object> _targets = new List<object>(); | |||
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null) | |||
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null) | |||
{ | |||
_fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
@@ -43,4 +43,8 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
<Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
</ItemGroup> | |||
</Project> |
@@ -74,6 +74,9 @@ namespace Tensorflow | |||
switch (nd.dtype.Name) | |||
{ | |||
case "Int16": | |||
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Int32": | |||
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||
break; | |||
@@ -161,6 +164,8 @@ namespace Tensorflow | |||
{ | |||
switch (type.Name) | |||
{ | |||
case "Int16": | |||
return TF_DataType.TF_INT16; | |||
case "Int32": | |||
return TF_DataType.TF_INT32; | |||
case "Single": | |||
@@ -169,9 +174,9 @@ namespace Tensorflow | |||
return TF_DataType.TF_DOUBLE; | |||
case "String": | |||
return TF_DataType.TF_STRING; | |||
default: | |||
throw new NotImplementedException("ToTFDataType error"); | |||
} | |||
return TF_DataType.DtInvalid; | |||
} | |||
public void Dispose() | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow; | |||
@@ -43,11 +44,24 @@ namespace TensorFlowNET.Examples | |||
// Launch the default graph. | |||
using(sess = tf.Session()) | |||
{ | |||
// var feed_dict = new Dictionary<string, > | |||
var feed_dict = new Dictionary<Tensor, NDArray>(); | |||
feed_dict.Add(a, (short)2); | |||
feed_dict.Add(b, (short)3); | |||
// Run every operation with variable input | |||
// Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}"); | |||
// Console.WriteLine($"Multiplication with variables: {}"); | |||
Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); | |||
Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); | |||
} | |||
// ---------------- | |||
// More in details: | |||
// Matrix Multiplication from TensorFlow official tutorial | |||
// Create a Constant op that produces a 1x2 matrix. The op is | |||
// added as a node to the default graph. | |||
// | |||
// The value returned by the constructor represents the output | |||
// of the Constant op. | |||
} | |||
} | |||
} |
@@ -10,6 +10,7 @@ | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
</ItemGroup> | |||
@@ -1,4 +1,5 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
@@ -31,7 +32,7 @@ namespace TensorFlowNET.UnitTest | |||
using(var sess = tf.Session()) | |||
{ | |||
var feed_dict = new Dictionary<Tensor, object>(); | |||
var feed_dict = new Dictionary<Tensor, NDArray>(); | |||
feed_dict.Add(a, 3.0f); | |||
feed_dict.Add(b, 2.0f); | |||