Browse Source

Finished BasicOperations example.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
11216bfec1
4 changed files with 95 additions and 12 deletions
  1. +21
    -7
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +30
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  3. +10
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  4. +34
    -2
      test/TensorFlowNET.Examples/BasicOperations.cs

+ 21
- 7
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -114,22 +114,36 @@ namespace Tensorflow
for (int i = 0; i < fetch_list.Length; i++) for (int i = 0; i < fetch_list.Length; i++)
{ {
var tensor = new Tensor(output_values[i]); var tensor = new Tensor(output_values[i]);
Type type = tensor.dtype.as_numpy_datatype();
var ndims = tensor.shape.Select(x => (int)x).ToArray();

switch (tensor.dtype) switch (tensor.dtype)
{ {
case TF_DataType.TF_STRING: case TF_DataType.TF_STRING:
// wired, don't know why we have to start from offset 9.
var bytes = tensor.Data();
result[i] = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
{
// wired, don't know why we have to start from offset 9.
var bytes = tensor.Data();
var output = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims);
}
break; break;
case TF_DataType.TF_FLOAT: case TF_DataType.TF_FLOAT:
result[i] = *(float*)c_api.TF_TensorData(output_values[i]);
{
var output = *(float*)c_api.TF_TensorData(output_values[i]);
result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims);
}
break; break;
case TF_DataType.TF_INT16: case TF_DataType.TF_INT16:
result[i] = *(short*)c_api.TF_TensorData(output_values[i]);
{
var output = *(short*)c_api.TF_TensorData(output_values[i]);
result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims);
}
break; break;
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
result[i] = *(int*)c_api.TF_TensorData(output_values[i]);
{
var output = *(int*)c_api.TF_TensorData(output_values[i]);
result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims);
}
break; break;
default: default:
throw new NotImplementedException("can't get output"); throw new NotImplementedException("can't get output");


+ 30
- 3
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -22,6 +22,8 @@ namespace Tensorflow
public object value; public object value;
public int value_index { get; } public int value_index { get; }


private Status status = new Status();

private TF_DataType _dtype = TF_DataType.DtInvalid; private TF_DataType _dtype = TF_DataType.DtInvalid;
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
@@ -33,8 +35,17 @@ namespace Tensorflow
get get
{ {
var dims = new long[rank]; var dims = new long[rank];
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);

if (_handle == IntPtr.Zero)
{
c_api.TF_GraphGetTensorShape(op.Graph, _as_tf_output(), dims, rank, status);
status.Check();
}
else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}


return dims; return dims;
} }
@@ -48,7 +59,22 @@ namespace Tensorflow
/// 3 3-Tensor (cube of numbers) /// 3 3-Tensor (cube of numbers)
/// n n-Tensor (you get the idea) /// n n-Tensor (you get the idea)
/// </summary> /// </summary>
public int rank => _handle == IntPtr.Zero ? 0 : c_api.TF_NumDims(_handle);
public int rank
{
get
{
if (_handle == IntPtr.Zero)
{
var output = _as_tf_output();
return c_api.TF_GraphGetTensorNumDims(op.Graph, output, status);
}
else
{
return c_api.TF_NumDims(_handle);
}
}
}

public int NDims => rank; public int NDims => rank;


/// <summary> /// <summary>
@@ -182,6 +208,7 @@ namespace Tensorflow
public void Dispose() public void Dispose()
{ {
c_api.TF_DeleteTensor(_handle); c_api.TF_DeleteTensor(_handle);
status.Dispose();
} }


public static implicit operator IntPtr(Tensor tensor) public static implicit operator IntPtr(Tensor tensor)


+ 10
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -6,6 +6,16 @@ namespace Tensorflow
{ {
public static class dtypes public static class dtypes
{ {
public static Type as_numpy_datatype(this TF_DataType type)
{
switch (type)
{
case TF_DataType.TF_INT32:
return typeof(int);
default:
throw new NotImplementedException("as_numpy_datatype failed");
}
}
public static TF_DataType as_dtype(Type type) public static TF_DataType as_dtype(Type type)
{ {
TF_DataType dtype = TF_DataType.DtInvalid; TF_DataType dtype = TF_DataType.DtInvalid;


+ 34
- 2
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -19,7 +19,7 @@ namespace TensorFlowNET.Examples
// Basic constant operations // Basic constant operations
// The value returned by the constructor represents the output // The value returned by the constructor represents the output
// of the Constant op. // of the Constant op.
var a = tf.constant(2);
/*var a = tf.constant(2);
var b = tf.constant(3); var b = tf.constant(3);
// Launch the default graph. // Launch the default graph.
@@ -50,7 +50,7 @@ namespace TensorFlowNET.Examples
// Run every operation with variable input // Run every operation with variable input
Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}");
Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}");
}
}*/


// ---------------- // ----------------
// More in details: // More in details:
@@ -61,7 +61,39 @@ namespace TensorFlowNET.Examples
// //
// The value returned by the constructor represents the output // The value returned by the constructor represents the output
// of the Constant op. // of the Constant op.
var nd1 = np.array(3, 3).reshape(1, 2);
var matrix1 = tf.constant(nd1);

// Create another Constant that produces a 2x1 matrix.
var nd2 = np.array(2, 2).reshape(2, 1);
var matrix2 = tf.constant(nd2);

// Create a Matmul op that takes 'matrix1' and 'matrix2' as inputs.
// The returned value, 'product', represents the result of the matrix
// multiplication.
var product = tf.matmul(matrix1, matrix2);


// To run the matmul op we call the session 'run()' method, passing 'product'
// which represents the output of the matmul op. This indicates to the call
// that we want to get the output of the matmul op back.
//
// All inputs needed by the op are run automatically by the session. They
// typically are run in parallel.
//
// The call 'run(product)' thus causes the execution of threes ops in the
// graph: the two constants and matmul.
//
// The output of the op is returned in 'result' as a numpy `ndarray` object.
using (sess = tf.Session())
{
var result = sess.run(product);
Console.WriteLine(result);
if((result as NDArray).Data<int>()[0] != 12)
{
throw new Exception("BasicOperations error");
}
// ==> [[ 12.]]
}
} }
} }
} }

Loading…
Cancel
Save