@@ -16,95 +16,51 @@ EndProject | |||||
Global | Global | ||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
Debug|x64 = Debug|x64 | |||||
Debug-Minimal|Any CPU = Debug-Minimal|Any CPU | Debug-Minimal|Any CPU = Debug-Minimal|Any CPU | ||||
Debug-Minimal|x64 = Debug-Minimal|x64 | |||||
Publish|Any CPU = Publish|Any CPU | Publish|Any CPU = Publish|Any CPU | ||||
Publish|x64 = Publish|x64 | |||||
Release|Any CPU = Release|Any CPU | Release|Any CPU = Release|Any CPU | ||||
Release|x64 = Release|x64 | |||||
EndGlobalSection | EndGlobalSection | ||||
GlobalSection(ProjectConfigurationPlatforms) = postSolution | GlobalSection(ProjectConfigurationPlatforms) = postSolution | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|x64 | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|x64 | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.Build.0 = Release|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.Build.0 = Release|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU | |||||
EndGlobalSection | EndGlobalSection | ||||
GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
@@ -43,7 +43,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public partial class c_api | public partial class c_api | ||||
{ | { | ||||
public const string TensorFlowLibName = "tensorflow"; | |||||
public const string TensorFlowLibName = @"D:\SciSharp\tensorflow-google\bazel-bin\tensorflow\tensorflow.dll"; | |||||
public static string StringPiece(IntPtr handle) | public static string StringPiece(IntPtr handle) | ||||
{ | { | ||||
@@ -0,0 +1,34 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public class EagerOperation : Operation | |||||
{ | |||||
public int NumInputs; | |||||
public Tensor[] Inputs { get; set; } | |||||
public EagerOperation() : base(IntPtr.Zero) { } | |||||
public override InputList inputs | |||||
{ | |||||
get | |||||
{ | |||||
if (_inputs_val == null) | |||||
{ | |||||
var retval = new Tensor[NumInputs]; | |||||
for (int i = 0; i < NumInputs; i++) | |||||
{ | |||||
} | |||||
_inputs_val = new InputList(Inputs); | |||||
} | |||||
return _inputs_val; | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -10,5 +10,8 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
public static explicit operator TFE_TensorHandle(EagerTensor tensor) | public static explicit operator TFE_TensorHandle(EagerTensor tensor) | ||||
=> tensor.tfe_tensor_handle; | => tensor.tfe_tensor_handle; | ||||
public static implicit operator IntPtr(EagerTensor tensor) | |||||
=> tensor.EagerTensorHandle; | |||||
} | } | ||||
} | } |
@@ -24,31 +24,10 @@ namespace Tensorflow.Eager | |||||
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
} | } | ||||
public EagerTensor(int value, string device_name) : base(value) | |||||
{ | |||||
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); | |||||
} | |||||
public EagerTensor(float value, string device_name) : base(value) | |||||
{ | |||||
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); | |||||
} | |||||
public EagerTensor(float[] value, string device_name) : base(value) | |||||
{ | |||||
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
} | |||||
public EagerTensor(double[] value, string device_name) : base(value) | |||||
{ | |||||
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | |||||
} | |||||
public EagerTensor(NDArray value, string device_name) : base(value) | public EagerTensor(NDArray value, string device_name) : base(value) | ||||
{ | { | ||||
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); | ||||
EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); | |||||
} | } | ||||
public override string ToString() | public override string ToString() | ||||
@@ -56,23 +35,24 @@ namespace Tensorflow.Eager | |||||
switch (rank) | switch (rank) | ||||
{ | { | ||||
case -1: | case -1: | ||||
return $"tf.Tensor: shape=<unknown>, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString()}"; | |||||
return $"tf.Tensor: shape=<unknown>, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
case 0: | case 0: | ||||
return $"tf.Tensor: shape=(), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString()}"; | |||||
return $"tf.Tensor: shape=(), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
default: | default: | ||||
return $"tf.Tensor: shape=({string.Join(",", shape)}), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString()}"; | |||||
return $"tf.Tensor: shape=({string.Join(",", shape)}), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; | |||||
} | } | ||||
} | } | ||||
private string GetFormattedString() | |||||
public static string GetFormattedString(TF_DataType dtype, NDArray nd) | |||||
{ | { | ||||
var nd = numpy(); | |||||
switch (dtype) | switch (dtype) | ||||
{ | { | ||||
case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
return $"b'{(string)nd}'"; | return $"b'{(string)nd}'"; | ||||
case TF_DataType.TF_BOOL: | case TF_DataType.TF_BOOL: | ||||
return (nd.GetByte(0) > 0).ToString(); | return (nd.GetByte(0) > 0).ToString(); | ||||
case TF_DataType.TF_RESOURCE: | |||||
return "<unprintable>"; | |||||
default: | default: | ||||
return nd.ToString(); | return nd.ToString(); | ||||
} | } | ||||
@@ -11,8 +11,18 @@ namespace Tensorflow | |||||
public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer); | public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer); | ||||
[UnmanagedFunctionPointer(CallingConvention.StdCall)] | [UnmanagedFunctionPointer(CallingConvention.StdCall)] | ||||
public delegate void _gradient_function_callback(string op_name, int num_inputs, IntPtr attrs, int num_attrs); | |||||
public delegate IntPtr _gradient_function_callback(string op_name, int num_inputs, IntPtr[] op_inputs, int num_attrs, IntPtr[] output_grads); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr VSpace_Handle(VSpace_callback_Ones ones, VSpace_callback_AggregateGrads aggregate_grads); | |||||
[UnmanagedFunctionPointer(CallingConvention.StdCall)] | |||||
public delegate IntPtr VSpace_callback_Ones(long[] shape, int dims, TF_DataType dtype); | |||||
[UnmanagedFunctionPointer(CallingConvention.StdCall)] | |||||
public delegate IntPtr VSpace_callback_AggregateGrads(IntPtr gradients, int num_grads); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TFE_RegisterVSpace(IntPtr vspace); | |||||
/// <summary> | /// <summary> | ||||
/// Return a new options object. | /// Return a new options object. | ||||
/// </summary> | /// </summary> | ||||
@@ -330,7 +340,10 @@ namespace Tensorflow | |||||
string name, | string name, | ||||
IntPtr[] args, | IntPtr[] args, | ||||
int input_size, | int input_size, | ||||
TFE_FastPathExecute_SetOpAttrs set_op_attrs, | |||||
IntPtr status); | IntPtr status); | ||||
[UnmanagedFunctionPointer(CallingConvention.StdCall)] | |||||
public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | ||||
@@ -342,7 +355,8 @@ namespace Tensorflow | |||||
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor); | public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TFE_TapeGradient(IntPtr tape, IntPtr[] target, int target_size, | |||||
public static extern IntPtr TFE_TapeGradient(IntPtr tape, | |||||
IntPtr[] target, int target_size, | |||||
IntPtr[] sources, int source_size, | IntPtr[] sources, int source_size, | ||||
IntPtr status); | IntPtr status); | ||||
} | } | ||||
@@ -11,6 +11,8 @@ namespace Tensorflow.Eager | |||||
public partial class wrap_tfe_src | public partial class wrap_tfe_src | ||||
{ | { | ||||
static int kFastPathExecuteInputStartIndex = 0; | static int kFastPathExecuteInputStartIndex = 0; | ||||
[Obsolete] | |||||
public static EagerTensor TFE_FastPathExecute(Context ctx, | public static EagerTensor TFE_FastPathExecute(Context ctx, | ||||
string device_name, | string device_name, | ||||
string opName, | string opName, | ||||
@@ -203,7 +205,7 @@ namespace Tensorflow.Eager | |||||
/// <param name="attr_value"></param> | /// <param name="attr_value"></param> | ||||
/// <param name="attr_list_sizes"></param> | /// <param name="attr_list_sizes"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
private static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, | |||||
public static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, | |||||
string attr_name, object attr_value, | string attr_name, object attr_value, | ||||
Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
Status status) | Status status) | ||||
@@ -74,12 +74,12 @@ namespace Tensorflow.Gradients | |||||
} | } | ||||
using var status = new Status(); | using var status = new Status(); | ||||
var et = c_api.TFE_TapeGradient(_tape, | |||||
new IntPtr[] { (target as EagerTensor).EagerTensorHandle }, 1, | |||||
new IntPtr[] { (sources as EagerTensor).EagerTensorHandle }, 1, | |||||
var et = c_api.TFE_TapeGradient(_tape, | |||||
new [] { (target as EagerTensor).EagerTensorHandle }, 1, | |||||
new [] { (sources as EagerTensor).EagerTensorHandle }, 1, | |||||
status); | status); | ||||
status.Check(true); | status.Check(true); | ||||
return et; | |||||
return new EagerTensor(et); | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -24,9 +24,9 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class ops | public partial class ops | ||||
{ | { | ||||
static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null; | |||||
public static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null; | |||||
private static void RegisterFromAssembly() | |||||
public static void RegisterFromAssembly() | |||||
{ | { | ||||
if (gradientFunctions == null) | if (gradientFunctions == null) | ||||
{ | { | ||||
@@ -40,9 +40,9 @@ namespace Tensorflow | |||||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | public int NumInputs => c_api.TF_OperationNumInputs(_handle); | ||||
private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); | private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); | ||||
private InputList _inputs_val; | |||||
protected InputList _inputs_val; | |||||
public InputList inputs | |||||
public virtual InputList inputs | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
@@ -152,8 +152,14 @@ namespace Tensorflow | |||||
{ | { | ||||
if(tf.context.executing_eagerly()) | if(tf.context.executing_eagerly()) | ||||
{ | { | ||||
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Pack", name, null, values, "axis", axis); | |||||
return _result; | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Pack", name, | |||||
values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 1, | |||||
(op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "axis", axis, null, status), | |||||
status); | |||||
status.Check(true); | |||||
return new EagerTensor(tensor); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); | var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); | ||||
@@ -41,6 +41,18 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor add_n(Tensor[] inputs, string name = null) | public static Tensor add_n(Tensor[] inputs, string name = null) | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
using var status = new Status(); | |||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"AddN", name, | |||||
inputs.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), inputs.Length, | |||||
null, | |||||
status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | |||||
var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); | var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -121,10 +133,18 @@ namespace Tensorflow | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Mean", name, null, | |||||
input, axis, "keep_dims", keep_dims); | |||||
return _result; | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Mean", name, | |||||
new IntPtr[] | |||||
{ | |||||
(input as EagerTensor).EagerTensorHandle, | |||||
(axis as EagerTensor).EagerTensorHandle | |||||
}, 2, | |||||
(op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "keep_dims", keep_dims, null, status), | |||||
status); | |||||
status.Check(true); | |||||
return new EagerTensor(tensor); | |||||
} | } | ||||
catch (Exception) | catch (Exception) | ||||
{ | { | ||||
@@ -196,17 +216,15 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
{ | |||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Add", name, new IntPtr[] | |||||
{ | |||||
(x as EagerTensor).EagerTensorHandle, | |||||
(y as EagerTensor).EagerTensorHandle | |||||
}, 2, status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | |||||
using var status = new Status(); | |||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Add", name, new IntPtr[] | |||||
{ | |||||
(x as EagerTensor).EagerTensorHandle, | |||||
(y as EagerTensor).EagerTensorHandle | |||||
}, 2, null, status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | ||||
@@ -574,10 +592,18 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Cast", name, null, | |||||
x, "DstT", DstT, "Truncate", Truncate); | |||||
return _result; | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Cast", name, | |||||
new IntPtr[] { (x as EagerTensor).EagerTensorHandle }, 1, | |||||
(op) => | |||||
{ | |||||
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "DstT", DstT, null, status); | |||||
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "Truncate", Truncate, null, status); | |||||
}, | |||||
status); | |||||
status.Check(true); | |||||
return new EagerTensor(tensor); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); | var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); | ||||
@@ -619,17 +645,15 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
{ | |||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Sub", name, new IntPtr[] | |||||
{ | |||||
(x as EagerTensor).EagerTensorHandle, | |||||
(y as EagerTensor).EagerTensorHandle | |||||
}, 2, status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | |||||
using var status = new Status(); | |||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Sub", name, new IntPtr[] | |||||
{ | |||||
(x as EagerTensor).EagerTensorHandle, | |||||
(y as EagerTensor).EagerTensorHandle | |||||
}, 2, null, status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | ||||
@@ -717,11 +741,11 @@ namespace Tensorflow | |||||
{ | { | ||||
using var status = new Status(); | using var status = new Status(); | ||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | ||||
"Mul", name, new IntPtr[] | |||||
{ | |||||
(x as EagerTensor).EagerTensorHandle, | |||||
(y as EagerTensor).EagerTensorHandle | |||||
}, 2, status); | |||||
"Mul", name, new IntPtr[] | |||||
{ | |||||
(x as EagerTensor).EagerTensorHandle, | |||||
(y as EagerTensor).EagerTensorHandle | |||||
}, 2, null, status); | |||||
status.Check(true); | status.Check(true); | ||||
return new EagerTensor(_result); | return new EagerTensor(_result); | ||||
} | } | ||||
@@ -757,17 +781,15 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
{ | |||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
using var status = new Status(); | |||||
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"RealDiv", name, new IntPtr[] | "RealDiv", name, new IntPtr[] | ||||
{ | { | ||||
(x as EagerTensor).EagerTensorHandle, | (x as EagerTensor).EagerTensorHandle, | ||||
(y as EagerTensor).EagerTensorHandle | (y as EagerTensor).EagerTensorHandle | ||||
}, 2, status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | |||||
}, 2, null, status); | |||||
status.Check(true); | |||||
return new EagerTensor(_result); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); | ||||
@@ -962,8 +984,16 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Range", name, null, start, limit, delta); | |||||
return _result; | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Range", name, new IntPtr[] | |||||
{ | |||||
(start as EagerTensor).EagerTensorHandle, | |||||
(limit as EagerTensor).EagerTensorHandle, | |||||
(delta as EagerTensor).EagerTensorHandle | |||||
}, 3, null, status); | |||||
status.Check(true); | |||||
return new EagerTensor(tensor); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("Range", name, new { start, limit, delta }); | var _op = _op_def_lib._apply_op_helper("Range", name, new { start, limit, delta }); | ||||
@@ -14,6 +14,8 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
using System.Linq; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -27,10 +29,16 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"AssignVariableOp", name, null, | |||||
resource, value); | |||||
return _result; | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"AssignVariableOp", name, | |||||
new[] | |||||
{ | |||||
(resource as EagerTensor).EagerTensorHandle, | |||||
(value as EagerTensor).EagerTensorHandle | |||||
}, 2, null, status); | |||||
status.Check(true); | |||||
return tensor; | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("AssignVariableOp", name, new { resource, value }); | var _op = _op_def_lib._apply_op_helper("AssignVariableOp", name, new { resource, value }); | ||||
@@ -42,10 +50,13 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"VarIsInitializedOp", name, null, | |||||
resource); | |||||
return _result; | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"VarIsInitializedOp", name, | |||||
new[] { (resource as EagerTensor).EagerTensorHandle }, | |||||
1, null, status); | |||||
status.Check(true); | |||||
return new EagerTensor(tensor); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); | var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); | ||||
@@ -67,10 +78,17 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"VarHandleOp", name, null, | |||||
"container", container, "shared_name", shared_name, "dtype", dtype, "shape", shape.dims); | |||||
return _result; | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"VarHandleOp", name, null, 0, op => | |||||
{ | |||||
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "container", container, null, status); | |||||
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shared_name", shared_name, null, status); | |||||
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status); | |||||
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shape", shape.dims, null, status); | |||||
}, status); | |||||
status.Check(true); | |||||
return new EagerTensor(tensor); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("VarHandleOp", name, new { | var _op = _op_def_lib._apply_op_helper("VarHandleOp", name, new { | ||||
@@ -94,10 +112,13 @@ namespace Tensorflow | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
{ | { | ||||
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"ReadVariableOp", name, null, | |||||
resource, "dtype", dtype); | |||||
return _result; | |||||
using var status = new Status(); | |||||
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"ReadVariableOp", name, new IntPtr[] { (resource as EagerTensor).EagerTensorHandle }, 1, | |||||
(op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status), | |||||
status); | |||||
status.Check(true); | |||||
return new EagerTensor(tensor); | |||||
} | } | ||||
var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new | var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new | ||||
@@ -31,7 +31,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | ||||
<Platforms>AnyCPU;x64</Platforms> | |||||
<Platforms>AnyCPU</Platforms> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
@@ -40,20 +40,10 @@ https://tensorflownet.readthedocs.io</Description> | |||||
<PlatformTarget>x64</PlatformTarget> | <PlatformTarget>x64</PlatformTarget> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||||
<DefineConstants>TRACE;DEBUG;SERIALIZABLE_</DefineConstants> | |||||
<PlatformTarget>x64</PlatformTarget> | |||||
</PropertyGroup> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> | |||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||||
</PropertyGroup> | |||||
<ItemGroup> | <ItemGroup> | ||||
<Compile Remove="Distribute\**" /> | <Compile Remove="Distribute\**" /> | ||||
<Compile Remove="Models\**" /> | <Compile Remove="Models\**" /> | ||||
@@ -127,6 +127,8 @@ namespace Tensorflow | |||||
return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
case float val: | case float val: | ||||
return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
case float[,] val: | |||||
return new EagerTensor(val, ctx.device_name); | |||||
case double val: | case double val: | ||||
return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
case float[] val: | case float[] val: | ||||
@@ -202,6 +202,7 @@ namespace Tensorflow | |||||
TF_DataType.TF_INT32 => "int32", | TF_DataType.TF_INT32 => "int32", | ||||
TF_DataType.TF_FLOAT => "float32", | TF_DataType.TF_FLOAT => "float32", | ||||
TF_DataType.TF_BOOL => "bool", | TF_DataType.TF_BOOL => "bool", | ||||
TF_DataType.TF_RESOURCE => "resource", | |||||
_ => type.ToString() | _ => type.ToString() | ||||
}; | }; | ||||
@@ -18,6 +18,7 @@ using Google.Protobuf; | |||||
using NumSharp; | using NumSharp; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -221,5 +222,10 @@ namespace Tensorflow | |||||
return array_ops.identity(value); | return array_ops.identity(value); | ||||
}); | }); | ||||
} | } | ||||
public override string ToString() | |||||
{ | |||||
return $"tf.Variable: '{name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}"; | |||||
} | |||||
} | } | ||||
} | } |
@@ -14,6 +14,9 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
using System.Linq; | |||||
using System.Runtime.InteropServices; | |||||
using System.Threading; | using System.Threading; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
@@ -39,6 +42,45 @@ namespace Tensorflow | |||||
public tensorflow() | public tensorflow() | ||||
{ | { | ||||
_constructThreadingObjects(); | _constructThreadingObjects(); | ||||
InitGradientEnvironment(); | |||||
} | |||||
private unsafe void InitGradientEnvironment() | |||||
{ | |||||
var vspace = c_api.VSpace_Handle((shape, dims, dtype) => | |||||
{ | |||||
var ones = constant_op.constant(1.0f, dtype: dtype) as EagerTensor; | |||||
return ones.EagerTensorHandle; | |||||
}, (gradients, num_grads) => | |||||
{ | |||||
var input_grads = new EagerTensor[num_grads]; | |||||
for (int i = 0; i < num_grads; i++) | |||||
input_grads[i] = new EagerTensor(*((IntPtr*)gradients + i)); | |||||
var add_n = gen_math_ops.add_n(input_grads); | |||||
return (add_n as EagerTensor).EagerTensorHandle; | |||||
}); | |||||
ops.RegisterFromAssembly(); | |||||
c_api.TFE_RegisterGradientFunction((op_name, num_inputs, op_inputs, num_attrs, output_grads) => | |||||
{ | |||||
var output_grad_tensors = output_grads.Select(x => new EagerTensor(x)).ToArray(); | |||||
var input_tensors = new EagerTensor[num_inputs]; | |||||
for (int i = 0; i < num_inputs; i++) | |||||
input_tensors[i] = new EagerTensor(op_inputs[op_inputs.Length == 1 ? 0 : i]); | |||||
var gradients = ops.gradientFunctions[op_name](new EagerOperation | |||||
{ | |||||
NumInputs = num_inputs, | |||||
Inputs = input_tensors | |||||
}, output_grad_tensors); | |||||
var ret_tensors = Marshal.AllocHGlobal(sizeof(IntPtr) * num_inputs); | |||||
Marshal.Copy(gradients.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 0, ret_tensors, 2); | |||||
// Marshal.FreeHGlobal(ret_tensors); | |||||
return ret_tensors; | |||||
}); | |||||
} | } | ||||
public ResourceVariable Variable<T>(T data, | public ResourceVariable Variable<T>(T data, | ||||
@@ -0,0 +1,27 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Tensorflow; | |||||
using static Tensorflow.Binding; | |||||
namespace TensorFlowNET.UnitTest.Gradient | |||||
{ | |||||
[TestClass] | |||||
public class GradientEagerTest : PythonTest | |||||
{ | |||||
[TestMethod] | |||||
public void ConstantSq() | |||||
{ | |||||
// Calcute the gradient of w * w | |||||
// by Automatic Differentiation in Eager mode | |||||
// in tensorflow.net 2.x that is in development intensively | |||||
var w = tf.constant(1.5f); | |||||
using var tape = tf.GradientTape(); | |||||
tape.watch(w); | |||||
var loss = w * w; | |||||
var grad = tape.gradient(loss, w); | |||||
print(grad); | |||||
} | |||||
} | |||||
} |
@@ -30,11 +30,10 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.6.1" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.1.1" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" /> | <PackageReference Include="MSTest.TestFramework" Version="2.1.1" /> | ||||
<PackageReference Include="NumSharp.Lite" Version="0.1.7" /> | <PackageReference Include="NumSharp.Lite" Version="0.1.7" /> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -7,7 +7,7 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.6.1" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.1.1" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" /> | <PackageReference Include="MSTest.TestFramework" Version="2.1.1" /> | ||||
<PackageReference Include="coverlet.collector" Version="1.2.1"> | <PackageReference Include="coverlet.collector" Version="1.2.1"> | ||||