Browse Source

fix tf.cond.

tags/v0.20
Oceania2018 5 years ago
parent
commit
16ea35126d
24 changed files with 234 additions and 77 deletions
  1. +25
    -2
      src/TensorFlowNET.Core/APIs/tf.io.cs
  2. +19
    -3
      src/TensorFlowNET.Core/APIs/tf.strings.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  4. +0
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  5. +20
    -0
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  6. +0
    -40
      src/TensorFlowNET.Core/Operations/gen_string_ops.cs
  7. +9
    -5
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  8. +18
    -4
      src/TensorFlowNET.Core/Operations/io_ops.cs
  9. +26
    -3
      src/TensorFlowNET.Core/Operations/string_ops.cs
  10. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  12. +32
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  14. +12
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  15. +2
    -2
      src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs
  16. +1
    -2
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  17. +4
    -3
      test/TensorFlowNET.UnitTest/ImageTest.cs
  18. +0
    -0
      test/TensorFlowNET.UnitTest/TF_API/ActivationFunctionTest.cs
  19. +2
    -2
      test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs
  20. +43
    -0
      test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs
  21. +0
    -0
      test/TensorFlowNET.UnitTest/TF_API/TFNetApiTest.cs
  22. +0
    -0
      test/TensorFlowNET.UnitTest/TF_API/ZeroFractionTest.cs
  23. +0
    -0
      test/TensorFlowNET.UnitTest/TF_API/nn_test.py
  24. +16
    -0
      test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs

+ 25
- 2
src/TensorFlowNET.Core/APIs/tf.io.cs View File

@@ -21,9 +21,32 @@ namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
{ {
public IoApi io { get; } = new IoApi();

public class IoApi
{
io_ops ops;
public IoApi()
{
ops = new io_ops();
}

public Tensor read_file(string filename, string name = null)
=> ops.read_file(filename, name);

public Tensor read_file(Tensor filename, string name = null)
=> ops.read_file(filename, name);

public Operation save_v2(Tensor prefix, string[] tensor_names,
string[] shape_and_slices, Tensor[] tensors, string name = null)
=> ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name);

public Tensor[] restore_v2(Tensor prefix, string[] tensor_names,
string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name);
}

public GFile gfile = new GFile(); public GFile gfile = new GFile();
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);
public Tensor read_file(Tensor filename, string name = null) => gen_io_ops.read_file(filename, name);


