Browse Source

added sparse_softmax_cross_entropy_with_logits

tags/v0.9
Oceania2018 6 years ago
parent
commit
a4f03c22ec
2 changed files with 32 additions and 0 deletions
  1. +11
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  2. +21
    -0
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs

+ 11
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -90,6 +90,17 @@ namespace Tensorflow
public static Tensor softmax(Tensor logits, int axis = -1, string name = null)
=> gen_nn_ops.softmax(logits, name);

/// <summary>
/// Computes sparse softmax cross entropy between `logits` and `labels`.
/// </summary>
/// <param name="labels"></param>
/// <param name="logits"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor sparse_softmax_cross_entropy_with_logits(Tensor labels = null,
Tensor logits = null, string name = null)
=> nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name);

public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
}


+ 21
- 0
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -203,6 +203,27 @@ namespace TensorFlowNET.Examples
var h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
});

Tensor logits = null;
Tensor predictions = null;
with(tf.name_scope("output"), delegate
{
logits = tf.layers.dense(h_pool_flat, keep_prob);
predictions = tf.argmax(logits, -1, output_type: tf.int32);
});

with(tf.name_scope("loss"), delegate
{
var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y);
var loss = tf.reduce_mean(sscel);
var optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step: global_step);
});

with(tf.name_scope("accuracy"), delegate
{
var correct_predictions = tf.equal(predictions, y);
var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy");
});

return graph;
}



Loading…
Cancel
Save