@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" | ||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}" | |||||
EndProject | |||||
Global | Global | ||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
@@ -27,6 +29,10 @@ Global | |||||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{6ACED8FF-F08E-40E6-A75D-D01BAAA41072}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
EndGlobalSection | EndGlobalSection | ||||
GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
@@ -50,7 +50,7 @@ namespace Tensorflow | |||||
// Perform input type inference | // Perform input type inference | ||||
var inputs = new List<Tensor>(); | var inputs = new List<Tensor>(); | ||||
var input_types = new List<DataType>(); | |||||
var input_types = new List<TF_DataType>(); | |||||
foreach (var input_arg in op_def.InputArg) | foreach (var input_arg in op_def.InputArg) | ||||
{ | { | ||||
@@ -106,7 +106,7 @@ namespace Tensorflow | |||||
} | } | ||||
// Determine output types (possibly using attrs) | // Determine output types (possibly using attrs) | ||||
var output_types = new List<DataType>(); | |||||
var output_types = new List<TF_DataType>(); | |||||
foreach (var arg in op_def.OutputArg) | foreach (var arg in op_def.OutputArg) | ||||
{ | { | ||||
@@ -116,7 +116,7 @@ namespace Tensorflow | |||||
} | } | ||||
else if (!String.IsNullOrEmpty(arg.TypeAttr)) | else if (!String.IsNullOrEmpty(arg.TypeAttr)) | ||||
{ | { | ||||
output_types.Add(attr_protos[arg.TypeAttr].Type); | |||||
output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||||
} | } | ||||
} | } | ||||
@@ -24,7 +24,7 @@ namespace Tensorflow | |||||
var status = new Status(); | var status = new Status(); | ||||
var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name); | var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name); | ||||
c_api.TF_SetAttrType(desc, "dtype", DataType.DtInt32); | |||||
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); | |||||
c_api.TF_FinishOperation(desc, status.Handle); | c_api.TF_FinishOperation(desc, status.Handle); | ||||
} | } | ||||
@@ -39,7 +39,7 @@ namespace Tensorflow | |||||
_outputs = new Tensor[num_outputs]; | _outputs = new Tensor[num_outputs]; | ||||
for (int i = 0; i < num_outputs; i++) | for (int i = 0; i < num_outputs; i++) | ||||
{ | { | ||||
_outputs[i] = new Tensor(this, i, TF_DataType.DtFloat); | |||||
_outputs[i] = new Tensor(this, i, TF_DataType.TF_FLOAT); | |||||
} | } | ||||
_graph._add_op(this); | _graph._add_op(this); | ||||
@@ -2,6 +2,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -113,10 +114,9 @@ namespace Tensorflow | |||||
run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
status: status.Handle); | status: status.Handle); | ||||
var result = output_values.Select(x => new Tensor(x).buffer).Select(x => | |||||
{ | |||||
return (object)*(float*)x; | |||||
}).ToArray(); | |||||
var result = output_values.Select(x => c_api.TF_TensorData(x)) | |||||
.Select(x => (object)*(float*)x) | |||||
.ToArray(); | |||||
return result; | return result; | ||||
} | } | ||||
@@ -0,0 +1,34 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public enum TF_DataType | |||||
{ | |||||
TF_FLOAT = 1, | |||||
TF_DOUBLE = 2, | |||||
TF_INT32 = 3, // Int32 tensors are always in 'host' memory. | |||||
TF_UINT8 = 4, | |||||
TF_INT16 = 5, | |||||
TF_INT8 = 6, | |||||
TF_STRING = 7, | |||||
TF_COMPLEX64 = 8, // Single-precision complex | |||||
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility | |||||
TF_INT64 = 9, | |||||
TF_BOOL = 10, | |||||
TF_QINT8 = 11, // Quantized int8 | |||||
TF_QUINT8 = 12, // Quantized uint8 | |||||
TF_QINT32 = 13, // Quantized int32 | |||||
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. | |||||
TF_QINT16 = 15, // Quantized int16 | |||||
TF_QUINT16 = 16, // Quantized uint16 | |||||
TF_UINT16 = 17, | |||||
TF_COMPLEX128 = 18, // Double-precision complex | |||||
TF_HALF = 19, | |||||
TF_RESOURCE = 20, | |||||
TF_VARIANT = 21, | |||||
TF_UINT32 = 22, | |||||
TF_UINT64 = 23 | |||||
} | |||||
} |
@@ -0,0 +1,15 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
[StructLayout(LayoutKind.Sequential)] | |||||
public struct TF_Tensor | |||||
{ | |||||
public TF_DataType dtype; | |||||
public IntPtr shape; | |||||
public IntPtr buffer; | |||||
} | |||||
} |
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Runtime.InteropServices; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -10,8 +11,8 @@ namespace Tensorflow | |||||
public Operation op => _op; | public Operation op => _op; | ||||
private readonly int _value_index; | private readonly int _value_index; | ||||
public int value_index => _value_index; | public int value_index => _value_index; | ||||
private DataType _dtype; | |||||
public DataType dtype => _dtype; | |||||
private TF_DataType _dtype; | |||||
public TF_DataType dtype => _dtype; | |||||
public Graph graph => _op.graph; | public Graph graph => _op.graph; | ||||
@@ -19,14 +20,19 @@ namespace Tensorflow | |||||
private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
public IntPtr handle => _handle; | public IntPtr handle => _handle; | ||||
public IntPtr buffer => c_api.TF_TensorData(_handle); | |||||
private TF_Tensor tensor; | |||||
public IntPtr buffer => c_api.TF_TensorData(tensor.buffer); | |||||
public Tensor(IntPtr handle) | public Tensor(IntPtr handle) | ||||
{ | { | ||||
_handle = handle; | _handle = handle; | ||||
tensor = Marshal.PtrToStructure<TF_Tensor>(handle); | |||||
_dtype = tensor.dtype; | |||||
} | } | ||||
public Tensor(Operation op, int value_index, DataType dtype) | |||||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||||
{ | { | ||||
_op = op; | _op = op; | ||||
_value_index = value_index; | _value_index = value_index; |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class TensorBuffer | |||||
{ | |||||
} | |||||
} |
@@ -17,7 +17,10 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.6.1" /> | <PackageReference Include="Google.Protobuf" Version="3.6.1" /> | ||||
<PackageReference Include="NumSharp" Version="0.6.0" /> | |||||
</ItemGroup> | |||||
<ItemGroup> | |||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -7,12 +7,9 @@ using size_t = System.UIntPtr; | |||||
using TF_Graph = System.IntPtr; | using TF_Graph = System.IntPtr; | ||||
using TF_Operation = System.IntPtr; | using TF_Operation = System.IntPtr; | ||||
using TF_Status = System.IntPtr; | using TF_Status = System.IntPtr; | ||||
using TF_Tensor = System.IntPtr; | |||||
using TF_Session = System.IntPtr; | using TF_Session = System.IntPtr; | ||||
using TF_SessionOptions = System.IntPtr; | using TF_SessionOptions = System.IntPtr; | ||||
using TF_DataType = Tensorflow.DataType; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static class c_api | public static class c_api | ||||
@@ -54,8 +51,19 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static unsafe extern TF_Status TF_NewStatus(); | public static unsafe extern TF_Status TF_NewStatus(); | ||||
/// <summary> | |||||
/// Return a new tensor that holds the bytes data[0,len-1] | |||||
/// </summary> | |||||
/// <param name="dataType"></param> | |||||
/// <param name="dims"></param> | |||||
/// <param name="num_dims"></param> | |||||
/// <param name="data"></param> | |||||
/// <param name="len">num_bytes, ex: 6 * sizeof(float)</param> | |||||
/// <param name="deallocator"></param> | |||||
/// <param name="deallocator_arg"></param> | |||||
/// <returns></returns> | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, Int64 dims, int num_dims, IntPtr data, size_t len, tf.Deallocator deallocator, IntPtr deallocator_arg); | |||||
public static extern unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, size_t len, tf.Deallocator deallocator, IntPtr deallocator_arg); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe int TF_OperationNumOutputs(TF_Operation oper); | public static extern unsafe int TF_OperationNumOutputs(TF_Operation oper); | ||||
@@ -66,10 +74,25 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); | public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); | ||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="session"></param> | |||||
/// <param name="run_options"></param> | |||||
/// <param name="inputs"></param> | |||||
/// <param name="input_values"></param> | |||||
/// <param name="ninputs"></param> | |||||
/// <param name="outputs"></param> | |||||
/// <param name="output_values"></param> | |||||
/// <param name="noutputs"></param> | |||||
/// <param name="target_opers"></param> | |||||
/// <param name="ntargets"></param> | |||||
/// <param name="run_metadata"></param> | |||||
/// <param name="status"></param> | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_SessionRun(TF_Session session, IntPtr run_options, | public static extern unsafe void TF_SessionRun(TF_Session session, IntPtr run_options, | ||||
TF_Output[] inputs, TF_Tensor[] input_values, int ninputs, | |||||
TF_Output[] outputs, TF_Tensor[] output_values, int noutputs, | |||||
TF_Output[] inputs, IntPtr[] input_values, int ninputs, | |||||
TF_Output[] outputs, IntPtr[] output_values, int noutputs, | |||||
TF_Operation[] target_opers, int ntargets, | TF_Operation[] target_opers, int ntargets, | ||||
IntPtr run_metadata, | IntPtr run_metadata, | ||||
TF_Status status); | TF_Status status); | ||||
@@ -78,7 +101,10 @@ namespace Tensorflow | |||||
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe IntPtr TF_TensorData(TF_Tensor tensor); | |||||
public static extern unsafe IntPtr TF_TensorData(IntPtr tensor); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern unsafe TF_DataType TF_TensorType(IntPtr tensor); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); | public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); | ||||
@@ -10,7 +10,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); | public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); | ||||
public static Tensor placeholder(DataType dtype, TensorShape shape = null) | |||||
public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | |||||
{ | { | ||||
/*var g = ops.get_default_graph(); | /*var g = ops.get_default_graph(); | ||||
var op = new Operation(g, "Placeholder", "feed"); | var op = new Operation(g, "Placeholder", "feed"); | ||||
@@ -17,7 +17,7 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); | ||||
var tensor = new Tensor(_op, 0, DataType.DtFloat); | |||||
var tensor = new Tensor(_op, 0, TF_DataType.TF_FLOAT); | |||||
return tensor; | return tensor; | ||||
} | } | ||||
@@ -10,7 +10,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public static class tf | public static class tf | ||||
{ | { | ||||
public static DataType float32 = DataType.DtFloat; | |||||
public static TF_DataType float32 = TF_DataType.TF_FLOAT; | |||||
public static Context context = new Context(); | public static Context context = new Context(); | ||||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||||
return gen_math_ops.add(a, b); | return gen_math_ops.add(a, b); | ||||
} | } | ||||
public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) | |||||
public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | |||||
{ | { | ||||
return gen_array_ops.placeholder(dtype, shape); | return gen_array_ops.placeholder(dtype, shape); | ||||
} | } | ||||
@@ -42,7 +42,7 @@ namespace Tensorflow | |||||
var attrs = new Dictionary<string, AttrValue>(); | var attrs = new Dictionary<string, AttrValue>(); | ||||
attrs["dtype"] = dtype_value; | attrs["dtype"] = dtype_value; | ||||
attrs["value"] = tensor_value; | attrs["value"] = tensor_value; | ||||
var const_tensor = g.create_op("Const", null, new TF_DataType[] { dtype_value.Type }, attrs: attrs).outputs[0]; | |||||
var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0]; | |||||
return const_tensor; | return const_tensor; | ||||
} | } | ||||
@@ -55,7 +55,7 @@ namespace Tensorflow | |||||
public static Deallocator FreeTensorDataDelegate = FreeTensorData; | public static Deallocator FreeTensorDataDelegate = FreeTensorData; | ||||
[MonoPInvokeCallback(typeof(Deallocator))] | [MonoPInvokeCallback(typeof(Deallocator))] | ||||
internal static void FreeTensorData(IntPtr data, IntPtr len, IntPtr closure) | |||||
public static void FreeTensorData(IntPtr data, IntPtr len, IntPtr closure) | |||||
{ | { | ||||
Marshal.FreeHGlobal(data); | Marshal.FreeHGlobal(data); | ||||
} | } | ||||
@@ -5,10 +5,6 @@ | |||||
<TargetFramework>netcoreapp2.1</TargetFramework> | <TargetFramework>netcoreapp2.1</TargetFramework> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | |||||
<PackageReference Include="NumSharp" Version="0.6.0" /> | |||||
</ItemGroup> | |||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -6,14 +6,19 @@ | |||||
<IsPackable>false</IsPackable> | <IsPackable>false</IsPackable> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||||
</PropertyGroup> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | ||||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
<PackageReference Include="NumSharp" Version="0.6.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -0,0 +1,38 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using NumSharp.Core; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.UnitTest | |||||
{ | |||||
[TestClass] | |||||
public class TensorTest | |||||
{ | |||||
[TestMethod] | |||||
public unsafe void NewTF_Tensor() | |||||
{ | |||||
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||||
var data = Marshal.AllocHGlobal(sizeof(float) * nd.size); | |||||
Marshal.Copy(nd.Data<float>(), 0, data, nd.size); | |||||
var handle = c_api.TF_NewTensor(TF_DataType.TF_FLOAT, | |||||
nd.shape.Select(x => (long)x).ToArray(), // shape | |||||
nd.ndim, | |||||
data, | |||||
(UIntPtr)(nd.size * sizeof(float)), | |||||
tf.FreeTensorData, | |||||
IntPtr.Zero); | |||||
Assert.AreNotEqual(handle, IntPtr.Zero); | |||||
var tensor = new Tensor(handle); | |||||
Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT); | |||||
} | |||||
} | |||||
} |