public ITensorOrOperation[] import_graph_def(GraphDef graph_def, public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary<string, Tensor> input_map = null, Dictionary<string, Tensor> input_map = null,


+ 19
- 3
src/TensorFlowNET.Core/APIs/tf.strings.cs View File

@@ -21,12 +21,28 @@ namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
{ {
public strings_internal strings = new strings_internal();
public class strings_internal
public StringsApi strings { get; } = new StringsApi();

public class StringsApi
{ {
string_ops ops = new string_ops();

/// <summary>
/// Return substrings from `Tensor` of strings.
/// </summary>
/// <param name="input"></param>
/// <param name="pos"></param>
/// <param name="len"></param>
/// <param name="name"></param>
/// <param name="uint"></param>
/// <returns></returns>
public Tensor substr(Tensor input, int pos, int len, public Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE") string name = null, string @uint = "BYTE")
=> string_ops.substr(input, pos, len, name: name, @uint: @uint);
=> ops.substr(input, pos, len, @uint: @uint, name: name);

public Tensor substr(string input, int pos, int len,
string name = null, string @uint = "BYTE")
=> ops.substr(input, pos, len, @uint: @uint, name: name);
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -47,7 +47,7 @@ namespace Tensorflow.Eager
status.Check(true); status.Check(true);
} }
} }
if (status.ok())
if (status.ok() && attrs != null)
SetOpAttrs(op, attrs); SetOpAttrs(op, attrs);


var outputs = new IntPtr[num_outputs]; var outputs = new IntPtr[num_outputs];


+ 0
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -204,9 +204,6 @@ namespace Tensorflow.Eager
input_handle = input.EagerTensorHandle; input_handle = input.EagerTensorHandle;
flattened_inputs.Add(input); flattened_inputs.Add(input);
break; break;
case EagerTensor[] input_list:
input_handle = input_list[0].EagerTensorHandle;
break;
default: default:
var tensor = tf.convert_to_tensor(inputs); var tensor = tf.convert_to_tensor(inputs);
input_handle = (tensor as EagerTensor).EagerTensorHandle; input_handle = (tensor as EagerTensor).EagerTensorHandle;


+ 20
- 0
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -376,6 +376,16 @@ namespace Tensorflow
{ {
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
{ {
if (tf.context.executing_eagerly())
{
if (pred.ToArray<bool>()[0])
return true_fn() as Tensor;
else
return false_fn() as Tensor;

return null;
}

// Add the Switch to the graph. // Add the Switch to the graph.
var switch_result= @switch(pred, pred); var switch_result= @switch(pred, pred);
var (p_2, p_1 )= (switch_result[0], switch_result[1]); var (p_2, p_1 )= (switch_result[0], switch_result[1]);
@@ -450,6 +460,16 @@ namespace Tensorflow
{ {
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
{ {
if (tf.context.executing_eagerly())
{
if (pred.ToArray<bool>()[0])
return true_fn() as Tensor[];
else
return false_fn() as Tensor[];

return null;
}

// Add the Switch to the graph. // Add the Switch to the graph.
var switch_result = @switch(pred, pred); var switch_result = @switch(pred, pred);
var p_2 = switch_result[0]; var p_2 = switch_result[0];


+ 0
- 40
src/TensorFlowNET.Core/Operations/gen_string_ops.cs View File

@@ -1,40 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class gen_string_ops
{
public static Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE")
{
var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new
{
input,
pos,
len,
unit = @uint
});

return _op.output;
}
}
}

+ 9
- 5
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -16,6 +16,7 @@


using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Operations; using Tensorflow.Operations;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -63,7 +64,7 @@ namespace Tensorflow
Func<ITensorOrOperation> _bmp = () => Func<ITensorOrOperation> _bmp = () =>
{ {
int bmp_channels = channels; int bmp_channels = channels;
var signature = string_ops.substr(contents, 0, 2);
var signature = tf.strings.substr(contents, 0, 2);
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
@@ -98,7 +99,7 @@ namespace Tensorflow


return tf_with(ops.name_scope(name, "decode_image"), scope => return tf_with(ops.name_scope(name, "decode_image"), scope =>
{ {
substr = string_ops.substr(contents, 0, 3);
substr = tf.strings.substr(contents, 0, 3);
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
}); });
} }
@@ -128,8 +129,11 @@ namespace Tensorflow
{ {
return tf_with(ops.name_scope(name, "is_jpeg"), scope => return tf_with(ops.name_scope(name, "is_jpeg"), scope =>
{ {
var substr = string_ops.substr(contents, 0, 3);
return math_ops.equal(substr, "\xff\xd8\xff", name: name);
var substr = tf.strings.substr(contents, 0, 3);
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff });
var jpg_tensor = tf.constant(jpg);
var result = math_ops.equal(substr, jpg_tensor, name: name);
return result;
}); });
} }


@@ -137,7 +141,7 @@ namespace Tensorflow
{ {
return tf_with(ops.name_scope(name, "is_png"), scope => return tf_with(ops.name_scope(name, "is_png"), scope =>
{ {
var substr = string_ops.substr(contents, 0, 3);
var substr = tf.strings.substr(contents, 0, 3);
return math_ops.equal(substr, @"\211PN", name: name); return math_ops.equal(substr, @"\211PN", name: name);
}); });
} }


src/TensorFlowNET.Core/Operations/gen_io_ops.cs → src/TensorFlowNET.Core/Operations/io_ops.cs View File

@@ -14,31 +14,45 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Tensorflow.Eager;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
public class gen_io_ops
public class io_ops
{ {
public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
{ {
var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors });


return _op; return _op;
} }


public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
{ {
var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });


return _op.outputs; return _op.outputs;
} }


public static Tensor read_file<T>(T filename, string name = null)
public Tensor read_file<T>(T filename, string name = null)
{ {
if (tf.context.executing_eagerly())
{
return read_file_eager_fallback(filename, name: name, tf.context);
}

var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename }); var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename });


return _op.outputs[0]; return _op.outputs[0];
} }

