Browse Source

tf.PaddingFIFOQueue #396

tags/v0.12
Oceania2018 6 years ago
parent
commit
0ece29177e
6 changed files with 244 additions and 4 deletions
  1. +47
    -0
      src/TensorFlowNET.Core/APIs/tf.queue.cs
  2. +18
    -2
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  3. +33
    -0
      src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs
  4. +56
    -0
      src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs
  5. +54
    -2
      src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
  6. +36
    -0
      test/TensorFlowNET.UnitTest/QueueTest.cs

+ 47
- 0
src/TensorFlowNET.Core/APIs/tf.queue.cs View File

@@ -0,0 +1,47 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Queues;

namespace Tensorflow
{
public partial class tensorflow
{
/// <summary>
/// A FIFOQueue that supports batching variable-sized tensors by padding.
/// </summary>
/// <param name="capacity"></param>
/// <param name="dtypes"></param>
/// <param name="shapes"></param>
/// <param name="names"></param>
/// <param name="shared_name"></param>
/// <param name="name"></param>
/// <returns></returns>
public PaddingFIFOQueue PaddingFIFOQueue(int capacity,
TF_DataType[] dtypes,
TensorShape[] shapes,
string[] names = null,
string shared_name = null,
string name = "padding_fifo_queue")
=> new PaddingFIFOQueue(capacity,
dtypes,
shapes,
names,
shared_name: shared_name,
name: name);
}
}

+ 18
- 2
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Linq;
using static Tensorflow.OpDef.Types;
using static Tensorflow.Binding;
using Google.Protobuf;

namespace Tensorflow
{
@@ -194,7 +195,9 @@ namespace Tensorflow
if (attrs.ContainsKey(key))
{
attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]);
} else {
}
else
{
if (attr_def.DefaultValue == null)
{
throw new TypeError("Missing required positional argument " + key);
@@ -311,6 +314,16 @@ namespace Tensorflow
input_types.AddRange(base_types);
}

public ByteString _MakeStr(string value, AttrDef attr_def)
{
return ByteString.CopyFromUtf8(value ?? string.Empty);
}

public TensorShapeProto _MakeShape(TensorShape shape, AttrDef attr_def)
{
return shape.as_proto();
}

public DataType _MakeType(TF_DataType v, AttrDef attr_def)
{
return v.as_base_dtype().as_datatype_enum();
@@ -330,7 +343,7 @@ namespace Tensorflow
switch (attr_def.Type)
{
case "string":
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
attr_value.S = _MakeStr((string)value, attr_def);
break;
case "type":
attr_value.Type = _MakeType((TF_DataType)value, attr_def);
@@ -363,6 +376,9 @@ namespace Tensorflow
else if (value is int[] val3)
attr_value.Shape = tensor_util.as_shape(val3);

break;
case "list(shape)":
attr_value.List.Shape.AddRange((value as TensorShape[]).Select(x => _MakeShape(x, attr_def)));
break;
default:
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");


+ 33
- 0
src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow.Queues
{
/// <summary>
/// A FIFOQueue that supports batching variable-sized tensors by padding.
/// </summary>
public class PaddingFIFOQueue : QueueBase
{
public PaddingFIFOQueue(int capacity,
TF_DataType[] dtypes,
TensorShape[] shapes,
string[] names = null,
string shared_name = null,
string name = "padding_fifo_queue")
: base(dtypes: dtypes, shapes: shapes, names: names)
{
_queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
component_types: dtypes,
shapes: shapes,
capacity: capacity,
shared_name: shared_name,
name: name);

_name = _queue_ref.op.name.Split('/').Last();
}
}
}

+ 56
- 0
src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs View File

@@ -0,0 +1,56 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Queues
{
public class QueueBase
{
protected TF_DataType[] _dtypes;
protected TensorShape[] _shapes;
protected string[] _names;
protected Tensor _queue_ref;
protected string _name;

public QueueBase(TF_DataType[] dtypes, TensorShape[] shapes, string[] names)
{
_dtypes = dtypes;
_shapes = shapes;
_names = names;
}

public Operation enqueue(Tensor val, string name = null)
{
return tf_with(ops.name_scope(name, $"{_name}_enqueue", val), scope =>
{
var vals = new[] { val };
if (_queue_ref.dtype == TF_DataType.TF_RESOURCE)
return gen_data_flow_ops.queue_enqueue_v2(_queue_ref, vals, name: scope);
else
return gen_data_flow_ops.queue_enqueue(_queue_ref, vals, name: scope);
});
}

public Tensor[] dequeue_many(int n, string name = null)
{
if (name == null)
name = $"{_name}_DequeueMany";

var ret = gen_data_flow_ops.queue_dequeue_many_v2(_queue_ref, n: n, component_types: _dtypes, name: name);
//var op = ret[0].op;
//var cv = tensor_util.constant_value(op.inputs[1]);
//var batch_dim = new Dimension(cv);

return _dequeue_return_value(ret);
}

public Tensor[] _dequeue_return_value(Tensor[] tensors)
{
if (_names != null)
throw new NotImplementedException("");
return tensors;
}
}
}

+ 54
- 2
src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs View File

@@ -22,10 +22,9 @@ namespace Tensorflow

public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null)
{
var _attr_N = indices.Length;
var _op = _op_def_lib._apply_op_helper("DynamicStitch", name, new { indices, data });

return _op.outputs[0];
return _op.output;
}

public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid,
@@ -45,5 +44,58 @@ namespace Tensorflow

return (null, null);
}

public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes,
int capacity = -1, string container = "", string shared_name = "",
string name = null)
{
var _op = _op_def_lib._apply_op_helper("PaddingFIFOQueueV2", name, new
{
component_types,
shapes,
capacity,
container,
shared_name
});

return _op.output;
}

public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null)
{
var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new
{
handle,
components,
timeout_ms
});

return _op;
}

public static Operation queue_enqueue_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null)
{
var _op = _op_def_lib._apply_op_helper("QueueEnqueueV2", name, new
{
handle,
components,
timeout_ms
});

return _op;
}

public static Tensor[] queue_dequeue_many_v2(Tensor handle, int n, TF_DataType[] component_types, int timeout_ms = -1, string name = null)
{
var _op = _op_def_lib._apply_op_helper("QueueDequeueManyV2", name, new
{
handle,
n,
component_types,
timeout_ms
});

return _op.outputs;
}
}
}

+ 36
- 0
test/TensorFlowNET.UnitTest/QueueTest.cs View File

@@ -0,0 +1,36 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class QueueTest
{
[TestMethod]
public void PaddingFIFOQueue()
{
var numbers = tf.placeholder(tf.int32);
var queue = tf.PaddingFIFOQueue(capacity: 10, dtypes: new[] { tf.int32 }, shapes: new[] { new TensorShape(-1) });
var enqueue = queue.enqueue(numbers);
var dequeue_many = queue.dequeue_many(n: 3);

using(var sess = tf.Session())
{
sess.run(enqueue, (numbers, new[] { 1 }));
sess.run(enqueue, (numbers, new[] { 2, 3 }));
sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));

var result = sess.run(dequeue_many[0]);

Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
}
}
}
}

Loading…
Cancel
Save