@@ -2,6 +2,7 @@ | |||||
TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). | TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). | ||||
[](https://gitter.im/Tensorflow-NET) | [](https://gitter.im/Tensorflow-NET) | ||||
 | |||||
TensorFlow.NET is a member project of SciSharp stack. | TensorFlow.NET is a member project of SciSharp stack. | ||||
@@ -10,11 +11,12 @@ TensorFlow.NET is a member project of SciSharp stack. | |||||
### How to use | ### How to use | ||||
Download the pre-compiled dll [here](tensorflowlib) and place it in the bin folder. | Download the pre-compiled dll [here](tensorflowlib) and place it in the bin folder. | ||||
Import tensorflow.net. | |||||
```cs | ```cs | ||||
// import tensorflow.net | |||||
using using Tensorflow; | |||||
using Tensorflow; | |||||
``` | ``` | ||||
Add two constants. | |||||
```cs | ```cs | ||||
// Create a Constant op | // Create a Constant op | ||||
var a = tf.constant(4.0f); | var a = tf.constant(4.0f); | ||||
@@ -27,6 +29,7 @@ using (var sess = tf.Session()) | |||||
} | } | ||||
``` | ``` | ||||
Feed placeholder. | |||||
```cs | ```cs | ||||
// Create a placeholder op | // Create a placeholder op | ||||
var a = tf.placeholder(tf.float32); | var a = tf.placeholder(tf.float32); | ||||
@@ -4,3 +4,28 @@ | |||||
### 表示一个操作的输出 | ### 表示一个操作的输出 | ||||
##### What is Tensor? | |||||
##### Tensor 是什么? | |||||
TF uses column major order. | |||||
TF 采用的是按列存储模式,如果我们用NumSharp产生一个2 X 3的矩阵,如果按顺序从0到5访问数据的话,是不会得到1-6的数字的,而是得到1,4, 2, 5, 3, 6这个顺序的一组数字。 | |||||
```cs | |||||
// generate a matrix:[[1, 2, 3], [4, 5, 6]] | |||||
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||||
// the index will be 0 2 4 1 3 5, it's column-major order. | |||||
``` | |||||
 | |||||
 |