private Tensor read_file_eager_fallback<T>(T filename, string name = null, Context ctx = null)
{
var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING);
var _inputs_flat = new[] { filename_tensor };

return tf._execute.execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0];
}
} }
} }

+ 26
- 3
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -17,6 +17,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -31,8 +32,30 @@ namespace Tensorflow
/// <param name="name"></param> /// <param name="name"></param>
/// <param name="uint"></param> /// <param name="uint"></param>
/// <returns></returns> /// <returns></returns>
public static Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE")
=> gen_string_ops.substr(input, pos, len, name: name, @uint: @uint);
public Tensor substr<T>(T input, int pos, int len,
string @uint = "BYTE", string name = null)
{
if (tf.context.executing_eagerly())
{
var input_tensor = tf.constant(input);
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Substr", name,
null,
input, pos, len,
"unit", @uint);

return results[0];
}

var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new
{
input,
pos,
len,
unit = @uint
});

return _op.output;
}
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs View File

@@ -68,9 +68,9 @@ namespace Tensorflow
throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); throw new ArgumentException($"{nameof(Tensor)} can only be scalar.");


IntPtr stringStartAddress = IntPtr.Zero; IntPtr stringStartAddress = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
ulong dstLen = 0;


c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status.Handle);
c_api.TF_StringDecode((byte*) this.buffer + 8, this.bytesize, (byte**) &stringStartAddress, ref dstLen, tf.status.Handle);
tf.status.Check(true); tf.status.Check(true);


var dstLenInt = checked((int) dstLen); var dstLenInt = checked((int) dstLen);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -453,7 +453,7 @@ namespace Tensorflow
{ {
var buffer = Encoding.UTF8.GetBytes(str); var buffer = Encoding.UTF8.GetBytes(str);
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong)));
AllocationType = AllocationType.Tensorflow; AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);


+ 32
- 3
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -235,13 +235,12 @@ namespace Tensorflow


var buffer = new byte[size][]; var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle); var src = c_api.TF_TensorData(_handle);
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
src += (int)(size * 8); src += (int)(size * 8);
for (int i = 0; i < buffer.Length; i++) for (int i = 0; i < buffer.Length; i++)
{ {
IntPtr dst = IntPtr.Zero; IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status.Handle);
ulong dstLen = 0;
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle);
tf.status.Check(true); tf.status.Check(true);
buffer[i] = new byte[(int)dstLen]; buffer[i] = new byte[(int)dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
@@ -254,5 +253,35 @@ namespace Tensorflow


return _str; return _str;
} }

public unsafe byte[][] StringBytes()
{
if (dtype != TF_DataType.TF_STRING)
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");

//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
//
long size = 1;
foreach (var s in TensorShape.dims)
size *= s;

var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
src += (int)(size * 8);
for (int i = 0; i < buffer.Length; i++)
{
IntPtr dst = IntPtr.Zero;
ulong dstLen = 0;
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle);
tf.status.Check(true);
buffer[i] = new byte[(int)dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
src += (int)read;
}

return buffer;
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -207,7 +207,7 @@ namespace Tensorflow
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, SafeStatusHandle status);
public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status);




public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator;


+ 12
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -132,10 +132,22 @@ namespace Tensorflow


