@@ -0,0 +1,47 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using Tensorflow.Queues; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
/// <summary> | |||
/// A FIFOQueue that supports batching variable-sized tensors by padding. | |||
/// </summary> | |||
/// <param name="capacity"></param> | |||
/// <param name="dtypes"></param> | |||
/// <param name="shapes"></param> | |||
/// <param name="names"></param> | |||
/// <param name="shared_name"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public PaddingFIFOQueue PaddingFIFOQueue(int capacity, | |||
TF_DataType[] dtypes, | |||
TensorShape[] shapes, | |||
string[] names = null, | |||
string shared_name = null, | |||
string name = "padding_fifo_queue") | |||
=> new PaddingFIFOQueue(capacity, | |||
dtypes, | |||
shapes, | |||
names, | |||
shared_name: shared_name, | |||
name: name); | |||
} | |||
} |
@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using static Tensorflow.OpDef.Types; | |||
using static Tensorflow.Binding; | |||
using Google.Protobuf; | |||
namespace Tensorflow | |||
{ | |||
@@ -194,7 +195,9 @@ namespace Tensorflow | |||
if (attrs.ContainsKey(key)) | |||
{ | |||
attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); | |||
} else { | |||
} | |||
else | |||
{ | |||
if (attr_def.DefaultValue == null) | |||
{ | |||
throw new TypeError("Missing required positional argument " + key); | |||
@@ -311,6 +314,16 @@ namespace Tensorflow | |||
input_types.AddRange(base_types); | |||
} | |||
public ByteString _MakeStr(string value, AttrDef attr_def) | |||
{ | |||
return ByteString.CopyFromUtf8(value ?? string.Empty); | |||
} | |||
public TensorShapeProto _MakeShape(TensorShape shape, AttrDef attr_def) | |||
{ | |||
return shape.as_proto(); | |||
} | |||
public DataType _MakeType(TF_DataType v, AttrDef attr_def) | |||
{ | |||
return v.as_base_dtype().as_datatype_enum(); | |||
@@ -330,7 +343,7 @@ namespace Tensorflow | |||
switch (attr_def.Type) | |||
{ | |||
case "string": | |||
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||
attr_value.S = _MakeStr((string)value, attr_def); | |||
break; | |||
case "type": | |||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||
@@ -363,6 +376,9 @@ namespace Tensorflow | |||
else if (value is int[] val3) | |||
attr_value.Shape = tensor_util.as_shape(val3); | |||
break; | |||
case "list(shape)": | |||
attr_value.List.Shape.AddRange((value as TensorShape[]).Select(x => _MakeShape(x, attr_def))); | |||
break; | |||
default: | |||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | |||
@@ -0,0 +1,33 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Queues | |||
{ | |||
/// <summary> | |||
/// A FIFOQueue that supports batching variable-sized tensors by padding. | |||
/// </summary> | |||
public class PaddingFIFOQueue : QueueBase | |||
{ | |||
public PaddingFIFOQueue(int capacity, | |||
TF_DataType[] dtypes, | |||
TensorShape[] shapes, | |||
string[] names = null, | |||
string shared_name = null, | |||
string name = "padding_fifo_queue") | |||
: base(dtypes: dtypes, shapes: shapes, names: names) | |||
{ | |||
_queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( | |||
component_types: dtypes, | |||
shapes: shapes, | |||
capacity: capacity, | |||
shared_name: shared_name, | |||
name: name); | |||
_name = _queue_ref.op.name.Split('/').Last(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,56 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Queues | |||
{ | |||
public class QueueBase | |||
{ | |||
protected TF_DataType[] _dtypes; | |||
protected TensorShape[] _shapes; | |||
protected string[] _names; | |||
protected Tensor _queue_ref; | |||
protected string _name; | |||
public QueueBase(TF_DataType[] dtypes, TensorShape[] shapes, string[] names) | |||
{ | |||
_dtypes = dtypes; | |||
_shapes = shapes; | |||
_names = names; | |||
} | |||
public Operation enqueue(Tensor val, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, $"{_name}_enqueue", val), scope => | |||
{ | |||
var vals = new[] { val }; | |||
if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) | |||
return gen_data_flow_ops.queue_enqueue_v2(_queue_ref, vals, name: scope); | |||
else | |||
return gen_data_flow_ops.queue_enqueue(_queue_ref, vals, name: scope); | |||
}); | |||
} | |||
public Tensor[] dequeue_many(int n, string name = null) | |||
{ | |||
if (name == null) | |||
name = $"{_name}_DequeueMany"; | |||
var ret = gen_data_flow_ops.queue_dequeue_many_v2(_queue_ref, n: n, component_types: _dtypes, name: name); | |||
//var op = ret[0].op; | |||
//var cv = tensor_util.constant_value(op.inputs[1]); | |||
//var batch_dim = new Dimension(cv); | |||
return _dequeue_return_value(ret); | |||
} | |||
public Tensor[] _dequeue_return_value(Tensor[] tensors) | |||
{ | |||
if (_names != null) | |||
throw new NotImplementedException(""); | |||
return tensors; | |||
} | |||
} | |||
} |
@@ -22,10 +22,9 @@ namespace Tensorflow | |||
public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) | |||
{ | |||
var _attr_N = indices.Length; | |||
var _op = _op_def_lib._apply_op_helper("DynamicStitch", name, new { indices, data }); | |||
return _op.outputs[0]; | |||
return _op.output; | |||
} | |||
public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, | |||
@@ -45,5 +44,58 @@ namespace Tensorflow | |||
return (null, null); | |||
} | |||
public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | |||
int capacity = -1, string container = "", string shared_name = "", | |||
string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("PaddingFIFOQueueV2", name, new | |||
{ | |||
component_types, | |||
shapes, | |||
capacity, | |||
container, | |||
shared_name | |||
}); | |||
return _op.output; | |||
} | |||
public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new | |||
{ | |||
handle, | |||
components, | |||
timeout_ms | |||
}); | |||
return _op; | |||
} | |||
public static Operation queue_enqueue_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("QueueEnqueueV2", name, new | |||
{ | |||
handle, | |||
components, | |||
timeout_ms | |||
}); | |||
return _op; | |||
} | |||
public static Tensor[] queue_dequeue_many_v2(Tensor handle, int n, TF_DataType[] component_types, int timeout_ms = -1, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("QueueDequeueManyV2", name, new | |||
{ | |||
handle, | |||
n, | |||
component_types, | |||
timeout_ms | |||
}); | |||
return _op.outputs; | |||
} | |||
} | |||
} |
@@ -0,0 +1,36 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
[TestClass] | |||
public class QueueTest | |||
{ | |||
[TestMethod] | |||
public void PaddingFIFOQueue() | |||
{ | |||
var numbers = tf.placeholder(tf.int32); | |||
var queue = tf.PaddingFIFOQueue(capacity: 10, dtypes: new[] { tf.int32 }, shapes: new[] { new TensorShape(-1) }); | |||
var enqueue = queue.enqueue(numbers); | |||
var dequeue_many = queue.dequeue_many(n: 3); | |||
using(var sess = tf.Session()) | |||
{ | |||
sess.run(enqueue, (numbers, new[] { 1 })); | |||
sess.run(enqueue, (numbers, new[] { 2, 3 })); | |||
sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); | |||
var result = sess.run(dequeue_many[0]); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>())); | |||
} | |||
} | |||
} | |||
} |