Browse Source

Multiplication with variables in BasicOperations example.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
72d25e058f
8 changed files with 50 additions and 24 deletions
  1. +6
    -0
      TensorFlow.NET.sln
  2. +9
    -15
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  3. +3
    -2
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  4. +4
    -0
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  5. +7
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +18
    -4
      test/TensorFlowNET.Examples/BasicOperations.cs
  7. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  8. +2
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 6
- 0
TensorFlow.NET.sln View File

@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
EndProject EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}"
EndProject EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}"
EndProject
Global Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU Debug|Any CPU = Debug|Any CPU
@@ -27,6 +29,10 @@ Global
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.Build.0 = Debug|Any CPU
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection EndGlobalSection
GlobalSection(SolutionProperties) = preSolution GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE HideSolutionNode = FALSE


+ 9
- 15
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -40,30 +40,22 @@ namespace Tensorflow
} }


public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
public virtual object run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
{ {
var result = _run(fetches, feed_dict); var result = _run(fetches, feed_dict);


return result; return result;
} }


private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
private unsafe object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
{ {
var feed_dict_tensor = new Dictionary<Tensor, object>();
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();


if (feed_dict != null) if (feed_dict != null)
{ {
NDArray np_val = null;
foreach (var feed in feed_dict) foreach (var feed in feed_dict)
{ {
switch (feed.Value)
{
case float value:
np_val = np.asarray(value);
break;
}

feed_dict_tensor[feed.Key] = np_val;
feed_dict_tensor[feed.Key] = feed.Value;
} }
} }


@@ -85,9 +77,9 @@ namespace Tensorflow
return fetch_handler.build_results(null, results); return fetch_handler.build_results(null, results);
} }


private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict)
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
{ {
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray();
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();


return _call_tf_sessionrun(feeds, fetches); return _call_tf_sessionrun(feeds, fetches);
@@ -133,12 +125,14 @@ namespace Tensorflow
case TF_DataType.TF_FLOAT: case TF_DataType.TF_FLOAT:
result[i] = *(float*)c_api.TF_TensorData(output_values[i]); result[i] = *(float*)c_api.TF_TensorData(output_values[i]);
break; break;
case TF_DataType.TF_INT16:
result[i] = *(short*)c_api.TF_TensorData(output_values[i]);
break;
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
result[i] = *(int*)c_api.TF_TensorData(output_values[i]); result[i] = *(int*)c_api.TF_TensorData(output_values[i]);
break; break;
default: default:
throw new NotImplementedException("can't get output"); throw new NotImplementedException("can't get output");
break;
} }
} }




+ 3
- 2
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp.Core;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


@@ -15,7 +16,7 @@ namespace Tensorflow
private List<Tensor> _final_fetches = new List<Tensor>(); private List<Tensor> _final_fetches = new List<Tensor>();
private List<object> _targets = new List<object>(); private List<object> _targets = new List<object>();


public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null)
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null)
{ {
_fetch_mapper = new _FetchMapper().for_fetch(fetches); _fetch_mapper = new _FetchMapper().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches()) foreach(var fetch in _fetch_mapper.unique_fetches())


+ 4
- 0
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -43,4 +43,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
<Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" /> <Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" />
</ItemGroup> </ItemGroup>


<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
</ItemGroup>

</Project> </Project>

+ 7
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -74,6 +74,9 @@ namespace Tensorflow


switch (nd.dtype.Name) switch (nd.dtype.Name)
{ {
case "Int16":
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size);
break;
case "Int32": case "Int32":
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
break; break;
@@ -161,6 +164,8 @@ namespace Tensorflow
{ {
switch (type.Name) switch (type.Name)
{ {
case "Int16":
return TF_DataType.TF_INT16;
case "Int32": case "Int32":
return TF_DataType.TF_INT32; return TF_DataType.TF_INT32;
case "Single": case "Single":
@@ -169,9 +174,9 @@ namespace Tensorflow
return TF_DataType.TF_DOUBLE; return TF_DataType.TF_DOUBLE;
case "String": case "String":
return TF_DataType.TF_STRING; return TF_DataType.TF_STRING;
default:
throw new NotImplementedException("ToTFDataType error");
} }

return TF_DataType.DtInvalid;
} }


public void Dispose() public void Dispose()


+ 18
- 4
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp.Core;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;
@@ -43,11 +44,24 @@ namespace TensorFlowNET.Examples
// Launch the default graph. // Launch the default graph.
using(sess = tf.Session()) using(sess = tf.Session())
{ {
// var feed_dict = new Dictionary<string, >
var feed_dict = new Dictionary<Tensor, NDArray>();
feed_dict.Add(a, (short)2);
feed_dict.Add(b, (short)3);
// Run every operation with variable input // Run every operation with variable input
// Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}");
// Console.WriteLine($"Multiplication with variables: {}");
Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}");
Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}");
} }

// ----------------
// More in details:
// Matrix Multiplication from TensorFlow official tutorial

// Create a Constant op that produces a 1x2 matrix. The op is
// added as a node to the default graph.
//
// The value returned by the constructor represents the output
// of the Constant op.

} }
} }
} }

+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -10,6 +10,7 @@
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup> </ItemGroup>




+ 2
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp.Core;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
@@ -31,7 +32,7 @@ namespace TensorFlowNET.UnitTest


using(var sess = tf.Session()) using(var sess = tf.Session())
{ {
var feed_dict = new Dictionary<Tensor, object>();
var feed_dict = new Dictionary<Tensor, NDArray>();
feed_dict.Add(a, 3.0f); feed_dict.Add(a, 3.0f);
feed_dict.Add(b, 2.0f); feed_dict.Add(b, 2.0f);




Loading…
Cancel
Save