@@ -21,9 +21,32 @@ namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public IoApi io { get; } = new IoApi(); | |||
public class IoApi | |||
{ | |||
io_ops ops; | |||
public IoApi() | |||
{ | |||
ops = new io_ops(); | |||
} | |||
public Tensor read_file(string filename, string name = null) | |||
=> ops.read_file(filename, name); | |||
public Tensor read_file(Tensor filename, string name = null) | |||
=> ops.read_file(filename, name); | |||
public Operation save_v2(Tensor prefix, string[] tensor_names, | |||
string[] shape_and_slices, Tensor[] tensors, string name = null) | |||
=> ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name); | |||
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, | |||
string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); | |||
} | |||
public GFile gfile = new GFile(); | |||
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | |||
public Tensor read_file(Tensor filename, string name = null) => gen_io_ops.read_file(filename, name); | |||
public ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||
Dictionary<string, Tensor> input_map = null, | |||
@@ -21,12 +21,28 @@ namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public strings_internal strings = new strings_internal(); | |||
public class strings_internal | |||
public StringsApi strings { get; } = new StringsApi(); | |||
public class StringsApi | |||
{ | |||
string_ops ops = new string_ops(); | |||
/// <summary> | |||
/// Return substrings from `Tensor` of strings. | |||
/// </summary> | |||
/// <param name="input"></param> | |||
/// <param name="pos"></param> | |||
/// <param name="len"></param> | |||
/// <param name="name"></param> | |||
/// <param name="uint"></param> | |||
/// <returns></returns> | |||
public Tensor substr(Tensor input, int pos, int len, | |||
string name = null, string @uint = "BYTE") | |||
=> string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||
=> ops.substr(input, pos, len, @uint: @uint, name: name); | |||
public Tensor substr(string input, int pos, int len, | |||
string name = null, string @uint = "BYTE") | |||
=> ops.substr(input, pos, len, @uint: @uint, name: name); | |||
} | |||
} | |||
} |
@@ -47,7 +47,7 @@ namespace Tensorflow.Eager | |||
status.Check(true); | |||
} | |||
} | |||
if (status.ok()) | |||
if (status.ok() && attrs != null) | |||
SetOpAttrs(op, attrs); | |||
var outputs = new IntPtr[num_outputs]; | |||
@@ -204,9 +204,6 @@ namespace Tensorflow.Eager | |||
input_handle = input.EagerTensorHandle; | |||
flattened_inputs.Add(input); | |||
break; | |||
case EagerTensor[] input_list: | |||
input_handle = input_list[0].EagerTensorHandle; | |||
break; | |||
default: | |||
var tensor = tf.convert_to_tensor(inputs); | |||
input_handle = (tensor as EagerTensor).EagerTensorHandle; | |||
@@ -376,6 +376,16 @@ namespace Tensorflow | |||
{ | |||
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
if (pred.ToArray<bool>()[0]) | |||
return true_fn() as Tensor; | |||
else | |||
return false_fn() as Tensor; | |||
return null; | |||
} | |||
// Add the Switch to the graph. | |||
var switch_result= @switch(pred, pred); | |||
var (p_2, p_1 )= (switch_result[0], switch_result[1]); | |||
@@ -450,6 +460,16 @@ namespace Tensorflow | |||
{ | |||
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
if (pred.ToArray<bool>()[0]) | |||
return true_fn() as Tensor[]; | |||
else | |||
return false_fn() as Tensor[]; | |||
return null; | |||
} | |||
// Add the Switch to the graph. | |||
var switch_result = @switch(pred, pred); | |||
var p_2 = switch_result[0]; | |||
@@ -1,40 +0,0 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public class gen_string_ops | |||
{ | |||
public static Tensor substr(Tensor input, int pos, int len, | |||
string name = null, string @uint = "BYTE") | |||
{ | |||
var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new | |||
{ | |||
input, | |||
pos, | |||
len, | |||
unit = @uint | |||
}); | |||
return _op.output; | |||
} | |||
} | |||
} |
@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
@@ -63,7 +64,7 @@ namespace Tensorflow | |||
Func<ITensorOrOperation> _bmp = () => | |||
{ | |||
int bmp_channels = channels; | |||
var signature = string_ops.substr(contents, 0, 2); | |||
var signature = tf.strings.substr(contents, 0, 2); | |||
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | |||
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | |||
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | |||
@@ -98,7 +99,7 @@ namespace Tensorflow | |||
return tf_with(ops.name_scope(name, "decode_image"), scope => | |||
{ | |||
substr = string_ops.substr(contents, 0, 3); | |||
substr = tf.strings.substr(contents, 0, 3); | |||
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | |||
}); | |||
} | |||
@@ -128,8 +129,11 @@ namespace Tensorflow | |||
{ | |||
return tf_with(ops.name_scope(name, "is_jpeg"), scope => | |||
{ | |||
var substr = string_ops.substr(contents, 0, 3); | |||
return math_ops.equal(substr, "\xff\xd8\xff", name: name); | |||
var substr = tf.strings.substr(contents, 0, 3); | |||
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); | |||
var jpg_tensor = tf.constant(jpg); | |||
var result = math_ops.equal(substr, jpg_tensor, name: name); | |||
return result; | |||
}); | |||
} | |||
@@ -137,7 +141,7 @@ namespace Tensorflow | |||
{ | |||
return tf_with(ops.name_scope(name, "is_png"), scope => | |||
{ | |||
var substr = string_ops.substr(contents, 0, 3); | |||
var substr = tf.strings.substr(contents, 0, 3); | |||
return math_ops.equal(substr, @"\211PN", name: name); | |||
}); | |||
} | |||
@@ -14,31 +14,45 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public class gen_io_ops | |||
public class io_ops | |||
{ | |||
public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | |||
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | |||
{ | |||
var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | |||
return _op; | |||
} | |||
public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||
{ | |||
var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | |||
return _op.outputs; | |||
} | |||
public static Tensor read_file<T>(T filename, string name = null) | |||
public Tensor read_file<T>(T filename, string name = null) | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
return read_file_eager_fallback(filename, name: name, tf.context); | |||
} | |||
var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename }); | |||
return _op.outputs[0]; | |||
} | |||
private Tensor read_file_eager_fallback<T>(T filename, string name = null, Context ctx = null) | |||
{ | |||
var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING); | |||
var _inputs_flat = new[] { filename_tensor }; | |||
return tf._execute.execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0]; | |||
} | |||
} | |||
} |
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -31,8 +32,30 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <param name="uint"></param> | |||
/// <returns></returns> | |||
public static Tensor substr(Tensor input, int pos, int len, | |||
string name = null, string @uint = "BYTE") | |||
=> gen_string_ops.substr(input, pos, len, name: name, @uint: @uint); | |||
public Tensor substr<T>(T input, int pos, int len, | |||
string @uint = "BYTE", string name = null) | |||
{ | |||
if (tf.context.executing_eagerly()) | |||
{ | |||
var input_tensor = tf.constant(input); | |||
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
"Substr", name, | |||
null, | |||
input, pos, len, | |||
"unit", @uint); | |||
return results[0]; | |||
} | |||
var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new | |||
{ | |||
input, | |||
pos, | |||
len, | |||
unit = @uint | |||
}); | |||
return _op.output; | |||
} | |||
} | |||
} |
@@ -68,9 +68,9 @@ namespace Tensorflow | |||
throw new ArgumentException($"{nameof(Tensor)} can only be scalar."); | |||
IntPtr stringStartAddress = IntPtr.Zero; | |||
UIntPtr dstLen = UIntPtr.Zero; | |||
ulong dstLen = 0; | |||
c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status.Handle); | |||
c_api.TF_StringDecode((byte*) this.buffer + 8, this.bytesize, (byte**) &stringStartAddress, ref dstLen, tf.status.Handle); | |||
tf.status.Check(true); | |||
var dstLenInt = checked((int) dstLen); | |||
@@ -453,7 +453,7 @@ namespace Tensorflow | |||
{ | |||
var buffer = Encoding.UTF8.GetBytes(str); | |||
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong))); | |||
AllocationType = AllocationType.Tensorflow; | |||
IntPtr tensor = c_api.TF_TensorData(handle); | |||
@@ -235,13 +235,12 @@ namespace Tensorflow | |||
var buffer = new byte[size][]; | |||
var src = c_api.TF_TensorData(_handle); | |||
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||
src += (int)(size * 8); | |||
for (int i = 0; i < buffer.Length; i++) | |||
{ | |||
IntPtr dst = IntPtr.Zero; | |||
UIntPtr dstLen = UIntPtr.Zero; | |||
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status.Handle); | |||
ulong dstLen = 0; | |||
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); | |||
tf.status.Check(true); | |||
buffer[i] = new byte[(int)dstLen]; | |||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||
@@ -254,5 +253,35 @@ namespace Tensorflow | |||
return _str; | |||
} | |||
public unsafe byte[][] StringBytes() | |||
{ | |||
if (dtype != TF_DataType.TF_STRING) | |||
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); | |||
// | |||
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | |||
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | |||
// | |||
long size = 1; | |||
foreach (var s in TensorShape.dims) | |||
size *= s; | |||
var buffer = new byte[size][]; | |||
var src = c_api.TF_TensorData(_handle); | |||
src += (int)(size * 8); | |||
for (int i = 0; i < buffer.Length; i++) | |||
{ | |||
IntPtr dst = IntPtr.Zero; | |||
ulong dstLen = 0; | |||
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle); | |||
tf.status.Check(true); | |||
buffer[i] = new byte[(int)dstLen]; | |||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||
src += (int)read; | |||
} | |||
return buffer; | |||
} | |||
} | |||
} |
@@ -207,7 +207,7 @@ namespace Tensorflow | |||
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, SafeStatusHandle status); | |||
public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); | |||
public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; | |||
@@ -132,10 +132,22 @@ namespace Tensorflow | |||
switch (value) | |||
{ | |||
case EagerTensor val: | |||
return val; | |||
case NDArray val: | |||
return new EagerTensor(val, ctx.device_name); | |||
case string val: | |||
return new EagerTensor(val, ctx.device_name); | |||
case bool val: | |||
return new EagerTensor(val, ctx.device_name); | |||
case byte val: | |||
return new EagerTensor(val, ctx.device_name); | |||
case byte[] val: | |||
return new EagerTensor(val, ctx.device_name); | |||
case byte[,] val: | |||
return new EagerTensor(val, ctx.device_name); | |||
case byte[,,] val: | |||
return new EagerTensor(val, ctx.device_name); | |||
case int val: | |||
return new EagerTensor(val, ctx.device_name); | |||
case int[] val: | |||
@@ -55,7 +55,7 @@ namespace Tensorflow | |||
if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) | |||
{ | |||
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||
return tf.io.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||
} | |||
else | |||
{ | |||
@@ -76,7 +76,7 @@ namespace Tensorflow | |||
dtypes.Add(spec.dtype); | |||
} | |||
return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||
return tf.io.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||
} | |||
public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables, | |||
@@ -160,7 +160,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||
Assert.AreEqual(6.0, (double)c); | |||
} | |||
[Ignore] | |||
[TestMethod] | |||
public void StringEncode() | |||
{ | |||
@@ -175,7 +174,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | |||
Assert.AreEqual(encoded_str, str); | |||
Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | |||
//c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||
// c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle); | |||
} | |||
[TestMethod] | |||
@@ -2,8 +2,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Reflection; | |||
using System.Text; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
@@ -20,11 +22,10 @@ namespace TensorFlowNET.UnitTest.Basics | |||
[TestInitialize] | |||
public void Initialize() | |||
{ | |||
imgPath = Path.GetFullPath(imgPath); | |||
contents = tf.read_file(imgPath); | |||
imgPath = TestHelper.GetFullPathFromDataDir(imgPath); | |||
contents = tf.io.read_file(imgPath); | |||
} | |||
[Ignore("")] | |||
[TestMethod] | |||
public void decode_image() | |||
{ | |||
@@ -6,10 +6,10 @@ using System.Text; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.math_test | |||
namespace TensorFlowNET.UnitTest.TF_API | |||
{ | |||
[TestClass] | |||
public class MathOperationTest : TFNetApiTest | |||
public class MathApiTest : TFNetApiTest | |||
{ | |||
// A constant vector of size 6 | |||
Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); |
@@ -0,0 +1,43 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.UnitTest.TF_API | |||
{ | |||
[TestClass] | |||
public class StringsApiTest | |||
{ | |||
[TestMethod] | |||
public void StringEqual() | |||
{ | |||
var str1 = tf.constant("Hello1"); | |||
var str2 = tf.constant("Hello2"); | |||
var result = tf.equal(str1, str2); | |||
Assert.IsFalse(result.ToScalar<bool>()); | |||
var str3 = tf.constant("Hello1"); | |||
result = tf.equal(str1, str3); | |||
Assert.IsTrue(result.ToScalar<bool>()); | |||
var str4 = tf.strings.substr(str1, 0, 5); | |||
var str5 = tf.strings.substr(str2, 0, 5); | |||
result = tf.equal(str4, str5); | |||
Assert.IsTrue(result.ToScalar<bool>()); | |||
} | |||
[TestMethod] | |||
public void ImageType() | |||
{ | |||
var imgPath = TestHelper.GetFullPathFromDataDir("shasta-daisy.jpg"); | |||
var contents = tf.io.read_file(imgPath); | |||
var substr = tf.strings.substr(contents, 0, 3); | |||
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff }); | |||
var jpg_tensor = tf.constant(jpg); | |||
var result = math_ops.equal(substr, jpg_tensor); | |||
} | |||
} | |||
} |
@@ -0,0 +1,16 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Text; | |||
namespace Tensorflow.UnitTest | |||
{ | |||
public class TestHelper | |||
{ | |||
public static string GetFullPathFromDataDir(string fileName) | |||
{ | |||
var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data"); | |||
return Path.GetFullPath(Path.Combine(dir, fileName)); | |||
} | |||
} | |||
} |