Browse Source

Tensor.eval() #145

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
1b7003b2d3
11 changed files with 103 additions and 19 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  3. +11
    -2
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  4. +35
    -0
      src/TensorFlowNET.Core/Operations/math_ops.py.cs
  5. +7
    -3
      src/TensorFlowNET.Core/Sessions/Session.cs
  6. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
  7. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Tensors/tf.constant.cs
  10. +26
    -4
      src/TensorFlowNET.Core/ops.py.cs
  11. +14
    -0
      test/TensorFlowNET.UnitTest/SessionTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -51,7 +51,7 @@ namespace Tensorflow
var input_name = input_arg.Name; var input_name = input_arg.Name;
if (keywords[input_name] is double int_value) if (keywords[input_name] is double int_value)
{ {
keywords[input_name] = constant_op.Constant(int_value, input_name);
keywords[input_name] = constant_op.constant(int_value, input_name);
} }


if (keywords[input_name] is Tensor value) if (keywords[input_name] is Tensor value)


+ 2
- 2
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -43,12 +43,12 @@ namespace Tensorflow
var nd = np.zeros<T>(shape); var nd = np.zeros<T>(shape);
if (shape.Size < 1000) if (shape.Size < 1000)
{ {
return constant_op.Constant(nd, name);
return constant_op.constant(nd, name);
} }
else else
{ {
tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape());
var c = constant_op.Constant(0);
var c = constant_op.constant(0);
return gen_array_ops.fill(tShape, c, name); return gen_array_ops.fill(tShape, c, name);
} }
} }


+ 11
- 2
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -38,9 +38,18 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }


public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
/// <summary>
/// Multiply the matrix "a" by the matrix "b".
/// </summary>
/// <param name="a"></param>
/// <param name="b"></param>
/// <param name="transpose_a"></param>
/// <param name="transpose_b"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string name = "")
{ {
var _op = _op_def_lib._apply_op_helper("MatMul", args: new { a, b, transpose_a, transpose_b });
var _op = _op_def_lib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b });


return _op.outputs[0]; return _op.outputs[0];
} }


+ 35
- 0
src/TensorFlowNET.Core/Operations/math_ops.py.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class math_ops
{
public static Tensor matmul(Tensor a, Tensor b,
bool transpose_a = false, bool transpose_b = false,
bool adjoint_a = false, bool adjoint_b = false,
bool a_is_sparse = false, bool b_is_sparse = false,
string name = "")
{
Tensor result = null;

Python.with<ops.name_scope>(new ops.name_scope(name, "MatMul", new Tensor[] { a, b }), scope =>
{
name = scope;

if (transpose_a && adjoint_a)
throw new ValueError("Only one of transpose_a and adjoint_a can be True.");
if (transpose_b && adjoint_b)
throw new ValueError("Only one of transpose_b and adjoint_b can be True.");

a = ops.convert_to_tensor(a, name: "a");
b = ops.convert_to_tensor(b, name: "b");

result = gen_math_ops.mat_mul(a, b, transpose_a, transpose_b, name);
});

return result;
}
}
}

