@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// in order to limit function return value | |||||
/// is Tensor or Operation | |||||
/// </summary> | |||||
public interface IReturnTensorOrOperation | |||||
{ | |||||
} | |||||
} |
@@ -7,7 +7,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public partial class Operation | |||||
public partial class Operation : IReturnTensorOrOperation | |||||
{ | { | ||||
private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
@@ -12,7 +12,7 @@ namespace Tensorflow | |||||
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | ||||
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | ||||
/// </summary> | /// </summary> | ||||
public partial class Tensor : IDisposable | |||||
public partial class Tensor : IDisposable, IReturnTensorOrOperation | |||||
{ | { | ||||
private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
@@ -258,7 +258,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
return $"{name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||||
return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -6,11 +6,21 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class RefVariable | public partial class RefVariable | ||||
{ | { | ||||
public static Tensor operator +(RefVariable t1, int t2) | |||||
public static Tensor operator +(RefVariable x, int y) => op_helper("add", x, y); | |||||
public static Tensor operator +(RefVariable x, float y) => op_helper("add", x, y); | |||||
public static Tensor operator +(RefVariable x, double y) => op_helper("add", x, y); | |||||
public static Tensor operator -(RefVariable x, int y) => op_helper("sub", x, y); | |||||
public static Tensor operator -(RefVariable x, float y) => op_helper("sub", x, y); | |||||
public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y); | |||||
private static Tensor op_helper<T>(string default_name, RefVariable x, T y) | |||||
{ | { | ||||
var tensor1 = t1._AsTensor(); | |||||
var tensor2 = ops.convert_to_tensor(t2, tensor1.dtype, "y"); | |||||
return gen_math_ops.add(tensor1, tensor2); | |||||
var tensor1 = x.value(); | |||||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", default_name, new object[] { tensor1, y }), scope => { | |||||
var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y"); | |||||
return gen_math_ops.add(tensor1, tensor2, scope); | |||||
}); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -171,6 +171,34 @@ namespace Tensorflow | |||||
return op; | return op; | ||||
} | } | ||||
/// <summary> | |||||
/// Assigns a new value to the variable. | |||||
/// </summary> | |||||
/// <param name="value">The new value for this variable.</param> | |||||
/// <param name="use_locking">If `True`, use locking during the assignment.</param> | |||||
/// <param name="name">The name of the operation to be created</param> | |||||
/// <param name="read_value"> | |||||
/// if True, will return something which evaluates to the | |||||
/// new value of the variable; if False will return the assign op. | |||||
/// </param> | |||||
/// <returns> | |||||
/// A `Tensor` that will hold the new value of this variable after | |||||
/// the assignment has completed. | |||||
/// </returns> | |||||
public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | |||||
where T : IReturnTensorOrOperation | |||||
{ | |||||
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||||
if (read_value) | |||||
return (T)Convert.ChangeType(assign, typeof(T)); | |||||
return (T)Convert.ChangeType(assign.op, typeof(T)); | |||||
} | |||||
public Tensor assign(Tensor value, bool use_locking = false, string name = "") | |||||
{ | |||||
return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||||
} | |||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
return $"tf.Variable '{name}' shape={shape} dtype={dtype}"; | return $"tf.Variable '{name}' shape={shape} dtype={dtype}"; | ||||
@@ -15,7 +15,11 @@ namespace TensorFlowNET.UnitTest | |||||
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | ||||
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | ||||
var inc_v1 = v1.assign(v1 + 1.0f); | |||||
var dec_v2 = v2.assign(v2 - 1.0f); | |||||
// Add an op to initialize the variables. | |||||
var init_op = tf.global_variables_initializer(); | |||||
} | } | ||||
} | } | ||||
} | } |