@@ -114,22 +114,36 @@ namespace Tensorflow | |||
for (int i = 0; i < fetch_list.Length; i++) | |||
{ | |||
var tensor = new Tensor(output_values[i]); | |||
Type type = tensor.dtype.as_numpy_datatype(); | |||
var ndims = tensor.shape.Select(x => (int)x).ToArray(); | |||
switch (tensor.dtype) | |||
{ | |||
case TF_DataType.TF_STRING: | |||
// wired, don't know why we have to start from offset 9. | |||
var bytes = tensor.Data(); | |||
result[i] = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); | |||
{ | |||
// wired, don't know why we have to start from offset 9. | |||
var bytes = tensor.Data(); | |||
var output = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); | |||
result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims); | |||
} | |||
break; | |||
case TF_DataType.TF_FLOAT: | |||
result[i] = *(float*)c_api.TF_TensorData(output_values[i]); | |||
{ | |||
var output = *(float*)c_api.TF_TensorData(output_values[i]); | |||
result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims); | |||
} | |||
break; | |||
case TF_DataType.TF_INT16: | |||
result[i] = *(short*)c_api.TF_TensorData(output_values[i]); | |||
{ | |||
var output = *(short*)c_api.TF_TensorData(output_values[i]); | |||
result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims); | |||
} | |||
break; | |||
case TF_DataType.TF_INT32: | |||
result[i] = *(int*)c_api.TF_TensorData(output_values[i]); | |||
{ | |||
var output = *(int*)c_api.TF_TensorData(output_values[i]); | |||
result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims); | |||
} | |||
break; | |||
default: | |||
throw new NotImplementedException("can't get output"); | |||
@@ -22,6 +22,8 @@ namespace Tensorflow | |||
public object value; | |||
public int value_index { get; } | |||
private Status status = new Status(); | |||
private TF_DataType _dtype = TF_DataType.DtInvalid; | |||
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | |||
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | |||
@@ -33,8 +35,17 @@ namespace Tensorflow | |||
get | |||
{ | |||
var dims = new long[rank]; | |||
for (int i = 0; i < rank; i++) | |||
dims[i] = c_api.TF_Dim(_handle, i); | |||
if (_handle == IntPtr.Zero) | |||
{ | |||
c_api.TF_GraphGetTensorShape(op.Graph, _as_tf_output(), dims, rank, status); | |||
status.Check(); | |||
} | |||
else | |||
{ | |||
for (int i = 0; i < rank; i++) | |||
dims[i] = c_api.TF_Dim(_handle, i); | |||
} | |||
return dims; | |||
} | |||
@@ -48,7 +59,22 @@ namespace Tensorflow | |||
/// 3 3-Tensor (cube of numbers) | |||
/// n n-Tensor (you get the idea) | |||
/// </summary> | |||
public int rank => _handle == IntPtr.Zero ? 0 : c_api.TF_NumDims(_handle); | |||
public int rank | |||
{ | |||
get | |||
{ | |||
if (_handle == IntPtr.Zero) | |||
{ | |||
var output = _as_tf_output(); | |||
return c_api.TF_GraphGetTensorNumDims(op.Graph, output, status); | |||
} | |||
else | |||
{ | |||
return c_api.TF_NumDims(_handle); | |||
} | |||
} | |||
} | |||
public int NDims => rank; | |||
/// <summary> | |||
@@ -182,6 +208,7 @@ namespace Tensorflow | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteTensor(_handle); | |||
status.Dispose(); | |||
} | |||
public static implicit operator IntPtr(Tensor tensor) | |||
@@ -6,6 +6,16 @@ namespace Tensorflow | |||
{ | |||
public static class dtypes | |||
{ | |||
public static Type as_numpy_datatype(this TF_DataType type) | |||
{ | |||
switch (type) | |||
{ | |||
case TF_DataType.TF_INT32: | |||
return typeof(int); | |||
default: | |||
throw new NotImplementedException("as_numpy_datatype failed"); | |||
} | |||
} | |||
public static TF_DataType as_dtype(Type type) | |||
{ | |||
TF_DataType dtype = TF_DataType.DtInvalid; | |||
@@ -19,7 +19,7 @@ namespace TensorFlowNET.Examples | |||
// Basic constant operations | |||
// The value returned by the constructor represents the output | |||
// of the Constant op. | |||
var a = tf.constant(2); | |||
/*var a = tf.constant(2); | |||
var b = tf.constant(3); | |||
// Launch the default graph. | |||
@@ -50,7 +50,7 @@ namespace TensorFlowNET.Examples | |||
// Run every operation with variable input | |||
Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); | |||
Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); | |||
} | |||
}*/ | |||
// ---------------- | |||
// More in details: | |||
@@ -61,7 +61,39 @@ namespace TensorFlowNET.Examples | |||
// | |||
// The value returned by the constructor represents the output | |||
// of the Constant op. | |||
var nd1 = np.array(3, 3).reshape(1, 2); | |||
var matrix1 = tf.constant(nd1); | |||
// Create another Constant that produces a 2x1 matrix. | |||
var nd2 = np.array(2, 2).reshape(2, 1); | |||
var matrix2 = tf.constant(nd2); | |||
// Create a Matmul op that takes 'matrix1' and 'matrix2' as inputs. | |||
// The returned value, 'product', represents the result of the matrix | |||
// multiplication. | |||
var product = tf.matmul(matrix1, matrix2); | |||
// To run the matmul op we call the session 'run()' method, passing 'product' | |||
// which represents the output of the matmul op. This indicates to the call | |||
// that we want to get the output of the matmul op back. | |||
// | |||
// All inputs needed by the op are run automatically by the session. They | |||
// typically are run in parallel. | |||
// | |||
// The call 'run(product)' thus causes the execution of threes ops in the | |||
// graph: the two constants and matmul. | |||
// | |||
// The output of the op is returned in 'result' as a numpy `ndarray` object. | |||
using (sess = tf.Session()) | |||
{ | |||
var result = sess.run(product); | |||
Console.WriteLine(result); | |||
if((result as NDArray).Data<int>()[0] != 12) | |||
{ | |||
throw new Exception("BasicOperations error"); | |||
} | |||
// ==> [[ 12.]] | |||
} | |||
} | |||
} | |||
} |