Browse Source

Multiplication with variables in BasicOperations example.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
72d25e058f
8 changed files with 50 additions and 24 deletions
  1. +6
    -0
      TensorFlow.NET.sln
  2. +9
    -15
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  3. +3
    -2
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  4. +4
    -0
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  5. +7
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +18
    -4
      test/TensorFlowNET.Examples/BasicOperations.cs
  7. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  8. +2
    -1
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 6
- 0
TensorFlow.NET.sln View File

@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -27,6 +29,10 @@ Global
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Debug|Any CPU.Build.0 = Debug|Any CPU
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EC622ADF-8DAE-474B-B18E-9598A4F91BA2}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 9
- 15
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -40,30 +40,22 @@ namespace Tensorflow
}

public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
public virtual object run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
{
var result = _run(fetches, feed_dict);

return result;
}

private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
private unsafe object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
{
var feed_dict_tensor = new Dictionary<Tensor, object>();
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();

if (feed_dict != null)
{
NDArray np_val = null;
foreach (var feed in feed_dict)
{
switch (feed.Value)
{
case float value:
np_val = np.asarray(value);
break;
}

feed_dict_tensor[feed.Key] = np_val;
feed_dict_tensor[feed.Key] = feed.Value;
}
}

@@ -85,9 +77,9 @@ namespace Tensorflow
return fetch_handler.build_results(null, results);
}

private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict)
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
{
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray();
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();

return _call_tf_sessionrun(feeds, fetches);
@@ -133,12 +125,14 @@ namespace Tensorflow
case TF_DataType.TF_FLOAT:
result[i] = *(float*)c_api.TF_TensorData(output_values[i]);
break;
case TF_DataType.TF_INT16:
result[i] = *(short*)c_api.TF_TensorData(output_values[i]);
break;
case TF_DataType.TF_INT32:
result[i] = *(int*)c_api.TF_TensorData(output_values[i]);
break;
default:
throw new NotImplementedException("can't get output");
break;
}
}



+ 3
- 2
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Text;

@@ -15,7 +16,7 @@ namespace Tensorflow
private List<Tensor> _final_fetches = new List<Tensor>();
private List<object> _targets = new List<object>();

public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null)
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, NDArray> feeds = null, object feed_handles = null)
{
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())


+ 4
- 0
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -43,4 +43,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
<Content CopyToOutputDirectory="PreserveNewest" Include="./runtimes/win-x64/native/tensorflow.dll" Link="tensorflow.dll" Pack="true" PackagePath="runtimes/win-x64/native/tensorflow.dll" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
</ItemGroup>

</Project>

+ 7
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -74,6 +74,9 @@ namespace Tensorflow

switch (nd.dtype.Name)
{
case "Int16":
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size);
break;
case "Int32":
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
break;
@@ -161,6 +164,8 @@ namespace Tensorflow
{
switch (type.Name)
{
case "Int16":
return TF_DataType.TF_INT16;
case "Int32":
return TF_DataType.TF_INT32;
case "Single":
@@ -169,9 +174,9 @@ namespace Tensorflow
return TF_DataType.TF_DOUBLE;
case "String":
return TF_DataType.TF_STRING;
default:
throw new NotImplementedException("ToTFDataType error");
}

return TF_DataType.DtInvalid;
}

public void Dispose()


+ 18
- 4
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
@@ -43,11 +44,24 @@ namespace TensorFlowNET.Examples
// Launch the default graph.
using(sess = tf.Session())
{
// var feed_dict = new Dictionary<string, >
var feed_dict = new Dictionary<Tensor, NDArray>();
feed_dict.Add(a, (short)2);
feed_dict.Add(b, (short)3);
// Run every operation with variable input
// Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}");
// Console.WriteLine($"Multiplication with variables: {}");
Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}");
Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}");
}

// ----------------
// More in details:
// Matrix Multiplication from TensorFlow official tutorial

// Create a Constant op that produces a 1x2 matrix. The op is
// added as a node to the default graph.
//
// The value returned by the constructor represents the output
// of the Constant op.

}
}
}

+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -10,6 +10,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>



+ 2
- 1
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Text;
@@ -31,7 +32,7 @@ namespace TensorFlowNET.UnitTest

using(var sess = tf.Session())
{
var feed_dict = new Dictionary<Tensor, object>();
var feed_dict = new Dictionary<Tensor, NDArray>();
feed_dict.Add(a, 3.0f);
feed_dict.Add(b, 2.0f);



Loading…
Cancel
Save