diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index 51189bda..dbd1b246 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -18,8 +18,11 @@ namespace Tensorflow { public static partial class tf { - public static object get_collection(string key, string scope = "") => get_default_graph() - .get_collection(key, scope: scope); + public static Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) + => state_ops.assign(@ref, value, validate_shape, use_locking, name); + + public static object get_collection(string key, string scope = "") + => get_default_graph().get_collection(key, scope: scope); /// /// Returns a context manager that creates hierarchical names for operations. @@ -28,8 +31,7 @@ namespace Tensorflow /// The default name to use if the name argument is None. /// The list of Tensor arguments that are passed to the op function. /// The scope name. - public static ops.NameScope name_scope(string name, - string default_name = "", - object values = null) => new ops.NameScope(name, default_name, values); + public static ops.NameScope name_scope(string name, string default_name = "", object values = null) + => new ops.NameScope(name, default_name, values); } } diff --git a/src/TensorFlowNET.Core/Contrib/Learn/Estimators/tensor_signature.cs b/src/TensorFlowNET.Core/Contrib/Learn/Estimators/tensor_signature.cs new file mode 100644 index 00000000..cf53a8dc --- /dev/null +++ b/src/TensorFlowNET.Core/Contrib/Learn/Estimators/tensor_signature.cs @@ -0,0 +1,39 @@ +using System.Linq; +using NumSharp; +using Tensorflow.Framework; + +namespace Tensorflow.Contrib.Learn.Estimators +{ + public static class tensor_signature + { + public static bool is_compatible_with(this Tensor self, Tensor other) + { + bool _shape_is_compatible_0dim(Shape _this, Shape _other) + { + var __other = tensor_shape.as_shape(_other); + if (_this.Dimensions == null || __other.Dimensions == null) + return true; + + if (_this.NDim != __other.NDim) + return false; + + foreach (var (x_dim, y_dim) in _this.Dimensions.Zip(__other.Dimensions, (x_dim, y_dim) => (x_dim, y_dim))) + { + if (x_dim != y_dim) + return false; + } + + return true; + } + + if (other.is_sparse()) + { + return self.dtype.is_compatible_with(other.dtype); + } + + return self.dtype.is_compatible_with(other.dtype) && + _shape_is_compatible_0dim(self.shape, other.shape) && + !self.is_sparse(); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs index 547d4516..8d0ea53b 100644 --- a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs +++ b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs @@ -1,8 +1,19 @@ namespace Tensorflow.Framework { - public static class SparseTensor + public interface _TensorLike + { } + + public class SparseTensor : CompositeTensor, _TensorLike { private static Tensor _dense_shape { get; set; } } + + public static class sparse_tensor + { + public static bool is_sparse(this _TensorLike x) + { + return x is SparseTensor; + } + } } diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs new file mode 100644 index 00000000..4972c1b4 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs @@ -0,0 +1,34 @@ +using System; +using System.Linq; +using System.Text; +using NumSharp; +using Tensorflow.Contrib.Learn.Estimators; + +namespace Tensorflow.Framework +{ + public static class tensor_shape + { + public static void assert_is_compatible_with(this Tensor self, Tensor other) + { + if (!self.is_compatible_with(other)) + { + var selfDim = self.shape + .Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString()) + .Replace(", }", "}"); + + var otherDim = other.shape + .Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString()) + .Replace(", }", "}"); + + throw new ArgumentException($"Dimensions {selfDim} and {otherDim} are not compatible"); + } + } + + public static TensorShape as_shape(this Shape shape) + { + if (shape is TensorShape tshape) + return tshape; + return new TensorShape(shape); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 416e88b4..41bd0ddf 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -14,12 +14,15 @@ limitations under the License. ******************************************************************************/ +using System; +using Tensorflow.Framework; + namespace Tensorflow { /// /// tensorflow\python\ops\resource_variable_ops.py /// - public class resource_variable_ops + public static class resource_variable_ops { public static ITensorOrOperation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) { @@ -29,9 +32,61 @@ namespace Tensorflow name: name); } + /// + /// + /// + /// + /// + /// + /// + /// + /// If `read_value` is `True`, this method will return the new value of the + /// variable after the assignment has completed.Otherwise, when in graph mode + /// it will return the `Operation` that does the assignment, and when in eager + /// mode it will return `None`. + /// + public static Operation assign(this Tensor self, Tensor value, bool use_locking = false, string name = null, bool read_value = true) + { + var value_tensor = ops.convert_to_tensor(value, dtype: self.dtype); + self.assert_is_compatible_with(value_tensor); + var assign_op = gen_resource_variable_ops.assign_variable_op(self, value_tensor, name: name); + if (read_value) + { + return self._lazy_read(assign_op); + } + + return assign_op; + } + + public static Operation _lazy_read(this Tensor self, Operation op) + { + variable_accessed(self); + throw new NotImplementedException(); + } + + public static void variable_accessed(this Tensor variable) + { + throw new NotImplementedException(); + } + public static bool is_resource_variable(VariableV1 var) { return var is ResourceVariable; } + + /// + /// Represents a future for a read of a variable. + /// Pretends to be the tensor if anyone looks. + /// + public class _UnreadVariable : BaseResourceVariable + { + } + + /// + /// A python variable from an existing handle. + /// + public class BaseResourceVariable : VariableV1 + { + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index a2a77a7e..e7049e7e 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -19,6 +19,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using Tensorflow.Framework; using static Tensorflow.Python; namespace Tensorflow @@ -27,7 +28,7 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// - public partial class Tensor : IDisposable, ITensorOrOperation + public partial class Tensor : IDisposable, ITensorOrOperation, _TensorLike { private IntPtr _handle; @@ -109,6 +110,8 @@ namespace Tensorflow this.shape = shape.Dimensions; } + public int[] dims => shape; + /// /// number of dimensions /// 0 Scalar (magnitude only) diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index c809c96b..16b09d05 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -205,5 +205,10 @@ namespace Tensorflow { return (int)type > 100; } + + public static bool is_compatible_with(this TF_DataType self, TF_DataType other) + { + return self.as_datatype_enum() == other.as_datatype_enum(); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 473bb7ca..2153e2d7 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -214,6 +214,12 @@ namespace Tensorflow else nparray = Convert.ToString(values); break; + case "Boolean": + if (values.GetType().IsArray) + nparray = np.array((bool[])values, np_dt); + else + nparray = Convert.ToBoolean(values); + break; default: throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index d380975e..78a241c2 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -40,6 +40,8 @@ namespace Tensorflow public override string name => _variable.name; + public Tensor eval() => _variable; + public RefVariable(object initial_value = null, bool trainable = true, List collections = null, diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index e1ab9e20..af34a2ba 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -12,102 +12,102 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -******************************************************************************/ - -using System.Collections.Generic; -using Tensorflow.Eager; - -namespace Tensorflow -{ - public class gen_state_ops - { - public static OpDefLibrary _op_def_lib = new OpDefLibrary(); - public static Execute _execute = new Execute(); - - /// - /// Holds state in the form of a tensor that persists across steps. - /// Outputs a ref to the tensor state so it may be read or modified. - /// - /// The shape of the variable tensor. - /// The type of elements in the variable tensor. - /// - /// - /// - /// - public static Tensor variable_v2(int[] shape, TF_DataType dtype, string name = null, string container = "", string shared_name = "") - { - var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, args: new { dtype, shape, container, shared_name }); - - var _result = _op.outputs; - var _inputs_flat = _op.inputs; - - var _attrs = new Dictionary(); - _attrs["dtype"] = _op.get_attr("dtype"); - _attrs["shape"] = _op.get_attr("shape"); - _attrs["container"] = _op.get_attr("container"); - _attrs["shared_name"] = _op.get_attr("shared_name"); - - _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); - - return _result[0]; - } - - /// - /// Update 'ref' by assigning 'value' to it - /// - /// - /// - /// - /// - /// - public static Tensor assign(Tensor @ref, object value, - bool validate_shape = true, - bool use_locking = true, - string name = null) - { - var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking }); - - var _result = _op.outputs; - var _inputs_flat = _op.inputs; - - var _attrs = new Dictionary(); - _attrs["T"] = _op.get_attr("T"); - _attrs["validate_shape"] = _op.get_attr("validate_shape"); - _attrs["use_locking"] = _op.get_attr("use_locking"); - - _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name); - - return _result[0]; - } - +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.Eager; + +namespace Tensorflow +{ + public class gen_state_ops + { + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + public static Execute _execute = new Execute(); + + /// + /// Holds state in the form of a tensor that persists across steps. + /// Outputs a ref to the tensor state so it may be read or modified. + /// + /// The shape of the variable tensor. + /// The type of elements in the variable tensor. + /// + /// + /// + /// + public static Tensor variable_v2(int[] shape, TF_DataType dtype, string name = null, string container = "", string shared_name = "") + { + var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, args: new { dtype, shape, container, shared_name }); + + var _result = _op.outputs; + var _inputs_flat = _op.inputs; + + var _attrs = new Dictionary(); + _attrs["dtype"] = _op.get_attr("dtype"); + _attrs["shape"] = _op.get_attr("shape"); + _attrs["container"] = _op.get_attr("container"); + _attrs["shared_name"] = _op.get_attr("shared_name"); + + _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); + + return _result[0]; + } + + /// + /// Update 'ref' by assigning 'value' to it + /// + /// + /// + /// + /// + /// + public static Tensor assign(Tensor @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking }); + + var _result = _op.outputs; + var _inputs_flat = _op.inputs; + + var _attrs = new Dictionary(); + _attrs["T"] = _op.get_attr("T"); + _attrs["validate_shape"] = _op.get_attr("validate_shape"); + _attrs["use_locking"] = _op.get_attr("use_locking"); + + _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name); + + return _result[0]; + } + public static Tensor assign(RefVariable @ref, object value, bool validate_shape = true, - bool use_locking = true, - string name = null) - { - var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking }); - - var _result = _op.outputs; - var _inputs_flat = _op.inputs; - - var _attrs = new Dictionary(); - _attrs["T"] = _op.get_attr("T"); - _attrs["validate_shape"] = _op.get_attr("validate_shape"); - _attrs["use_locking"] = _op.get_attr("use_locking"); - - _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name); - - return _result[0]; - } - - public static Tensor assign_sub(RefVariable @ref, - Tensor value, - bool use_locking = false, - string name = null) - { - var _op = _op_def_lib._apply_op_helper("AssignSub", name: name, args: new { @ref, value, use_locking }); - - return _op.outputs[0]; + bool use_locking = true, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking }); + + var _result = _op.outputs; + var _inputs_flat = _op.inputs; + + var _attrs = new Dictionary(); + _attrs["T"] = _op.get_attr("T"); + _attrs["validate_shape"] = _op.get_attr("validate_shape"); + _attrs["use_locking"] = _op.get_attr("use_locking"); + + _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name); + + return _result[0]; + } + + public static Tensor assign_sub(RefVariable @ref, + Tensor value, + bool use_locking = false, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("AssignSub", name: name, args: new { @ref, value, use_locking }); + + return _op.outputs[0]; } @@ -125,10 +125,10 @@ namespace Tensorflow // name: A name for the operation(optional). // Returns: // A mutable `Tensor`. Has the same type as `ref`. - public static Tensor assign_add(RefVariable @ref, Tensor value, bool use_locking = false, string name = null) - { - var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); - return _op.outputs[0]; + public static Tensor assign_add(RefVariable @ref, Tensor value, bool use_locking = false, string name = null) + { + var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); + return _op.outputs[0]; } /// @@ -142,8 +142,8 @@ namespace Tensorflow /// public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) { - var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); + var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); return _op.outputs[0]; - } - } -} + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index e89844f9..502c3c1e 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -12,99 +12,99 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -******************************************************************************/ - -using System; - -namespace Tensorflow -{ - public class state_ops - { - /// - /// Create a variable Operation. - /// - /// - /// - /// - /// - /// - /// - public static Tensor variable_op_v2(int[] shape, - TF_DataType dtype, - string name = "Variable", - string container = "", - string shared_name = "") => gen_state_ops.variable_v2(shape, - dtype, - name: name, - container: container, - shared_name: shared_name); - - public static Tensor assign(Tensor @ref, object value, - bool validate_shape = true, - bool use_locking = true, - string name = null) - { - if (@ref.dtype.is_ref_dtype()) - return gen_state_ops.assign(@ref, - value, - validate_shape: validate_shape, - use_locking: use_locking, - name: name); - throw new NotImplementedException("state_ops.assign"); - //return @ref.assign(value, name: name); - } - - public static Tensor assign(RefVariable @ref, object value, - bool validate_shape = true, - bool use_locking = true, - string name = null) +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public class state_ops + { + /// + /// Create a variable Operation. + /// + /// + /// + /// + /// + /// + /// + public static Tensor variable_op_v2(int[] shape, + TF_DataType dtype, + string name = "Variable", + string container = "", + string shared_name = "") => gen_state_ops.variable_v2(shape, + dtype, + name: name, + container: container, + shared_name: shared_name); + + public static Tensor assign(Tensor @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.assign(@ref, + value, + validate_shape: validate_shape, + use_locking: use_locking, + name: name); + + return @ref.assign((Tensor)value, name: name); + } + + public static Tensor assign(RefVariable @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) { return gen_state_ops.assign(@ref, value, validate_shape: validate_shape, use_locking: use_locking, - name: name); - } - - public static Tensor assign_sub(RefVariable @ref, - Tensor value, - bool use_locking = false, - string name = null) => gen_state_ops.assign_sub(@ref, - value, - use_locking: use_locking, - name: name); - - //"""Update 'ref' by adding 'value' to it. - // - // This operation outputs "ref" after the update is done. - // This makes it easier to chain operations that need to use the reset value. - // - // Args: - // ref: A mutable `Tensor`. Must be one of the following types: - // `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, - // `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. - // Should be from a `Variable` node. - // value: A `Tensor`. Must have the same type as `ref`. - // The value to be added to the variable. - // use_locking: An optional `bool`. Defaults to `False`. - // If True, the addition will be protected by a lock; - // otherwise the behavior is undefined, but may exhibit less contention. - // name: A name for the operation (optional). + name: name); + } + + public static Tensor assign_sub(RefVariable @ref, + Tensor value, + bool use_locking = false, + string name = null) => gen_state_ops.assign_sub(@ref, + value, + use_locking: use_locking, + name: name); + + //"""Update 'ref' by adding 'value' to it. + // + // This operation outputs "ref" after the update is done. + // This makes it easier to chain operations that need to use the reset value. + // + // Args: + // ref: A mutable `Tensor`. Must be one of the following types: + // `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, + // `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. + // Should be from a `Variable` node. + // value: A `Tensor`. Must have the same type as `ref`. + // The value to be added to the variable. + // use_locking: An optional `bool`. Defaults to `False`. + // If True, the addition will be protected by a lock; + // otherwise the behavior is undefined, but may exhibit less contention. + // name: A name for the operation (optional). // // Returns: // Same as "ref". Returned as a convenience for operations that want // to use the new value after the variable has been updated. - public static Tensor assign_add(RefVariable @ref, - Tensor value, - bool use_locking = false, - string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); - + public static Tensor assign_add(RefVariable @ref, + Tensor value, + bool use_locking = false, + string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) { if (@ref.dtype.is_ref_dtype()) return gen_state_ops.scatter_add(@ref, indices, updates, use_locking: use_locking, name: name); throw new NotImplementedException("scatter_add"); - } - } -} + } + } +} diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 027465bf..b10f41b0 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -46,7 +46,7 @@ namespace Tensorflow trainable: trainable, validate_shape: validate_shape, name: name, - dtype: TF_DataType.DtInvalid); + dtype: dtype); } public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)