Browse Source

fix tf.cond.

tags/v0.20
Oceania2018 5 years ago
parent
commit
16ea35126d
24 changed files with 234 additions and 77 deletions
  1. +25
    -2
      src/TensorFlowNET.Core/APIs/tf.io.cs
  2. +19
    -3
      src/TensorFlowNET.Core/APIs/tf.strings.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  4. +0
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  5. +20
    -0
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  6. +0
    -40
      src/TensorFlowNET.Core/Operations/gen_string_ops.cs
  7. +9
    -5
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  8. +18
    -4
      src/TensorFlowNET.Core/Operations/io_ops.cs
  9. +26
    -3
      src/TensorFlowNET.Core/Operations/string_ops.cs
  10. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  12. +32
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  14. +12
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  15. +2
    -2
      src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs
  16. +1
    -2
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  17. +4
    -3
      test/TensorFlowNET.UnitTest/ImageTest.cs
  18. +0
    -0
      test/TensorFlowNET.UnitTest/TF_API/ActivationFunctionTest.cs
  19. +2
    -2
      test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs
  20. +43
    -0
      test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs
  21. +0
    -0
      test/TensorFlowNET.UnitTest/TF_API/TFNetApiTest.cs
  22. +0
    -0
      test/TensorFlowNET.UnitTest/TF_API/ZeroFractionTest.cs
  23. +0
    -0
      test/TensorFlowNET.UnitTest/TF_API/nn_test.py
  24. +16
    -0
      test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs

+ 25
- 2
src/TensorFlowNET.Core/APIs/tf.io.cs View File

@@ -21,9 +21,32 @@ namespace Tensorflow
{
public partial class tensorflow
{
public IoApi io { get; } = new IoApi();

public class IoApi
{
io_ops ops;
public IoApi()
{
ops = new io_ops();
}

public Tensor read_file(string filename, string name = null)
=> ops.read_file(filename, name);

public Tensor read_file(Tensor filename, string name = null)
=> ops.read_file(filename, name);

public Operation save_v2(Tensor prefix, string[] tensor_names,
string[] shape_and_slices, Tensor[] tensors, string name = null)
=> ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name);

public Tensor[] restore_v2(Tensor prefix, string[] tensor_names,
string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name);
}

public GFile gfile = new GFile();
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);
public Tensor read_file(Tensor filename, string name = null) => gen_io_ops.read_file(filename, name);

public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary<string, Tensor> input_map = null,


+ 19
- 3
src/TensorFlowNET.Core/APIs/tf.strings.cs View File

@@ -21,12 +21,28 @@ namespace Tensorflow
{
public partial class tensorflow
{
public strings_internal strings = new strings_internal();
public class strings_internal
public StringsApi strings { get; } = new StringsApi();

public class StringsApi
{
string_ops ops = new string_ops();

/// <summary>
/// Return substrings from `Tensor` of strings.
/// </summary>
/// <param name="input"></param>
/// <param name="pos"></param>
/// <param name="len"></param>
/// <param name="name"></param>
/// <param name="uint"></param>
/// <returns></returns>
public Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE")
=> string_ops.substr(input, pos, len, name: name, @uint: @uint);
=> ops.substr(input, pos, len, @uint: @uint, name: name);

public Tensor substr(string input, int pos, int len,
string name = null, string @uint = "BYTE")
=> ops.substr(input, pos, len, @uint: @uint, name: name);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -47,7 +47,7 @@ namespace Tensorflow.Eager
status.Check(true);
}
}
if (status.ok())
if (status.ok() && attrs != null)
SetOpAttrs(op, attrs);

var outputs = new IntPtr[num_outputs];


+ 0
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -204,9 +204,6 @@ namespace Tensorflow.Eager
input_handle = input.EagerTensorHandle;
flattened_inputs.Add(input);
break;
case EagerTensor[] input_list:
input_handle = input_list[0].EagerTensorHandle;
break;
default:
var tensor = tf.convert_to_tensor(inputs);
input_handle = (tensor as EagerTensor).EagerTensorHandle;


+ 20
- 0
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -376,6 +376,16 @@ namespace Tensorflow
{
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
{
if (tf.context.executing_eagerly())
{
if (pred.ToArray<bool>()[0])
return true_fn() as Tensor;
else
return false_fn() as Tensor;

return null;
}

// Add the Switch to the graph.
var switch_result= @switch(pred, pred);
var (p_2, p_1 )= (switch_result[0], switch_result[1]);
@@ -450,6 +460,16 @@ namespace Tensorflow
{
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
{
if (tf.context.executing_eagerly())
{
if (pred.ToArray<bool>()[0])
return true_fn() as Tensor[];
else
return false_fn() as Tensor[];

return null;
}

// Add the Switch to the graph.
var switch_result = @switch(pred, pred);
var p_2 = switch_result[0];


+ 0
- 40
src/TensorFlowNET.Core/Operations/gen_string_ops.cs View File

@@ -1,40 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class gen_string_ops
{
public static Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE")
{
var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new
{
input,
pos,
len,
unit = @uint
});

return _op.output;
}
}
}

+ 9
- 5
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;
using static Tensorflow.Binding;
@@ -63,7 +64,7 @@ namespace Tensorflow
Func<ITensorOrOperation> _bmp = () =>
{
int bmp_channels = channels;
var signature = string_ops.substr(contents, 0, 2);
var signature = tf.strings.substr(contents, 0, 2);
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
@@ -98,7 +99,7 @@ namespace Tensorflow

return tf_with(ops.name_scope(name, "decode_image"), scope =>
{
substr = string_ops.substr(contents, 0, 3);
substr = tf.strings.substr(contents, 0, 3);
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
});
}
@@ -128,8 +129,11 @@ namespace Tensorflow
{
return tf_with(ops.name_scope(name, "is_jpeg"), scope =>
{
var substr = string_ops.substr(contents, 0, 3);
return math_ops.equal(substr, "\xff\xd8\xff", name: name);
var substr = tf.strings.substr(contents, 0, 3);
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff });
var jpg_tensor = tf.constant(jpg);
var result = math_ops.equal(substr, jpg_tensor, name: name);
return result;
});
}

@@ -137,7 +141,7 @@ namespace Tensorflow
{
return tf_with(ops.name_scope(name, "is_png"), scope =>
{
var substr = string_ops.substr(contents, 0, 3);
var substr = tf.strings.substr(contents, 0, 3);
return math_ops.equal(substr, @"\211PN", name: name);
});
}


src/TensorFlowNET.Core/Operations/gen_io_ops.cs → src/TensorFlowNET.Core/Operations/io_ops.cs View File

@@ -14,31 +14,45 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class gen_io_ops
public class io_ops
{
public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
{
var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors });

return _op;
}

