Browse Source

add MarShall.Copy for Boolean when creating Tensor.

tags/v0.9
Oceania2018 6 years ago
parent
commit
e615351373
3 changed files with 11 additions and 5 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  2. +1
    -0
      test/TensorFlowNET.Examples/Program.cs
  3. +7
    -5
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

+ 3
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -55,6 +55,9 @@ namespace Tensorflow
var nd1 = nd.ravel();
switch (nd.dtype.Name)
{
case "Boolean":
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size);
break;
case "Int16":
Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size);
break;


+ 1
- 0
test/TensorFlowNET.Examples/Program.cs View File

@@ -64,6 +64,7 @@ namespace TensorFlowNET.Examples
disabled.ForEach(x => Console.WriteLine($"{x} is Disabled!", Color.Tan));
errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red));

Console.Write("Please [Enter] to quit.");
Console.ReadLine();
}
}


+ 7
- 5
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -61,7 +61,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification
var meta_file = model_name + ".meta";
tf.train.import_meta_graph(Path.Join("graph", meta_file));
Console.WriteLine("\tDONE");
//sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export
// definitely necessary, otherwize will get the exception of "use uninitialized variable"
sess.run(tf.global_variables_initializer());
var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1
@@ -89,12 +90,13 @@ namespace TensorFlowNET.Examples.CnnTextClassification
};
// original python:
//_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
loss_value = result[2];
var result = sess.run(new Tensor[] { optimizer, global_step, loss }, train_feed_dict);
// exception here, loss value seems like a float[]
//loss_value = result[2];
var step = result[1];
if (step % 10 == 0)
Console.WriteLine($"Step {step} loss: {result[2]}");
if (step % 100 == 0)
Console.WriteLine($"Step {step} loss: {loss_value}");
if (step % 2000 == 0)
{
continue;
// # Test accuracy with validation data for each epoch.


Loading…
Cancel
Save