Browse Source

softmax_cross_entropy_with_logits

tags/v0.8.0
haiping008 6 years ago
parent
commit
7dbcb6c147
17 changed files with 215 additions and 13 deletions
  1. +6
    -4
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +12
    -0
      src/TensorFlowNET.Core/APIs/tf.control.cs
  3. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  4. +12
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  5. +1
    -0
      src/TensorFlowNET.Core/Keras/Engine/Network.cs
  6. +1
    -2
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  7. +22
    -0
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  8. +3
    -2
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  9. +1
    -2
      src/TensorFlowNET.Core/Layers/Layer.cs
  10. +18
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  11. +57
    -0
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  12. +26
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  13. +5
    -0
      src/TensorFlowNET.Core/Operations/math_ops.py.cs
  14. +33
    -0
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  15. +4
    -1
      src/TensorFlowNET.Core/ops.py.cs
  16. +1
    -0
      src/TensorFlowNET.Core/tf.cs
  17. +10
    -2
      test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs

+ 6
- 4
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -34,9 +34,11 @@ namespace Tensorflow
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1)
=> gen_array_ops.squeeze(input, axis, name);

public static Tensor one_hot(Tensor indices, int depth)
{
throw new NotImplementedException("one_hot");
}
public static Tensor one_hot(Tensor indices, int depth,
Tensor on_value = null,
Tensor off_value = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int axis = -1,
string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name);
}
}

+ 12
- 0
src/TensorFlowNET.Core/APIs/tf.control.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
=> ops.control_dependencies(control_inputs);
}
}

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -36,6 +36,9 @@ namespace Tensorflow
public static Tensor reduce_sum(Tensor input, int[] axis = null)
=> math_ops.reduce_sum(input);

public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name);

public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
=> math_ops.cast(x, dtype, name);



+ 12
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -46,6 +46,18 @@ namespace Tensorflow

public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);

public static Tensor bias_add(Tensor value, RefVariable bias, string data_format = null, string name = null)
{
return Python.with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope =>
{
name = scope;
return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name);
});
}

public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/Engine/Network.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras.Engine
{


+ 1
- 2
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -3,11 +3,10 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Utils;
using Tensorflow.Layers;

namespace Tensorflow.Keras.Layers
{
public class BatchNormalization : Layer
public class BatchNormalization : Tensorflow.Layers.Layer
{
private bool _USE_V2_BEHAVIOR = true;
private float momentum;


+ 22
- 0
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -4,6 +4,7 @@ using System.Linq;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations.Activation;
using static Tensorflow.tf;

namespace Tensorflow.Keras.Layers
{
@@ -55,5 +56,26 @@ namespace Tensorflow.Keras.Layers

built = true;
}

protected override Tensor call(Tensor inputs, Tensor training = null)
{
Tensor outputs = null;
var rank = inputs.rank;
if(rank > 2)
{
throw new NotImplementedException("");
}
else
{
outputs = gen_math_ops.mat_mul(inputs, kernel);
}

if (use_bias)
outputs = nn.bias_add(outputs, bias);
if (activation != null)
return activation.Activate(outputs);

return outputs;
}
}
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs → src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -2,9 +2,10 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Engine
namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Base layer class.
@@ -106,7 +107,7 @@ namespace Tensorflow.Keras.Engine

protected virtual Tensor call(Tensor inputs, Tensor training = null)
{
throw new NotImplementedException("Layer.call");
return inputs;
}

protected virtual string _name_scope()

+ 1
- 2
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -2,11 +2,10 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Layers
{
public class Layer : Keras.Engine.Layer
public class Layer : Keras.Layers.Layer
{
protected Graph _graph;


+ 18
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -108,5 +108,23 @@ namespace Tensorflow.Operations

return _op.outputs;
}

/// <summary>
/// Computes softmax cross entropy cost and gradients to backpropagate.
/// </summary>
/// <param name="features"></param>
/// <param name="labels"></param>
/// <param name="name"></param>
/// <returns></returns>
public static (Tensor, Tensor) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = null)
{
var _op = _op_def_lib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name: name, args: new
{
features,
labels
});

return (_op.outputs[0], _op.outputs[1]);
}
}
}

+ 57
- 0
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -46,6 +46,22 @@ namespace Tensorflow
}
}
public static Tensor _autopacking_helper(Tensor[] list_or_tuple, TF_DataType dtype, string name)
{
var must_pack = false;
var converted_elems = new List<Tensor>();
return with(ops.name_scope(name), scope =>
{
foreach (var (i, elem) in enumerate(list_or_tuple))
{
converted_elems.Add(elem);
must_pack = true;
}
return gen_array_ops.pack(converted_elems.ToArray(), name: scope);
});
}
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => expand_dims_v2(input, axis, name);
private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) => gen_array_ops.expand_dims(input, axis, name);
@@ -109,6 +125,44 @@ namespace Tensorflow
});
}
public static Tensor one_hot(Tensor indices, int depth,
Tensor on_value = null,
Tensor off_value = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int axis = -1,
string name = null)
{
return with(ops.name_scope(name, "one_hot", new { indices, depth, dtype }), scope =>
{
name = scope;
var on_exists = false;
var off_exists = false;
var on_dtype = TF_DataType.DtInvalid;
var off_dtype = TF_DataType.DtInvalid;
if (dtype == TF_DataType.DtInvalid)
dtype = TF_DataType.TF_FLOAT;
if(!on_exists)
{
on_value = ops.convert_to_tensor(1, dtype, name: "on_value");
on_dtype = dtype;
}
if (!off_exists)
{
off_value = ops.convert_to_tensor(0, dtype, name = "off_value");
off_dtype = dtype;
}
return gen_array_ops.one_hot(indices, depth,
on_value: on_value,
off_value: off_value,
axis: axis,
name: name);
});
}
public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null)
{
if( x == null && y == null)
@@ -298,5 +352,8 @@ namespace Tensorflow
return gen_array_ops.transpose(a, perm, name);
});
}
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name);
}
}

