@@ -21,9 +21,32 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
{ | { | ||||
public IoApi io { get; } = new IoApi(); | |||||
public class IoApi | |||||
{ | |||||
io_ops ops; | |||||
public IoApi() | |||||
{ | |||||
ops = new io_ops(); | |||||
} | |||||
public Tensor read_file(string filename, string name = null) | |||||
=> ops.read_file(filename, name); | |||||
public Tensor read_file(Tensor filename, string name = null) | |||||
=> ops.read_file(filename, name); | |||||
public Operation save_v2(Tensor prefix, string[] tensor_names, | |||||
string[] shape_and_slices, Tensor[] tensors, string name = null) | |||||
=> ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name); | |||||
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, | |||||
string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||||
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); | |||||
} | |||||
public GFile gfile = new GFile(); | public GFile gfile = new GFile(); | ||||
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | |||||
public Tensor read_file(Tensor filename, string name = null) => gen_io_ops.read_file(filename, name); | |||||
public ITensorOrOperation[] import_graph_def(GraphDef graph_def, | public ITensorOrOperation[] import_graph_def(GraphDef graph_def, | ||||
Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
@@ -21,12 +21,28 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
{ | { | ||||
public strings_internal strings = new strings_internal(); | |||||
public class strings_internal | |||||
public StringsApi strings { get; } = new StringsApi(); | |||||
public class StringsApi | |||||
{ | { | ||||
string_ops ops = new string_ops(); | |||||
/// <summary> | |||||
/// Return substrings from `Tensor` of strings. | |||||
/// </summary> | |||||
/// <param name="input"></param> | |||||
/// <param name="pos"></param> | |||||
/// <param name="len"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="uint"></param> | |||||
/// <returns></returns> | |||||
public Tensor substr(Tensor input, int pos, int len, | public Tensor substr(Tensor input, int pos, int len, | ||||
string name = null, string @uint = "BYTE") | string name = null, string @uint = "BYTE") | ||||
=> string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||||
=> ops.substr(input, pos, len, @uint: @uint, name: name); | |||||
public Tensor substr(string input, int pos, int len, | |||||
string name = null, string @uint = "BYTE") | |||||
=> ops.substr(input, pos, len, @uint: @uint, name: name); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -47,7 +47,7 @@ namespace Tensorflow.Eager | |||||
status.Check(true); | status.Check(true); | ||||
} | } | ||||
} | } | ||||
if (status.ok()) | |||||
if (status.ok() && attrs != null) | |||||
SetOpAttrs(op, attrs); | SetOpAttrs(op, attrs); | ||||
var outputs = new IntPtr[num_outputs]; | var outputs = new IntPtr[num_outputs]; | ||||
@@ -204,9 +204,6 @@ namespace Tensorflow.Eager | |||||
input_handle = input.EagerTensorHandle; | input_handle = input.EagerTensorHandle; | ||||
flattened_inputs.Add(input); | flattened_inputs.Add(input); | ||||
break; | break; | ||||
case EagerTensor[] input_list: | |||||
input_handle = input_list[0].EagerTensorHandle; | |||||
break; | |||||
default: | default: | ||||
var tensor = tf.convert_to_tensor(inputs); | var tensor = tf.convert_to_tensor(inputs); | ||||
input_handle = (tensor as EagerTensor).EagerTensorHandle; | input_handle = (tensor as EagerTensor).EagerTensorHandle; | ||||
@@ -376,6 +376,16 @@ namespace Tensorflow | |||||
{ | { | ||||
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
if (pred.ToArray<bool>()[0]) | |||||
return true_fn() as Tensor; | |||||
else | |||||
return false_fn() as Tensor; | |||||
return null; | |||||
} | |||||
// Add the Switch to the graph. | // Add the Switch to the graph. | ||||
var switch_result= @switch(pred, pred); | var switch_result= @switch(pred, pred); | ||||
var (p_2, p_1 )= (switch_result[0], switch_result[1]); | var (p_2, p_1 )= (switch_result[0], switch_result[1]); | ||||
@@ -450,6 +460,16 @@ namespace Tensorflow | |||||
{ | { | ||||
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
if (pred.ToArray<bool>()[0]) | |||||
return true_fn() as Tensor[]; | |||||
else | |||||
return false_fn() as Tensor[]; | |||||
return null; | |||||
} | |||||
// Add the Switch to the graph. | // Add the Switch to the graph. | ||||
var switch_result = @switch(pred, pred); | var switch_result = @switch(pred, pred); | ||||
var p_2 = switch_result[0]; | var p_2 = switch_result[0]; | ||||
@@ -1,40 +0,0 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
public class gen_string_ops | |||||
{ | |||||
public static Tensor substr(Tensor input, int pos, int len, | |||||
string name = null, string @uint = "BYTE") | |||||
{ | |||||
var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new | |||||
{ | |||||
input, | |||||
pos, | |||||
len, | |||||
unit = @uint | |||||
}); | |||||
return _op.output; | |||||
} | |||||
} | |||||
} |
@@ -16,6 +16,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -63,7 +64,7 @@ namespace Tensorflow | |||||
Func<ITensorOrOperation> _bmp = () => | Func<ITensorOrOperation> _bmp = () => | ||||
{ | { | ||||
int bmp_channels = channels; | int bmp_channels = channels; | ||||
var signature = string_ops.substr(contents, 0, 2); | |||||
var signature = tf.strings.substr(contents, 0, 2); | |||||
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | ||||
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | ||||
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | ||||
@@ -98,7 +99,7 @@ namespace Tensorflow | |||||
return tf_with(ops.name_scope(name, "decode_image"), scope => | return tf_with(ops.name_scope(name, "decode_image"), scope => | ||||
{ | { | ||||
substr = string_ops.substr(contents, 0, 3); | |||||
substr = tf.strings.substr(contents, 0, 3); | |||||
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | ||||
}); | }); | ||||
} | } | ||||
@@ -128,8 +129,11 @@ namespace Tensorflow | |||||
{ | { | ||||
return tf_with(ops.name_scope(name, "is_jpeg"), scope => | return tf_with(ops.name_scope(name, "is_jpeg"), scope => | ||||
{ | { | ||||
var substr = string_ops.substr(contents, 0, 3); | |||||
return math_ops.equal(substr, "\xff\xd8\xff", name: name); | |||||
var substr = tf.strings.substr(contents, 0, 3); | |||||
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); | |||||
var jpg_tensor = tf.constant(jpg); | |||||
var result = math_ops.equal(substr, jpg_tensor, name: name); | |||||
return result; | |||||
}); | }); | ||||
} | } | ||||
@@ -137,7 +141,7 @@ namespace Tensorflow | |||||
{ | { | ||||
return tf_with(ops.name_scope(name, "is_png"), scope => | return tf_with(ops.name_scope(name, "is_png"), scope => | ||||
{ | { | ||||
var substr = string_ops.substr(contents, 0, 3); | |||||
var substr = tf.strings.substr(contents, 0, 3); | |||||
return math_ops.equal(substr, @"\211PN", name: name); | return math_ops.equal(substr, @"\211PN", name: name); | ||||
}); | }); | ||||
} | } | ||||
@@ -14,31 +14,45 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class gen_io_ops | |||||
public class io_ops | |||||
{ | { | ||||
public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | |||||
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | |||||
{ | { | ||||
var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | ||||
return _op; | return _op; | ||||
} | } | ||||
public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||||
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||||
{ | { | ||||
var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | ||||
return _op.outputs; | return _op.outputs; | ||||
} | } | ||||
public static Tensor read_file<T>(T filename, string name = null) | |||||
public Tensor read_file<T>(T filename, string name = null) | |||||
{ | { | ||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
return read_file_eager_fallback(filename, name: name, tf.context); | |||||
} | |||||
var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename }); | var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
private Tensor read_file_eager_fallback<T>(T filename, string name = null, Context ctx = null) | |||||
{ | |||||
var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING); | |||||
var _inputs_flat = new[] { filename_tensor }; | |||||
return tf._execute.execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0]; | |||||
} | |||||
} | } | ||||
} | } |
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -31,8 +32,30 @@ namespace Tensorflow | |||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <param name="uint"></param> | /// <param name="uint"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor substr(Tensor input, int pos, int len, | |||||
string name = null, string @uint = "BYTE") | |||||
=> gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||||
public Tensor substr<T>(T input, int pos, int len, | |||||
string @uint = "BYTE", string name = null) | |||||
{ | |||||
if (tf.context.executing_eagerly()) | |||||
{ | |||||
var input_tensor = tf.constant(input); | |||||
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
"Substr", name, | |||||
null, | |||||
input, pos, len, | |||||
"unit", @uint); | |||||
return results[0]; | |||||
} | |||||
var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new | |||||
{ | |||||
input, | |||||
pos, | |||||
len, | |||||
unit = @uint | |||||
}); | |||||
return _op.output; | |||||
} | |||||
} | } | ||||
} | } |
@@ -68,9 +68,9 @@ namespace Tensorflow | |||||
throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); | throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); | ||||
IntPtr stringStartAddress = IntPtr.Zero; | IntPtr stringStartAddress = IntPtr.Zero; | ||||
UIntPtr dstLen = UIntPtr.Zero; | |||||
ulong dstLen = 0; | |||||
c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status.Handle); | |||||
c_api.TF_StringDecode((byte*) this.buffer + 8, this.bytesize, (byte**) &stringStartAddress, ref dstLen, tf.status.Handle); | |||||
tf.status.Check(true); | tf.status.Check(true); | ||||
var dstLenInt = checked((int) dstLen); | var dstLenInt = checked((int) dstLen); | ||||
@@ -453,7 +453,7 @@ namespace Tensorflow | |||||
{ | { | ||||
var buffer = Encoding.UTF8.GetBytes(str); | var buffer = Encoding.UTF8.GetBytes(str); | ||||
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | ||||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong))); | |||||
AllocationType = AllocationType.Tensorflow; | AllocationType = AllocationType.Tensorflow; | ||||
IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
@@ -235,13 +235,12 @@ namespace Tensorflow | |||||
var buffer = new byte[size][]; | var buffer = new byte[size][]; | ||||
var src = c_api.TF_TensorData(_handle); | var src = c_api.TF_TensorData(_handle); | ||||
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||||
src += (int)(size * 8); | src += (int)(size * 8); | ||||
for (int i = 0; i < buffer.Length; i++) | for (int i = 0; i < buffer.Length; i++) | ||||
{ | { | ||||
IntPtr dst = IntPtr.Zero; | IntPtr dst = IntPtr.Zero; | ||||
UIntPtr dstLen = UIntPtr.Zero; | |||||
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status.Handle); | |||||
ulong dstLen = 0; | |||||
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); | |||||
tf.status.Check(true); | tf.status.Check(true); | ||||
buffer[i] = new byte[(int)dstLen]; | buffer[i] = new byte[(int)dstLen]; | ||||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | ||||
@@ -254,5 +253,35 @@ namespace Tensorflow | |||||
return _str; | return _str; | ||||
} | } | ||||
public unsafe byte[][] StringBytes() | |||||
{ | |||||
if (dtype != TF_DataType.TF_STRING) | |||||
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); | |||||
// | |||||
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | |||||
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | |||||
// | |||||
long size = 1; | |||||
foreach (var s in TensorShape.dims) | |||||
size *= s; | |||||
var buffer = new byte[size][]; | |||||
var src = c_api.TF_TensorData(_handle); | |||||
src += (int)(size * 8); | |||||
for (int i = 0; i < buffer.Length; i++) | |||||
{ | |||||
IntPtr dst = IntPtr.Zero; | |||||
ulong dstLen = 0; | |||||
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); | |||||
tf.status.Check(true); | |||||
buffer[i] = new byte[(int)dstLen]; | |||||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||||
src += (int)read; | |||||
} | |||||
return buffer; | |||||
} | |||||
} | } | ||||
} | } |
@@ -207,7 +207,7 @@ namespace Tensorflow | |||||
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); | public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, SafeStatusHandle status); | |||||
public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); | |||||
public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; | public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; | ||||
@@ -132,10 +132,22 @@ namespace Tensorflow | |||||
switch (value) | switch (value) | ||||
{ | { | ||||
case EagerTensor val: | |||||
return val; | |||||
case NDArray val: | case NDArray val: | ||||
return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
case string val: | case string val: | ||||
return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
case bool val: | |||||
return new EagerTensor(val, ctx.device_name); | |||||
case byte val: | |||||
return new EagerTensor(val, ctx.device_name); | |||||
case byte[] val: | |||||
return new EagerTensor(val, ctx.device_name); | |||||
case byte[,] val: | |||||
return new EagerTensor(val, ctx.device_name); | |||||
case byte[,,] val: | |||||
return new EagerTensor(val, ctx.device_name); | |||||
case int val: | case int val: | ||||
return new EagerTensor(val, ctx.device_name); | return new EagerTensor(val, ctx.device_name); | ||||
case int[] val: | case int[] val: | ||||
@@ -55,7 +55,7 @@ namespace Tensorflow | |||||
if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) | if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) | ||||
{ | { | ||||
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||||
return tf.io.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -76,7 +76,7 @@ namespace Tensorflow | |||||
dtypes.Add(spec.dtype); | dtypes.Add(spec.dtype); | ||||
} | } | ||||
return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||||
return tf.io.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||||
} | } | ||||
public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, | public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, | ||||
@@ -160,7 +160,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
Assert.AreEqual(6.0, (double)c); | Assert.AreEqual(6.0, (double)c); | ||||
} | } | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void StringEncode() | public void StringEncode() | ||||
{ | { | ||||
@@ -175,7 +174,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | ||||
Assert.AreEqual(encoded_str, str); | Assert.AreEqual(encoded_str, str); | ||||
Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | ||||
//c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||||
// c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -2,8 +2,10 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | using System.IO; | ||||
using System.Reflection; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
@@ -20,11 +22,10 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
[TestInitialize] | [TestInitialize] | ||||
public void Initialize() | public void Initialize() | ||||
{ | { | ||||
imgPath = Path.GetFullPath(imgPath); | |||||
contents = tf.read_file(imgPath); | |||||
imgPath = TestHelper.GetFullPathFromDataDir(imgPath); | |||||
contents = tf.io.read_file(imgPath); | |||||
} | } | ||||
[Ignore("")] | |||||
[TestMethod] | [TestMethod] | ||||
public void decode_image() | public void decode_image() | ||||
{ | { | ||||
@@ -6,10 +6,10 @@ using System.Text; | |||||
using Tensorflow; | using Tensorflow; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.math_test | |||||
namespace TensorFlowNET.UnitTest.TF_API | |||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class MathOperationTest : TFNetApiTest | |||||
public class MathApiTest : TFNetApiTest | |||||
{ | { | ||||
// A constant vector of size 6 | // A constant vector of size 6 | ||||
Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); |
@@ -0,0 +1,43 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.UnitTest.TF_API | |||||
{ | |||||
[TestClass] | |||||
public class StringsApiTest | |||||
{ | |||||
[TestMethod] | |||||
public void StringEqual() | |||||
{ | |||||
var str1 = tf.constant("Hello1"); | |||||
var str2 = tf.constant("Hello2"); | |||||
var result = tf.equal(str1, str2); | |||||
Assert.IsFalse(result.ToScalar<bool>()); | |||||
var str3 = tf.constant("Hello1"); | |||||
result = tf.equal(str1, str3); | |||||
Assert.IsTrue(result.ToScalar<bool>()); | |||||
var str4 = tf.strings.substr(str1, 0, 5); | |||||
var str5 = tf.strings.substr(str2, 0, 5); | |||||
result = tf.equal(str4, str5); | |||||
Assert.IsTrue(result.ToScalar<bool>()); | |||||
} | |||||
[TestMethod] | |||||
public void ImageType() | |||||
{ | |||||
var imgPath = TestHelper.GetFullPathFromDataDir("shasta-daisy.jpg"); | |||||
var contents = tf.io.read_file(imgPath); | |||||
var substr = tf.strings.substr(contents, 0, 3); | |||||
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); | |||||
var jpg_tensor = tf.constant(jpg); | |||||
var result = math_ops.equal(substr, jpg_tensor); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Text; | |||||
namespace Tensorflow.UnitTest | |||||
{ | |||||
public class TestHelper | |||||
{ | |||||
public static string GetFullPathFromDataDir(string fileName) | |||||
{ | |||||
var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data"); | |||||
return Path.GetFullPath(Path.Combine(dir, fileName)); | |||||
} | |||||
} | |||||
} |