switch (value) switch (value)
{ {
case EagerTensor val:
return val;
case NDArray val: case NDArray val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case string val: case string val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case bool val:
return new EagerTensor(val, ctx.device_name);
case byte val:
return new EagerTensor(val, ctx.device_name);
case byte[] val:
return new EagerTensor(val, ctx.device_name);
case byte[,] val:
return new EagerTensor(val, ctx.device_name);
case byte[,,] val:
return new EagerTensor(val, ctx.device_name);
case int val: case int val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case int[] val: case int[] val:


+ 2
- 2
src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow


if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2)
{ {
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray());
return tf.io.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray());
} }
else else
{ {
@@ -76,7 +76,7 @@ namespace Tensorflow
dtypes.Add(spec.dtype); dtypes.Add(spec.dtype);
} }


return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray());
return tf.io.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray());
} }


public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables,


+ 1
- 2
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -160,7 +160,6 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual(6.0, (double)c); Assert.AreEqual(6.0, (double)c);
} }


[Ignore]
[TestMethod] [TestMethod]
public void StringEncode() public void StringEncode()
{ {
@@ -175,7 +174,7 @@ namespace TensorFlowNET.UnitTest.Basics
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
Assert.AreEqual(encoded_str, str); Assert.AreEqual(encoded_str, str);
Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
//c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
// c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle);
} }


[TestMethod] [TestMethod]


+ 4
- 3
test/TensorFlowNET.UnitTest/ImageTest.cs View File

@@ -2,8 +2,10 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Reflection;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.Basics namespace TensorFlowNET.UnitTest.Basics
@@ -20,11 +22,10 @@ namespace TensorFlowNET.UnitTest.Basics
[TestInitialize] [TestInitialize]
public void Initialize() public void Initialize()
{ {
imgPath = Path.GetFullPath(imgPath);
contents = tf.read_file(imgPath);
imgPath = TestHelper.GetFullPathFromDataDir(imgPath);
contents = tf.io.read_file(imgPath);
} }


[Ignore("")]
[TestMethod] [TestMethod]
public void decode_image() public void decode_image()
{ {


test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs → test/TensorFlowNET.UnitTest/TF_API/ActivationFunctionTest.cs View File


test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs → test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs View File

@@ -6,10 +6,10 @@ using System.Text;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.math_test
namespace TensorFlowNET.UnitTest.TF_API
{ {
[TestClass] [TestClass]
public class MathOperationTest : TFNetApiTest
public class MathApiTest : TFNetApiTest
{ {
// A constant vector of size 6 // A constant vector of size 6
Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f });

+ 43
- 0
test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs View File

@@ -0,0 +1,43 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.UnitTest.TF_API
{
[TestClass]
public class StringsApiTest
{
[TestMethod]
public void StringEqual()
{
var str1 = tf.constant("Hello1");
var str2 = tf.constant("Hello2");
var result = tf.equal(str1, str2);
Assert.IsFalse(result.ToScalar<bool>());

var str3 = tf.constant("Hello1");
result = tf.equal(str1, str3);
Assert.IsTrue(result.ToScalar<bool>());

var str4 = tf.strings.substr(str1, 0, 5);
var str5 = tf.strings.substr(str2, 0, 5);
result = tf.equal(str4, str5);
Assert.IsTrue(result.ToScalar<bool>());
}

[TestMethod]
public void ImageType()
{
var imgPath = TestHelper.GetFullPathFromDataDir("shasta-daisy.jpg");
var contents = tf.io.read_file(imgPath);

var substr = tf.strings.substr(contents, 0, 3);
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff });
var jpg_tensor = tf.constant(jpg);

var result = math_ops.equal(substr, jpg_tensor);
}
}
}

test/TensorFlowNET.UnitTest/TFNetApiTest.cs → test/TensorFlowNET.UnitTest/TF_API/TFNetApiTest.cs View File


test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs → test/TensorFlowNET.UnitTest/TF_API/ZeroFractionTest.cs View File


test/TensorFlowNET.UnitTest/nn_test/nn_test.py → test/TensorFlowNET.UnitTest/TF_API/nn_test.py View File


+ 16
- 0
test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace Tensorflow.UnitTest
{
public class TestHelper
{
public static string GetFullPathFromDataDir(string fileName)
{
var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data");
return Path.GetFullPath(Path.Combine(dir, fileName));
}
}
}

Loading…
Cancel
Save