public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
{
var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });

return _op.outputs;
}

public static Tensor read_file<T>(T filename, string name = null)
public Tensor read_file<T>(T filename, string name = null)
{
if (tf.context.executing_eagerly())
{
return read_file_eager_fallback(filename, name: name, tf.context);
}

var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename });

return _op.outputs[0];
}

private Tensor read_file_eager_fallback<T>(T filename, string name = null, Context ctx = null)
{
var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING);
var _inputs_flat = new[] { filename_tensor };

return tf._execute.execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0];
}
}
}

+ 26
- 3
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -31,8 +32,30 @@ namespace Tensorflow
/// <param name="name"></param>
/// <param name="uint"></param>
/// <returns></returns>
public static Tensor substr(Tensor input, int pos, int len,
string name = null, string @uint = "BYTE")
=> gen_string_ops.substr(input, pos, len, name: name, @uint: @uint);
public Tensor substr<T>(T input, int pos, int len,
string @uint = "BYTE", string name = null)
{
if (tf.context.executing_eagerly())
{
var input_tensor = tf.constant(input);
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Substr", name,
null,
input, pos, len,
"unit", @uint);

return results[0];
}

var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new
{
input,
pos,
len,
unit = @uint
});

return _op.output;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs View File

@@ -68,9 +68,9 @@ namespace Tensorflow
throw new ArgumentException($"{nameof(Tensor)} can only be scalar.");

IntPtr stringStartAddress = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
ulong dstLen = 0;

c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status.Handle);
c_api.TF_StringDecode((byte*) this.buffer + 8, this.bytesize, (byte**) &stringStartAddress, ref dstLen, tf.status.Handle);
tf.status.Check(true);

var dstLenInt = checked((int) dstLen);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -453,7 +453,7 @@ namespace Tensorflow
{
var buffer = Encoding.UTF8.GetBytes(str);
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + sizeof(ulong)));
AllocationType = AllocationType.Tensorflow;

IntPtr tensor = c_api.TF_TensorData(handle);


+ 32
- 3
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -235,13 +235,12 @@ namespace Tensorflow

var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
src += (int)(size * 8);
for (int i = 0; i < buffer.Length; i++)
{
IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status.Handle);
ulong dstLen = 0;
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle);
tf.status.Check(true);
buffer[i] = new byte[(int)dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
@@ -254,5 +253,35 @@ namespace Tensorflow

return _str;
}

public unsafe byte[][] StringBytes()
{
if (dtype != TF_DataType.TF_STRING)
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");

//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
//
long size = 1;
foreach (var s in TensorShape.dims)
size *= s;

var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
src += (int)(size * 8);
for (int i = 0; i < buffer.Length; i++)
{
IntPtr dst = IntPtr.Zero;
ulong dstLen = 0;
var read = c_api.TF_StringDecode((byte*)src, bytesize, (byte**)&dst, ref dstLen, tf.status.Handle);
tf.status.Check(true);
buffer[i] = new byte[(int)dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
src += (int)read;
}

return buffer;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -207,7 +207,7 @@ namespace Tensorflow
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, SafeStatusHandle status);
public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status);