+ 7
- 3
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -7,16 +7,17 @@ namespace Tensorflow
public class Session : BaseSession, IPython public class Session : BaseSession, IPython
{ {
private IntPtr _handle; private IntPtr _handle;
public Status Status { get; }
public Status Status = new Status();
public SessionOptions Options { get; } public SessionOptions Options { get; }
public Graph graph;


public Session(string target = "", Graph graph = null) public Session(string target = "", Graph graph = null)
{ {
Status = new Status();
if(graph == null) if(graph == null)
{ {
graph = tf.get_default_graph(); graph = tf.get_default_graph();
} }
this.graph = graph;
Options = new SessionOptions(); Options = new SessionOptions();
_handle = c_api.TF_NewSession(graph, Options, Status); _handle = c_api.TF_NewSession(graph, Options, Status);
Status.Check(); Status.Check();
@@ -27,9 +28,12 @@ namespace Tensorflow
_handle = handle; _handle = handle;
} }


public Session(Graph graph, SessionOptions opts, Status s)
public Session(Graph graph, SessionOptions opts, Status s = null)
{ {
if (s == null)
s = Status;
_handle = c_api.TF_NewSession(graph, opts, s); _handle = c_api.TF_NewSession(graph, opts, s);
Status.Check(true);
} }


public static implicit operator IntPtr(Session session) => session._handle; public static implicit operator IntPtr(Session session) => session._handle;


+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs View File

@@ -8,12 +8,12 @@ namespace Tensorflow
{ {
public static implicit operator Tensor(double scalar) public static implicit operator Tensor(double scalar)
{ {
return constant_op.Constant(scalar);
return constant_op.constant(scalar);
} }


public static implicit operator Tensor(int scalar) public static implicit operator Tensor(int scalar)
{ {
return constant_op.Constant(scalar);
return constant_op.constant(scalar);
} }


public static implicit operator int(Tensor tensor) public static implicit operator int(Tensor tensor)


+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -213,9 +213,9 @@ namespace Tensorflow
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param> /// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns></returns> /// <returns></returns>
public NDArray eval(dynamic feed_dict = null, Session session = null)
public NDArray eval(Dictionary<Tensor, NDArray> feed_dict = null, Session session = null)
{ {
return ops._eval_using_default_session(new Tensor[] { this }, feed_dict, Graph, session)[0];
return ops._eval_using_default_session(this, feed_dict, Graph, session);
} }


public TF_DataType ToTFDataType(Type type) public TF_DataType ToTFDataType(Type type)


+ 2
- 2
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow
/// <param name="name">Optional name for the tensor.</param> /// <param name="name">Optional name for the tensor.</param>
/// <param name="verify_shape">Boolean that enables verification of a shape of values.</param> /// <param name="verify_shape">Boolean that enables verification of a shape of values.</param>
/// <returns></returns> /// <returns></returns>
public static Tensor Constant(NDArray nd, string name = "Const", bool verify_shape = false)
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
{ {
Graph g = ops.get_default_graph(); Graph g = ops.get_default_graph();
var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape);
@@ -76,7 +76,7 @@ namespace Tensorflow
if (string.IsNullOrEmpty(name)) if (string.IsNullOrEmpty(name))
name = "shape_as_tensor"; name = "shape_as_tensor";


return constant_op.Constant(s_list, name);
return constant_op.constant(s_list, name);
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow
{ {
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
{ {
return constant_op.Constant(nd, name, verify_shape);
return constant_op.constant(nd, name, verify_shape);
} }


public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "")


+ 26
- 4
src/TensorFlowNET.Core/ops.py.cs View File

@@ -75,7 +75,7 @@ namespace Tensorflow
return val; return val;
default: default:
var nd = tensor_util.convert_to_numpy_ndarray(value); var nd = tensor_util.convert_to_numpy_ndarray(value);
return constant_op.Constant(nd, name);
return constant_op.constant(nd, name);
} }
} }


@@ -240,12 +240,34 @@ namespace Tensorflow
/// of numpy ndarrays that each correspond to the respective element in /// of numpy ndarrays that each correspond to the respective element in
/// "tensors". /// "tensors".
/// </returns> /// </returns>
public static NDArray[] _eval_using_default_session(Tensor[] tensors, dynamic feed_dict, Graph graph, Session session = null)
public static NDArray _eval_using_default_session(Tensor tensor, Dictionary<Tensor, NDArray> feed_dict, Graph graph, Session session = null)
{ {
if (session == null) if (session == null)
{
session = get_default_session(); session = get_default_session();


return null;
if (session == null)
throw new ValueError("Cannot evaluate tensor using `eval()`: No default " +
"session is registered. Use `with " +
"sess.as_default()` or pass an explicit session to " +
"`eval(session=sess)`");

if (session.graph != graph)
throw new ValueError("Cannot use the default session to evaluate tensor: " +
"the tensor's graph is different from the session's " +
"graph. Pass an explicit session to " +
"`eval(session=sess)`.");
}
else
{
if (session.graph != graph)
throw new ValueError("Cannot use the default session to evaluate tensor: " +
"the tensor's graph is different from the session's " +
"graph. Pass an explicit session to " +
"`eval(session=sess)`.");
}

return session.run(tensor, feed_dict);
} }


/// <summary> /// <summary>
@@ -254,7 +276,7 @@ namespace Tensorflow
/// <returns>The default `Session` being used in the current thread.</returns> /// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session() public static Session get_default_session()
{ {
return null;
return tf.Session();
} }
} }
} }

+ 14
- 0
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp.Core;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
@@ -74,5 +75,18 @@ namespace TensorFlowNET.UnitTest
graph.Dispose(); graph.Dispose();
s.Dispose(); s.Dispose();
} }

[TestMethod]
public void EvalTensor()
{
var a = constant_op.constant(np.array(3.0).reshape(1, 1));
var b = constant_op.constant(np.array(2.0).reshape(1, 1));
var c = math_ops.matmul(a, b, name: "matmul");
Python.with(tf.Session(), delegate
{
var result = c.eval();
Assert.AreEqual(6, result.Data<double>()[0]);
});
}
} }
} }

Loading…
Cancel
Save