@@ -22,4 +22,12 @@ TensorFlow is an open source project for machine learning especially for deep le | |||||
为了避免混淆,本书中对TensorFlow中定义的特有类不进行翻译,比如Tensor, Graph, Shape这些词都会保留英文名称。 | |||||
为了避免混淆,本书中对TensorFlow中定义的特有类不进行翻译,比如Tensor, Graph, Shape这些词都会保留英文名称。 | |||||
术语简称: | |||||
TF: Google TensorFlow | |||||
TF.NET: Tensorflow.NET |
@@ -7,50 +7,60 @@ namespace Tensorflow | |||||
{ | { | ||||
public class Tensor | public class Tensor | ||||
{ | { | ||||
private readonly Operation _op; | |||||
public Operation op => _op; | |||||
private readonly int _value_index; | |||||
public int value_index => _value_index; | |||||
private TF_DataType _dtype; | |||||
public TF_DataType dtype => _dtype; | |||||
public Operation op { get; } | |||||
public int value_index { get; } | |||||
public TF_DataType dtype { get; } | |||||
public Graph graph => _op.graph; | |||||
public Graph graph => op.graph; | |||||
public string name; | public string name; | ||||
private readonly IntPtr _handle; | |||||
public IntPtr handle => _handle; | |||||
private readonly int _ndim; | |||||
public int ndim => _ndim; | |||||
public IntPtr handle { get; } | |||||
public int ndim { get; } | |||||
public ulong bytesize { get; } | |||||
public ulong dataTypeSize { get;} | |||||
public ulong size => bytesize / dataTypeSize; | |||||
public IntPtr buffer { get; } | |||||
public Tensor(IntPtr handle) | public Tensor(IntPtr handle) | ||||
{ | { | ||||
_handle = handle; | |||||
_dtype = c_api.TF_TensorType(_handle); | |||||
_ndim = c_api.TF_NumDims(_handle); | |||||
this.handle = handle; | |||||
dtype = c_api.TF_TensorType(handle); | |||||
ndim = c_api.TF_NumDims(handle); | |||||
bytesize = c_api.TF_TensorByteSize(handle); | |||||
buffer = c_api.TF_TensorData(handle); | |||||
dataTypeSize = c_api.TF_DataTypeSize(dtype); | |||||
} | } | ||||
public Tensor(Operation op, int value_index, TF_DataType dtype) | public Tensor(Operation op, int value_index, TF_DataType dtype) | ||||
{ | { | ||||
_op = op; | |||||
_value_index = value_index; | |||||
_dtype = dtype; | |||||
this.op = op; | |||||
this.value_index = value_index; | |||||
this.dtype = dtype; | |||||
} | } | ||||
public TF_Output _as_tf_output() | public TF_Output _as_tf_output() | ||||
{ | { | ||||
return c_api_util.tf_output(_op._c_op, _value_index); | |||||
return c_api_util.tf_output(op._c_op, value_index); | |||||
} | |||||
public T[] Data<T>() | |||||
{ | |||||
var data = new T[size]; | |||||
for (ulong i = 0; i < size; i++) | |||||
{ | |||||
data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * dataTypeSize)); | |||||
} | |||||
return data; | |||||
} | } | ||||
public T Data<T>() | |||||
public byte[] Data() | |||||
{ | { | ||||
/*var buffer = new byte[6 * sizeof(float)]; | |||||
var h1 = c_api.TF_TensorData(handle); | |||||
var bytes = Marshal.PtrToStructure<float>(h1); | |||||
Marshal.Copy(h1, buffer, 0, 24);*/ | |||||
var data = new byte[bytesize]; | |||||
Marshal.Copy(buffer, data, 0, (int)bytesize); | |||||
return default(T); | |||||
return data; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -17,6 +17,14 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); | public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); | ||||
/// <summary> | |||||
/// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | |||||
/// </summary> | |||||
/// <param name="dt"></param> | |||||
/// <returns></returns> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static unsafe extern ulong TF_DataTypeSize(TF_DataType dt); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static unsafe extern void TF_DeleteSessionOptions(IntPtr opts); | public static unsafe extern void TF_DeleteSessionOptions(IntPtr opts); | ||||
@@ -104,6 +112,14 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | ||||
/// <summary> | |||||
/// Return the size of the underlying data in bytes. | |||||
/// </summary> | |||||
/// <param name="tensor"></param> | |||||
/// <returns></returns> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern unsafe ulong TF_TensorByteSize(IntPtr tensor); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe IntPtr TF_TensorData(IntPtr tensor); | public static extern unsafe IntPtr TF_TensorData(IntPtr tensor); | ||||
@@ -36,6 +36,15 @@ namespace TensorFlowNET.UnitTest | |||||
Assert.AreEqual(tensor.ndim, nd.ndim); | Assert.AreEqual(tensor.ndim, nd.ndim); | ||||
Assert.AreEqual(nd.shape[0], c_api.TF_Dim(handle, 0)); | Assert.AreEqual(nd.shape[0], c_api.TF_Dim(handle, 0)); | ||||
Assert.AreEqual(nd.shape[1], c_api.TF_Dim(handle, 1)); | Assert.AreEqual(nd.shape[1], c_api.TF_Dim(handle, 1)); | ||||
Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||||
// Column major order | |||||
// https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg | |||||
// matrix:[[1, 2, 3], [4, 5, 6]] | |||||
// index: 0 2 4 1 3 5 | |||||
// result: 1 4 2 5 3 6 | |||||
var array = tensor.Data<float>(); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | |||||
} | } | ||||
} | } | ||||
} | } |