public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator;


+ 12
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -132,10 +132,22 @@ namespace Tensorflow

switch (value)
{
case EagerTensor val:
return val;
case NDArray val:
return new EagerTensor(val, ctx.device_name);
case string val:
return new EagerTensor(val, ctx.device_name);
case bool val:
return new EagerTensor(val, ctx.device_name);
case byte val:
return new EagerTensor(val, ctx.device_name);
case byte[] val:
return new EagerTensor(val, ctx.device_name);
case byte[,] val:
return new EagerTensor(val, ctx.device_name);
case byte[,,] val:
return new EagerTensor(val, ctx.device_name);
case int val:
return new EagerTensor(val, ctx.device_name);
case int[] val:


+ 2
- 2
src/TensorFlowNET.Core/Training/Saving/BaseSaverBuilder.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow

if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2)
{
return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray());
return tf.io.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray());
}
else
{
@@ -76,7 +76,7 @@ namespace Tensorflow
dtypes.Add(spec.dtype);
}

return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray());
return tf.io.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray());
}

public virtual SaverDef _build_internal(IVariableV1[] names_to_saveables,


+ 1
- 2
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -160,7 +160,6 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual(6.0, (double)c);
}

[Ignore]
[TestMethod]
public void StringEncode()
{
@@ -175,7 +174,7 @@ namespace TensorFlowNET.UnitTest.Basics
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
Assert.AreEqual(encoded_str, str);
Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
//c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
// c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle);
}

[TestMethod]


+ 4
- 3
test/TensorFlowNET.UnitTest/ImageTest.cs View File

@@ -2,8 +2,10 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;
using System.Text;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
@@ -20,11 +22,10 @@ namespace TensorFlowNET.UnitTest.Basics
[TestInitialize]
public void Initialize()
{
imgPath = Path.GetFullPath(imgPath);
contents = tf.read_file(imgPath);
imgPath = TestHelper.GetFullPathFromDataDir(imgPath);
contents = tf.io.read_file(imgPath);
}

[Ignore("")]
[TestMethod]
public void decode_image()
{


test/TensorFlowNET.UnitTest/nn_test/ActivationFunctionTest.cs → test/TensorFlowNET.UnitTest/TF_API/ActivationFunctionTest.cs View File


test/TensorFlowNET.UnitTest/math_test/MathOperationTest.cs → test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs View File

@@ -6,10 +6,10 @@ using System.Text;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.math_test
namespace TensorFlowNET.UnitTest.TF_API
{
[TestClass]
public class MathOperationTest : TFNetApiTest
public class MathApiTest : TFNetApiTest
{
// A constant vector of size 6
Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f });

+ 43
- 0
test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs View File

@@ -0,0 +1,43 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.UnitTest.TF_API
{
[TestClass]
public class StringsApiTest
{
[TestMethod]
public void StringEqual()
{
var str1 = tf.constant("Hello1");
var str2 = tf.constant("Hello2");
var result = tf.equal(str1, str2);
Assert.IsFalse(result.ToScalar<bool>());

var str3 = tf.constant("Hello1");
result = tf.equal(str1, str3);
Assert.IsTrue(result.ToScalar<bool>());

var str4 = tf.strings.substr(str1, 0, 5);
var str5 = tf.strings.substr(str2, 0, 5);
result = tf.equal(str4, str5);
Assert.IsTrue(result.ToScalar<bool>());
}

[TestMethod]
public void ImageType()
{
var imgPath = TestHelper.GetFullPathFromDataDir("shasta-daisy.jpg");
var contents = tf.io.read_file(imgPath);

var substr = tf.strings.substr(contents, 0, 3);
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff });
var jpg_tensor = tf.constant(jpg);

var result = math_ops.equal(substr, jpg_tensor);
}
}
}

test/TensorFlowNET.UnitTest/TFNetApiTest.cs → test/TensorFlowNET.UnitTest/TF_API/TFNetApiTest.cs View File


test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs → test/TensorFlowNET.UnitTest/TF_API/ZeroFractionTest.cs View File


test/TensorFlowNET.UnitTest/nn_test/nn_test.py → test/TensorFlowNET.UnitTest/TF_API/nn_test.py View File


+ 16
- 0
test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace Tensorflow.UnitTest
{
public class TestHelper
{
public static string GetFullPathFromDataDir(string fileName)
{
var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data");
return Path.GetFullPath(Path.Combine(dir, fileName));
}
}
}

Loading…
Cancel
Save