From 7dbcb6c1475f031afa9d434467d29e58b69d9e08 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Thu, 14 Mar 2019 16:04:08 -0500 Subject: [PATCH] softmax_cross_entropy_with_logits --- src/TensorFlowNET.Core/APIs/tf.array.cs | 10 ++-- src/TensorFlowNET.Core/APIs/tf.control.cs | 12 ++++ src/TensorFlowNET.Core/APIs/tf.math.cs | 3 + src/TensorFlowNET.Core/APIs/tf.nn.cs | 12 ++++ .../Keras/Engine/Network.cs | 1 + .../Keras/Layers/BatchNormalization.cs | 3 +- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 22 +++++++ .../Keras/{Engine => Layers}/Layer.cs | 5 +- src/TensorFlowNET.Core/Layers/Layer.cs | 3 +- .../Operations/NnOps/gen_nn_ops.cs | 18 ++++++ .../Operations/array_ops.py.cs | 57 +++++++++++++++++++ .../Operations/gen_array_ops.cs | 26 +++++++++ .../Operations/math_ops.py.cs | 5 ++ src/TensorFlowNET.Core/Operations/nn_ops.cs | 33 +++++++++++ src/TensorFlowNET.Core/ops.py.cs | 5 +- src/TensorFlowNET.Core/tf.cs | 1 + .../TextClassification/cnn_models/VdCnn.cs | 12 +++- 17 files changed, 215 insertions(+), 13 deletions(-) create mode 100644 src/TensorFlowNET.Core/APIs/tf.control.cs rename src/TensorFlowNET.Core/Keras/{Engine => Layers}/Layer.cs (98%) diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index b10e1d77..91b59842 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -34,9 +34,11 @@ namespace Tensorflow public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) => gen_array_ops.squeeze(input, axis, name); - public static Tensor one_hot(Tensor indices, int depth) - { - throw new NotImplementedException("one_hot"); - } + public static Tensor one_hot(Tensor indices, int depth, + Tensor on_value = null, + Tensor off_value = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int axis = -1, + string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.control.cs b/src/TensorFlowNET.Core/APIs/tf.control.cs new file mode 100644 index 00000000..1b6bc3b8 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.control.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) + => ops.control_dependencies(control_inputs); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index ad548864..0534b645 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -36,6 +36,9 @@ namespace Tensorflow public static Tensor reduce_sum(Tensor input, int[] axis = null) => math_ops.reduce_sum(input); + public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name); + public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) => math_ops.cast(x, dtype, name); diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 87b646d7..1288508c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -46,6 +46,18 @@ namespace Tensorflow public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); + + public static Tensor bias_add(Tensor value, RefVariable bias, string data_format = null, string name = null) + { + return Python.with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => + { + name = scope; + return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name); + }); + } + + public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) + => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Network.cs b/src/TensorFlowNET.Core/Keras/Engine/Network.cs index 6eff46c4..43594022 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Network.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Network.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.Layers; namespace Tensorflow.Keras.Engine { diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 64f44386..c93c07c0 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -3,11 +3,10 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Keras.Utils; -using Tensorflow.Layers; namespace Tensorflow.Keras.Layers { - public class BatchNormalization : Layer + public class BatchNormalization : Tensorflow.Layers.Layer { private bool _USE_V2_BEHAVIOR = true; private float momentum; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index f27d0963..323b6658 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Text; using Tensorflow.Keras.Engine; using Tensorflow.Operations.Activation; +using static Tensorflow.tf; namespace Tensorflow.Keras.Layers { @@ -55,5 +56,26 @@ namespace Tensorflow.Keras.Layers built = true; } + + protected override Tensor call(Tensor inputs, Tensor training = null) + { + Tensor outputs = null; + var rank = inputs.rank; + if(rank > 2) + { + throw new NotImplementedException(""); + } + else + { + outputs = gen_math_ops.mat_mul(inputs, kernel); + } + + if (use_bias) + outputs = nn.bias_add(outputs, bias); + if (activation != null) + return activation.Activate(outputs); + + return outputs; + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs similarity index 98% rename from src/TensorFlowNET.Core/Keras/Engine/Layer.cs rename to src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 12df6b4d..0ed2a4ce 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -2,9 +2,10 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; -namespace Tensorflow.Keras.Engine +namespace Tensorflow.Keras.Layers { /// /// Base layer class. @@ -106,7 +107,7 @@ namespace Tensorflow.Keras.Engine protected virtual Tensor call(Tensor inputs, Tensor training = null) { - throw new NotImplementedException("Layer.call"); + return inputs; } protected virtual string _name_scope() diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 17205c51..aa2a7405 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -2,11 +2,10 @@ using System.Collections.Generic; using System.Linq; using System.Text; -using Tensorflow.Keras.Engine; namespace Tensorflow.Layers { - public class Layer : Keras.Engine.Layer + public class Layer : Keras.Layers.Layer { protected Graph _graph; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 4f8cd7ff..8dde6143 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -108,5 +108,23 @@ namespace Tensorflow.Operations return _op.outputs; } + + /// + /// Computes softmax cross entropy cost and gradients to backpropagate. + /// + /// + /// + /// + /// + public static (Tensor, Tensor) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = null) + { + var _op = _op_def_lib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name: name, args: new + { + features, + labels + }); + + return (_op.outputs[0], _op.outputs[1]); + } } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index c5641f06..2a972c9c 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -46,6 +46,22 @@ namespace Tensorflow } } + public static Tensor _autopacking_helper(Tensor[] list_or_tuple, TF_DataType dtype, string name) + { + var must_pack = false; + var converted_elems = new List(); + return with(ops.name_scope(name), scope => + { + foreach (var (i, elem) in enumerate(list_or_tuple)) + { + converted_elems.Add(elem); + must_pack = true; + } + + return gen_array_ops.pack(converted_elems.ToArray(), name: scope); + }); + } + public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => expand_dims_v2(input, axis, name); private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name); @@ -109,6 +125,44 @@ namespace Tensorflow }); } + public static Tensor one_hot(Tensor indices, int depth, + Tensor on_value = null, + Tensor off_value = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int axis = -1, + string name = null) + { + return with(ops.name_scope(name, "one_hot", new { indices, depth, dtype }), scope => + { + name = scope; + var on_exists = false; + var off_exists = false; + var on_dtype = TF_DataType.DtInvalid; + var off_dtype = TF_DataType.DtInvalid; + + if (dtype == TF_DataType.DtInvalid) + dtype = TF_DataType.TF_FLOAT; + + if(!on_exists) + { + on_value = ops.convert_to_tensor(1, dtype, name: "on_value"); + on_dtype = dtype; + } + + if (!off_exists) + { + off_value = ops.convert_to_tensor(0, dtype, name = "off_value"); + off_dtype = dtype; + } + + return gen_array_ops.one_hot(indices, depth, + on_value: on_value, + off_value: off_value, + axis: axis, + name: name); + }); + } + public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null) { if( x == null && y == null) @@ -298,5 +352,8 @@ namespace Tensorflow return gen_array_ops.transpose(a, perm, name); }); } + + public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + => gen_array_ops.slice(input, begin, size, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 9260d9fe..393b10fa 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -40,6 +40,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor pack(Tensor[] values, int axis = 0, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); + + return _op.outputs[0]; + } + public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) { var _op = _op_def_lib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); @@ -126,6 +133,17 @@ namespace Tensorflow throw new NotImplementedException("where"); } + public static Tensor one_hot(Tensor indices, int depth, + Tensor on_value = null, + Tensor off_value = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int axis = -1, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis }); + return _op.outputs[0]; + } + /// /// A placeholder op that passes through `input` when its output is not fed. /// @@ -174,12 +192,20 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("ZerosLike", name, new { x }); return _op.outputs[0]; } + public static Tensor stop_gradient(Tensor x, string name = null) { var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); return _op.outputs[0]; } + + public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); + return _op.outputs[0]; + } + /// /// Removes dimensions of size 1 from the shape of a tensor. /// Given a tensor `input`, this operation returns a tensor of the same type with diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index e35a094b..f2abad66 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -60,6 +60,11 @@ namespace Tensorflow return gen_math_ops.square(x, name); } + public static Tensor subtract(Tx x, Ty y, string name = null) + { + return gen_math_ops.sub(x, y, name); + } + public static Tensor log(Tensor x, string name = null) { return gen_math_ops.log(x, name); diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index c7aec377..74e76bda 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -41,5 +41,38 @@ namespace Tensorflow return gen_nn_ops.bias_add(value, bias_tensor, data_format: data_format, name: name); }); } + + public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels, + Tensor logits, + int axis = -1, + string name = null) + { + return Python.with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { }), scope => + { + var precise_logits = logits; + var input_rank = array_ops.rank(precise_logits); + var shape = logits.getShape(); + + if (axis != -1) + throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1"); + + var input_shape = array_ops.shape(precise_logits); + + // Do the actual op computation. + // The second output tensor contains the gradients. We use it in + // _CrossEntropyGrad() in nn_grad but not here. + + var (cost, unused_backprop) = gen_nn_ops.softmax_cross_entropy_with_logits(precise_logits, labels, name: name); + + // The output cost shape should be the input minus axis. + var output_shape = array_ops.slice(input_shape, + new int[] { 0 }, + new Tensor[] { math_ops.subtract(input_rank, 1) }); + + cost = array_ops.reshape(cost, output_shape); + + return cost; + }); + } } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 50fd59c9..aeef11f3 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -434,7 +434,8 @@ namespace Tensorflow public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid, - bool as_ref = false) + bool as_ref = false, + string scope = null) { if (dtype == TF_DataType.DtInvalid) dtype = preferred_dtype; @@ -443,6 +444,8 @@ namespace Tensorflow { case Tensor tensor: return tensor; + case Tensor[] tensors: + return array_ops._autopacking_helper(tensors, dtype, name); case string str: return constant_op.constant(str, dtype: dtype, name: name); case string[] strArray: diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index e1e1331d..349ae39e 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Eager; +using static Tensorflow.ops; namespace Tensorflow { diff --git a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs index cce3bbef..2f24d87e 100644 --- a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs @@ -22,6 +22,9 @@ namespace TensorFlowNET.Examples.TextClassification private RefVariable embeddings; private Tensor x_emb; private Tensor x_expanded; + private Tensor logits; + private Tensor predictions; + private Tensor loss; public VdCnn(int alphabet_size, int document_max_len, int num_class) { @@ -55,8 +58,6 @@ namespace TensorFlowNET.Examples.TextClassification Tensor h_flat = null; Tensor fc1_out = null; Tensor fc2_out = null; - Tensor logits = null; - Tensor predictions = null; // First Convolution Layer with(tf.variable_scope("conv-0"), delegate @@ -116,6 +117,13 @@ namespace TensorFlowNET.Examples.TextClassification with(tf.name_scope("loss"), delegate { var y_one_hot = tf.one_hot(y, num_class); + loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); + + var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List; + with(tf.control_dependencies(update_ops.ToArray()), delegate + { + + }); }); }