+ 26
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -40,6 +40,13 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor pack(Tensor[] values, int axis = 0, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis });

return _op.outputs[0];
}

public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape });
@@ -126,6 +133,17 @@ namespace Tensorflow
throw new NotImplementedException("where");
}

public static Tensor one_hot(Tensor indices, int depth,
Tensor on_value = null,
Tensor off_value = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int axis = -1,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis });
return _op.outputs[0];
}

/// <summary>
/// A placeholder op that passes through `input` when its output is not fed.
/// </summary>
@@ -174,12 +192,20 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("ZerosLike", name, new { x });
return _op.outputs[0];
}

public static Tensor stop_gradient(Tensor x, string name = null)
{
var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name });

return _op.outputs[0];
}

public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size });
return _op.outputs[0];
}

/// <summary>
/// Removes dimensions of size 1 from the shape of a tensor.
/// Given a tensor `input`, this operation returns a tensor of the same type with


+ 5
- 0
src/TensorFlowNET.Core/Operations/math_ops.py.cs View File

@@ -60,6 +60,11 @@ namespace Tensorflow
return gen_math_ops.square(x, name);
}

public static Tensor subtract<Tx, Ty>(Tx x, Ty y, string name = null)
{
return gen_math_ops.sub(x, y, name);
}

public static Tensor log(Tensor x, string name = null)
{
return gen_math_ops.log(x, name);


+ 33
- 0
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -41,5 +41,38 @@ namespace Tensorflow
return gen_nn_ops.bias_add(value, bias_tensor, data_format: data_format, name: name);
});
}

public static Tensor softmax_cross_entropy_with_logits_v2_helper(Tensor labels,
Tensor logits,
int axis = -1,
string name = null)
{
return Python.with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { }), scope =>
{
var precise_logits = logits;
var input_rank = array_ops.rank(precise_logits);
var shape = logits.getShape();

if (axis != -1)
throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1");

var input_shape = array_ops.shape(precise_logits);

// Do the actual op computation.
// The second output tensor contains the gradients. We use it in
// _CrossEntropyGrad() in nn_grad but not here.

var (cost, unused_backprop) = gen_nn_ops.softmax_cross_entropy_with_logits(precise_logits, labels, name: name);

// The output cost shape should be the input minus axis.
var output_shape = array_ops.slice(input_shape,
new int[] { 0 },
new Tensor[] { math_ops.subtract(input_rank, 1) });

cost = array_ops.reshape(cost, output_shape);

return cost;
});
}
}
}

+ 4
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -434,7 +434,8 @@ namespace Tensorflow

public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid,
string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid,
bool as_ref = false)
bool as_ref = false,
string scope = null)
{
if (dtype == TF_DataType.DtInvalid)
dtype = preferred_dtype;
@@ -443,6 +444,8 @@ namespace Tensorflow
{
case Tensor tensor:
return tensor;
case Tensor[] tensors:
return array_ops._autopacking_helper(tensors, dtype, name);
case string str:
return constant_op.constant(str, dtype: dtype, name: name);
case string[] strArray:


+ 1
- 0
src/TensorFlowNET.Core/tf.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
using static Tensorflow.ops;

namespace Tensorflow
{


+ 10
- 2
test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs View File

@@ -22,6 +22,9 @@ namespace TensorFlowNET.Examples.TextClassification
private RefVariable embeddings;
private Tensor x_emb;
private Tensor x_expanded;
private Tensor logits;
private Tensor predictions;
private Tensor loss;

public VdCnn(int alphabet_size, int document_max_len, int num_class)
{
@@ -55,8 +58,6 @@ namespace TensorFlowNET.Examples.TextClassification
Tensor h_flat = null;
Tensor fc1_out = null;
Tensor fc2_out = null;
Tensor logits = null;
Tensor predictions = null;

// First Convolution Layer
with(tf.variable_scope("conv-0"), delegate
@@ -116,6 +117,13 @@ namespace TensorFlowNET.Examples.TextClassification
with(tf.name_scope("loss"), delegate
{
var y_one_hot = tf.one_hot(y, num_class);
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot));

var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List<Operation>;
with(tf.control_dependencies(update_ops.ToArray()), delegate
{

});
});
}



Loading…
Cancel
Save