Browse Source

add operator overload for RefVariable.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
66f390a910
6 changed files with 63 additions and 7 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/IReturnTensorOrOperation.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  4. +14
    -4
      src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs
  5. +28
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  6. +4
    -0
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 14
- 0
src/TensorFlowNET.Core/IReturnTensorOrOperation.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// in order to limit function return value
/// is Tensor or Operation
/// </summary>
public interface IReturnTensorOrOperation
{
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -7,7 +7,7 @@ using System.Text;

namespace Tensorflow
{
public partial class Operation
public partial class Operation : IReturnTensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python



+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
/// </summary>
public partial class Tensor : IDisposable
public partial class Tensor : IDisposable, IReturnTensorOrOperation
{
private readonly IntPtr _handle;

@@ -258,7 +258,7 @@ namespace Tensorflow
}
}

return $"{name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}";
}

public void Dispose()


+ 14
- 4
src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs View File

@@ -6,11 +6,21 @@ namespace Tensorflow
{
public partial class RefVariable
{
public static Tensor operator +(RefVariable t1, int t2)
public static Tensor operator +(RefVariable x, int y) => op_helper("add", x, y);
public static Tensor operator +(RefVariable x, float y) => op_helper("add", x, y);
public static Tensor operator +(RefVariable x, double y) => op_helper("add", x, y);
public static Tensor operator -(RefVariable x, int y) => op_helper("sub", x, y);
public static Tensor operator -(RefVariable x, float y) => op_helper("sub", x, y);
public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y);
private static Tensor op_helper<T>(string default_name, RefVariable x, T y)
{
var tensor1 = t1._AsTensor();
var tensor2 = ops.convert_to_tensor(t2, tensor1.dtype, "y");
return gen_math_ops.add(tensor1, tensor2);
var tensor1 = x.value();
return Python.with<ops.name_scope, Tensor>(new ops.name_scope("", default_name, new object[] { tensor1, y }), scope => {
var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y");
return gen_math_ops.add(tensor1, tensor2, scope);
});
}
}
}

+ 28
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -171,6 +171,34 @@ namespace Tensorflow
return op;
}

/// <summary>
/// Assigns a new value to the variable.
/// </summary>
/// <param name="value">The new value for this variable.</param>
/// <param name="use_locking">If `True`, use locking during the assignment.</param>
/// <param name="name">The name of the operation to be created</param>
/// <param name="read_value">
/// if True, will return something which evaluates to the
/// new value of the variable; if False will return the assign op.
/// </param>
/// <returns>
/// A `Tensor` that will hold the new value of this variable after
/// the assignment has completed.
/// </returns>
public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true)
where T : IReturnTensorOrOperation
{
var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
if (read_value)
return (T)Convert.ChangeType(assign, typeof(T));
return (T)Convert.ChangeType(assign.op, typeof(T));
}

public Tensor assign(Tensor value, bool use_locking = false, string name = "")
{
return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name);
}

public override string ToString()
{
return $"tf.Variable '{name}' shape={shape} dtype={dtype}";


+ 4
- 0
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -15,7 +15,11 @@ namespace TensorFlowNET.UnitTest
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);
var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer);

var inc_v1 = v1.assign(v1 + 1.0f);
var dec_v2 = v2.assign(v2 - 1.0f);

// Add an op to initialize the variables.
var init_op = tf.global_variables_initializer();
}
}
}

Loading…
Cancel
Save