feat: support training of RNN.tags/v0.110.0-LSTM-Model
@@ -16,6 +16,7 @@ | |||||
using System; | using System; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using static Tensorflow.CppShapeInferenceResult.Types; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -50,6 +51,19 @@ namespace Tensorflow | |||||
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | ||||
} | } | ||||
public unsafe static byte[] ByteStringPiece(IntPtr handle) | |||||
{ | |||||
byte* str_data = (byte*)handle.ToPointer(); | |||||
List<byte> bytes = new List<byte>(); | |||||
byte current = 255; | |||||
while (current != ((byte)'\0')) | |||||
{ | |||||
current = *(str_data++); | |||||
bytes.Add(current); | |||||
} | |||||
return bytes.Take(bytes.Count - 1).ToArray(); | |||||
} | |||||
[UnmanagedFunctionPointer(CallingConvention.Winapi)] | [UnmanagedFunctionPointer(CallingConvention.Winapi)] | ||||
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | ||||
@@ -46,10 +46,10 @@ namespace Tensorflow | |||||
Tensor loop_vars, | Tensor loop_vars, | ||||
int parallel_iterations = 10) | int parallel_iterations = 10) | ||||
{ | { | ||||
Func<Tensor[], Tensor> cond1 = x | |||||
Func<Tensors, Tensor> cond1 = x | |||||
=> cond(x[0]); | => cond(x[0]); | ||||
Func<Tensor[], Tensor[]> body1 = x | |||||
Func<Tensors, Tensors> body1 = x | |||||
=> new[] { body(x[0]) }; | => new[] { body(x[0]) }; | ||||
var results = control_flow_ops.while_loop(cond1, | var results = control_flow_ops.while_loop(cond1, | ||||
@@ -58,9 +58,9 @@ namespace Tensorflow | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||||
Func<Tensor[], Tensor[]> body, | |||||
Tensor[] loop_vars, | |||||
public Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||||
Func<Tensors, Tensors> body, | |||||
Tensors loop_vars, | |||||
int parallel_iterations = 10, | int parallel_iterations = 10, | ||||
string name = null) | string name = null) | ||||
=> control_flow_ops.while_loop(cond, body, loop_vars, | => control_flow_ops.while_loop(cond, body, loop_vars, | ||||
@@ -71,15 +71,15 @@ namespace Tensorflow | |||||
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | ||||
=> array_ops.split( | => array_ops.split( | ||||
value: value, | value: value, | ||||
num_split: num_split, | |||||
num_or_size_splits: num_split, | |||||
axis: axis, | axis: axis, | ||||
name: name); | name: name); | ||||
public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | ||||
=> array_ops.split( | => array_ops.split( | ||||
value: value, | value: value, | ||||
num_split: num_split, | |||||
axis: axis, | |||||
num_or_size_splits: num_split, | |||||
axis: ops.convert_to_tensor(axis), | |||||
name: name); | name: name); | ||||
public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | ||||
@@ -503,7 +503,7 @@ namespace Tensorflow | |||||
case Tensors tensors: | case Tensors tensors: | ||||
return tensors.dtype; | return tensors.dtype; | ||||
case IEnumerable<Tensor> tensors: | case IEnumerable<Tensor> tensors: | ||||
return tensors.First().dtype; | |||||
return tensors.Where(x => x is not null).First().dtype; | |||||
case RefVariable variable: | case RefVariable variable: | ||||
return variable.dtype; | return variable.dtype; | ||||
case ResourceVariable variable: | case ResourceVariable variable: | ||||
@@ -18,7 +18,12 @@ namespace Tensorflow.Common.Extensions | |||||
return sequence.Take(sequence.Count() - count); | return sequence.Take(sequence.Count() - count); | ||||
} | } | ||||
#endif | #endif | ||||
public static Tensors ToTensors(this IEnumerable<Tensor> tensors) | |||||
public static Tensors ToTensors(this Tensor[] tensors) | |||||
{ | |||||
return new Tensors(tensors); | |||||
} | |||||
public static Tensors ToTensors(this IList<Tensor> tensors) | |||||
{ | { | ||||
return new Tensors(tensors); | return new Tensors(tensors); | ||||
} | } | ||||
@@ -0,0 +1,20 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// This is a temp solution, which should be removed after refactoring `Tensors` | |||||
/// </summary> | |||||
[Obsolete] | |||||
public class FakeTensorByTensorArray: Tensor | |||||
{ | |||||
public TensorArray TensorArray { get; set; } | |||||
public FakeTensorByTensorArray(TensorArray array) | |||||
{ | |||||
TensorArray = array; | |||||
} | |||||
} | |||||
} |
@@ -5,136 +5,65 @@ using System.Text; | |||||
namespace Tensorflow.Common.Types | namespace Tensorflow.Common.Types | ||||
{ | { | ||||
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?> | |||||
public class GeneralizedTensorShape: Nest<Shape> | |||||
{ | { | ||||
public TensorShapeConfig[] Shapes { get; set; } | |||||
/// <summary> | |||||
/// create a single-dim generalized Tensor shape. | |||||
/// </summary> | |||||
/// <param name="dim"></param> | |||||
public GeneralizedTensorShape(int dim, int size = 1) | |||||
public GeneralizedTensorShape(Shape value, string? name = null) | |||||
{ | { | ||||
var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; | |||||
Shapes = Enumerable.Repeat(elem, size).ToArray(); | |||||
//Shapes = new TensorShapeConfig[size]; | |||||
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); | |||||
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); | |||||
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; | |||||
NodeValue = value; | |||||
NestType = NestType.Node; | |||||
} | } | ||||
public GeneralizedTensorShape(Shape shape) | |||||
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null) | |||||
{ | { | ||||
Shapes = new TensorShapeConfig[] { shape }; | |||||
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList(); | |||||
Name = name; | |||||
NestType = NestType.List; | |||||
} | } | ||||
public GeneralizedTensorShape(TensorShapeConfig shape) | |||||
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null) | |||||
{ | { | ||||
Shapes = new TensorShapeConfig[] { shape }; | |||||
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>); | |||||
Name = name; | |||||
NestType = NestType.Dictionary; | |||||
} | } | ||||
public GeneralizedTensorShape(TensorShapeConfig[] shapes) | |||||
public GeneralizedTensorShape(Nest<Shape> other) | |||||
{ | { | ||||
Shapes = shapes; | |||||
} | |||||
public GeneralizedTensorShape(IEnumerable<Shape> shape) | |||||
{ | |||||
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); | |||||
NestType = other.NestType; | |||||
NodeValue = other.NodeValue; | |||||
DictValue = other.DictValue; | |||||
ListValue = other.ListValue; | |||||
Name = other.Name; | |||||
} | } | ||||
public Shape ToSingleShape() | public Shape ToSingleShape() | ||||
{ | { | ||||
if (Shapes.Length != 1) | |||||
var shapes = Flatten().ToList(); | |||||
if (shapes.Count != 1) | |||||
{ | { | ||||
throw new ValueError("The generalized shape contains more than 1 dim."); | throw new ValueError("The generalized shape contains more than 1 dim."); | ||||
} | } | ||||
var shape_config = Shapes[0]; | |||||
Debug.Assert(shape_config is not null); | |||||
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); | |||||
return shapes[0]; | |||||
} | } | ||||
public long ToNumber() | public long ToNumber() | ||||
{ | { | ||||
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1) | |||||
var shapes = Flatten().ToList(); | |||||
if (shapes.Count != 1 || shapes[0].ndim != 1) | |||||
{ | { | ||||
throw new ValueError("The generalized shape contains more than 1 dim."); | throw new ValueError("The generalized shape contains more than 1 dim."); | ||||
} | } | ||||
var res = Shapes[0].Items[0]; | |||||
return res is null ? -1 : res.Value; | |||||
} | |||||
public Shape[] ToShapeArray() | |||||
{ | |||||
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | |||||
} | |||||
public IEnumerable<long?> Flatten() | |||||
{ | |||||
List<long?> result = new List<long?>(); | |||||
foreach(var shapeConfig in Shapes) | |||||
{ | |||||
result.AddRange(shapeConfig.Items); | |||||
} | |||||
return result; | |||||
} | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func) | |||||
{ | |||||
List<Nest<TOut>> lists = new(); | |||||
foreach(var shapeConfig in Shapes) | |||||
{ | |||||
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x))))); | |||||
} | |||||
return new Nest<TOut>(lists); | |||||
return shapes[0].dims[0]; | |||||
} | } | ||||
public Nest<long?> AsNest() | |||||
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs() | |||||
{ | { | ||||
Nest<long?> DealWithSingleShape(TensorShapeConfig config) | |||||
{ | |||||
if (config.Items.Length == 0) | |||||
{ | |||||
return Nest<long?>.Empty; | |||||
} | |||||
else if (config.Items.Length == 1) | |||||
{ | |||||
return new Nest<long?>(config.Items[0]); | |||||
} | |||||
else | |||||
{ | |||||
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x))); | |||||
} | |||||
} | |||||
if(Shapes.Length == 0) | |||||
{ | |||||
return Nest<long?>.Empty; | |||||
} | |||||
else if(Shapes.Length == 1) | |||||
{ | |||||
return DealWithSingleShape(Shapes[0]); | |||||
} | |||||
else | |||||
{ | |||||
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | |||||
} | |||||
} | |||||
public static implicit operator GeneralizedTensorShape(int dims) | |||||
=> new GeneralizedTensorShape(dims); | |||||
public IEnumerator<long?[]> GetEnumerator() | |||||
{ | |||||
foreach (var shape in Shapes) | |||||
{ | |||||
yield return shape.Items; | |||||
} | |||||
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }); | |||||
} | } | ||||
IEnumerator IEnumerable.GetEnumerator() | |||||
public static implicit operator GeneralizedTensorShape(Shape shape) | |||||
{ | { | ||||
return GetEnumerator(); | |||||
return new GeneralizedTensorShape(shape); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -10,6 +10,19 @@ namespace Tensorflow.Common.Types | |||||
/// </summary> | /// </summary> | ||||
public interface INestStructure<T>: INestable<T> | public interface INestStructure<T>: INestable<T> | ||||
{ | { | ||||
NestType NestType { get; } | |||||
/// <summary> | |||||
/// The item count of depth 1 of the nested structure. | |||||
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. | |||||
/// </summary> | |||||
int ShallowNestedCount { get; } | |||||
/// <summary> | |||||
/// The total item count of depth 1 of the nested structure. | |||||
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||||
/// </summary> | |||||
int TotalNestedCount { get; } | |||||
/// <summary> | /// <summary> | ||||
/// Flatten the Nestable object. Node that if the object contains only one value, | /// Flatten the Nestable object. Node that if the object contains only one value, | ||||
/// it will be flattened to an enumerable with one element. | /// it will be flattened to an enumerable with one element. |
@@ -13,7 +13,7 @@ namespace Tensorflow.Common.Types | |||||
/// <param name="template"></param> | /// <param name="template"></param> | ||||
/// <param name="flatItems"></param> | /// <param name="flatItems"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems) | |||||
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems) | |||||
{ | { | ||||
return template.AsNest().PackSequence(flatItems); | return template.AsNest().PackSequence(flatItems); | ||||
} | } | ||||
@@ -28,27 +28,58 @@ namespace Tensorflow.Common.Types | |||||
public static Nest<T> Empty => _empty; | public static Nest<T> Empty => _empty; | ||||
public NestType NestType { get; protected set; } | public NestType NestType { get; protected set; } | ||||
public string? Name { get; set; } | public string? Name { get; set; } | ||||
public T? Value { get; protected set; } | |||||
public List<Nest<T>>? ListValue { get; protected set; } | |||||
public Dictionary<string, Nest<T>>? DictValue { get; protected set; } | |||||
public T? NodeValue { get; protected set; } | |||||
public List<INestStructure<T>>? ListValue { get; protected set; } | |||||
public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; } | |||||
public int ShallowNestedCount | |||||
{ | |||||
get | |||||
{ | |||||
if (NestType == NestType.Empty) | |||||
{ | |||||
return 0; | |||||
} | |||||
else if (NestType == NestType.Node) | |||||
{ | |||||
return 1; | |||||
} | |||||
else if (NestType == NestType.List) | |||||
{ | |||||
return ListValue!.Count; | |||||
} | |||||
else // dict | |||||
{ | |||||
return DictValue!.Count; | |||||
} | |||||
} | |||||
} | |||||
public int TotalNestedCount | |||||
{ | |||||
get | |||||
{ | |||||
return Flatten().Count(); | |||||
} | |||||
} | |||||
protected Nest() { } | protected Nest() { } | ||||
public Nest(T value, string? name = null) | public Nest(T value, string? name = null) | ||||
{ | { | ||||
Value = value; | |||||
NodeValue = value; | |||||
Name = name; | Name = name; | ||||
NestType = NestType.Node; | NestType = NestType.Node; | ||||
} | } | ||||
public Nest(IEnumerable<Nest<T>> values, string? name = null) | |||||
public Nest(IEnumerable<INestStructure<T>> values, string? name = null) | |||||
{ | { | ||||
ListValue = values.ToList(); | ListValue = values.ToList(); | ||||
Name = name; | Name = name; | ||||
NestType = NestType.List; | NestType = NestType.List; | ||||
} | } | ||||
public Nest(Dictionary<string, Nest<T>> value, string? name = null) | |||||
public Nest(Dictionary<string, INestStructure<T>> value, string? name = null) | |||||
{ | { | ||||
DictValue = value; | DictValue = value; | ||||
Name = name; | Name = name; | ||||
@@ -58,7 +89,7 @@ namespace Tensorflow.Common.Types | |||||
public Nest(Nest<T> other) | public Nest(Nest<T> other) | ||||
{ | { | ||||
NestType = other.NestType; | NestType = other.NestType; | ||||
Value = other.Value; | |||||
NodeValue = other.NodeValue; | |||||
DictValue = other.DictValue; | DictValue = other.DictValue; | ||||
ListValue = other.ListValue; | ListValue = other.ListValue; | ||||
Name = other.Name; | Name = other.Name; | ||||
@@ -78,17 +109,17 @@ namespace Tensorflow.Common.Types | |||||
/// </summary> | /// </summary> | ||||
/// <param name="flatItems"></param> | /// <param name="flatItems"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public virtual Nest<T> PackSequence(T[] flatItems) | |||||
public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems) | |||||
{ | { | ||||
if(flatItems.Length == 0) | if(flatItems.Length == 0) | ||||
{ | { | ||||
return Nest<T>.Empty; | |||||
return Nest<TOut>.Empty; | |||||
} | } | ||||
int index = 0; | int index = 0; | ||||
return PackSequenceInternal(this, flatItems, ref index); | return PackSequenceInternal(this, flatItems, ref index); | ||||
} | } | ||||
private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index) | |||||
private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index) | |||||
{ | { | ||||
if(template.NestType == NestType.Node) | if(template.NestType == NestType.Node) | ||||
{ | { | ||||
@@ -96,25 +127,25 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
throw new InvalidArgumentError("The template and flat items are not matched."); | throw new InvalidArgumentError("The template and flat items are not matched."); | ||||
} | } | ||||
return new Nest<T>(flatItems[index++]); | |||||
return new Nest<TOut>(flatItems[index++]); | |||||
} | } | ||||
else if(template.NestType == NestType.List) | else if(template.NestType == NestType.List) | ||||
{ | { | ||||
List<Nest<T>> nestedObjects = new List<Nest<T>>(); | |||||
List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>(); | |||||
for (int i = 0; i < template.ListValue!.Count; i++) | for (int i = 0; i < template.ListValue!.Count; i++) | ||||
{ | { | ||||
nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); | |||||
nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); | |||||
} | } | ||||
return new Nest<T>(nestedObjects); | |||||
return new Nest<TOut>(nestedObjects); | |||||
} | } | ||||
else if(template.NestType == NestType.Node) | else if(template.NestType == NestType.Node) | ||||
{ | { | ||||
Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>(); | |||||
Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>(); | |||||
foreach(var (key, value) in template.DictValue!) | foreach(var (key, value) in template.DictValue!) | ||||
{ | { | ||||
dict[key] = PackSequenceInternal(value, flatItems, ref index); | |||||
dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); | |||||
} | } | ||||
return new Nest<T>(dict); | |||||
return new Nest<TOut>(dict); | |||||
} | } | ||||
// Consider Empty as invalid type. | // Consider Empty as invalid type. | ||||
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | ||||
@@ -166,25 +197,11 @@ namespace Tensorflow.Common.Types | |||||
} | } | ||||
else if(NestType is NestType.List) | else if(NestType is NestType.List) | ||||
{ | { | ||||
foreach(var item in ListValue!) | |||||
{ | |||||
if(item.NestType is NestType.List or NestType.Dictionary) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
return ListValue!.Count > 0; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
foreach (var item in DictValue!.Values) | |||||
{ | |||||
if (item.NestType is NestType.List or NestType.Dictionary) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
return DictValue!.Count > 0; | |||||
} | } | ||||
} | } | ||||
@@ -223,10 +240,10 @@ namespace Tensorflow.Common.Types | |||||
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | ||||
{ | { | ||||
var nested = input.AsNest(); | var nested = input.AsNest(); | ||||
return ReduceInternal(nested); | |||||
return ReduceInternal(nested).AsNest(); | |||||
} | } | ||||
private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||||
private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||||
{ | { | ||||
if(node.NestType == NestType.Empty) | if(node.NestType == NestType.Empty) | ||||
{ | { | ||||
@@ -234,15 +251,15 @@ namespace Tensorflow.Common.Types | |||||
} | } | ||||
else if(node.NestType == NestType.Node) | else if(node.NestType == NestType.Node) | ||||
{ | { | ||||
return node.Value!.AsNest(); | |||||
return node.NodeValue!.AsNest(); | |||||
} | } | ||||
else if(node.NestType == NestType.List) | else if(node.NestType == NestType.List) | ||||
{ | { | ||||
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x))); | |||||
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); | |||||
} | } | ||||
else // Dictionary type | else // Dictionary type | ||||
{ | { | ||||
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); | |||||
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); | |||||
} | } | ||||
} | } | ||||
@@ -252,7 +269,7 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
if(index == 0) | if(index == 0) | ||||
{ | { | ||||
result = node.Value!; | |||||
result = node.NodeValue!; | |||||
return true; | return true; | ||||
} | } | ||||
result = default(T); | result = default(T); | ||||
@@ -264,7 +281,7 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
if(index == 0) | if(index == 0) | ||||
{ | { | ||||
return FindInternal(item, index, out result); | |||||
return FindInternal(item.AsNest(), index, out result); | |||||
} | } | ||||
index--; | index--; | ||||
} | } | ||||
@@ -277,7 +294,7 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
if (index == 0) | if (index == 0) | ||||
{ | { | ||||
return FindInternal(item, index, out result); | |||||
return FindInternal(item.AsNest(), index, out result); | |||||
} | } | ||||
index--; | index--; | ||||
} | } | ||||
@@ -297,7 +314,7 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
if (index == 0) | if (index == 0) | ||||
{ | { | ||||
node.Value = newValue; | |||||
node.NodeValue = newValue; | |||||
return true; | return true; | ||||
} | } | ||||
return false; | return false; | ||||
@@ -308,7 +325,7 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
if (index == 0) | if (index == 0) | ||||
{ | { | ||||
return SetInternal(item, index, newValue); | |||||
return SetInternal(item.AsNest(), index, newValue); | |||||
} | } | ||||
index--; | index--; | ||||
} | } | ||||
@@ -320,7 +337,7 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
if (index == 0) | if (index == 0) | ||||
{ | { | ||||
return SetInternal(item, index, newValue); | |||||
return SetInternal(item.AsNest(), index, newValue); | |||||
} | } | ||||
index--; | index--; | ||||
} | } | ||||
@@ -336,13 +353,13 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
if (node.NestType == NestType.Node) | if (node.NestType == NestType.Node) | ||||
{ | { | ||||
yield return node.Value!; | |||||
yield return node.NodeValue!; | |||||
} | } | ||||
else if (node.NestType == NestType.List) | else if (node.NestType == NestType.List) | ||||
{ | { | ||||
foreach (var item in node.ListValue!) | foreach (var item in node.ListValue!) | ||||
{ | { | ||||
foreach(var val in FlattenInternal(item)) | |||||
foreach(var val in FlattenInternal(item.AsNest())) | |||||
{ | { | ||||
yield return val; | yield return val; | ||||
} | } | ||||
@@ -352,7 +369,7 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
foreach (var item in node.DictValue!.Values) | foreach (var item in node.DictValue!.Values) | ||||
{ | { | ||||
foreach (var val in FlattenInternal(item)) | |||||
foreach (var val in FlattenInternal(item.AsNest())) | |||||
{ | { | ||||
yield return val; | yield return val; | ||||
} | } | ||||
@@ -364,23 +381,23 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
if (NestType == NestType.Node) | if (NestType == NestType.Node) | ||||
{ | { | ||||
return new Nest<TOut>(func(Value!)); | |||||
return new Nest<TOut>(func(NodeValue!)); | |||||
} | } | ||||
else if (NestType == NestType.List) | else if (NestType == NestType.List) | ||||
{ | { | ||||
List<Nest<TOut>> outs = new List<Nest<TOut>>(); | List<Nest<TOut>> outs = new List<Nest<TOut>>(); | ||||
foreach (var item in ListValue!) | foreach (var item in ListValue!) | ||||
{ | { | ||||
outs.Add(item.MapStructureInternal(func)); | |||||
outs.Add(item.AsNest().MapStructureInternal(func)); | |||||
} | } | ||||
return new Nest<TOut>(outs); | return new Nest<TOut>(outs); | ||||
} | } | ||||
else if (NestType == NestType.Dictionary) | else if (NestType == NestType.Dictionary) | ||||
{ | { | ||||
Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>(); | |||||
Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>(); | |||||
foreach (var (key, value) in DictValue!) | foreach (var (key, value) in DictValue!) | ||||
{ | { | ||||
outs.Add(key, value.MapStructureInternal(func)); | |||||
outs.Add(key, value.AsNest().MapStructureInternal(func)); | |||||
} | } | ||||
return new Nest<TOut>(outs); | return new Nest<TOut>(outs); | ||||
} | } | ||||
@@ -417,14 +434,14 @@ namespace Tensorflow.Common.Types | |||||
} | } | ||||
if (node.NestType == NestType.Node) | if (node.NestType == NestType.Node) | ||||
{ | { | ||||
sb.Append(node.Value!.ToString()); | |||||
sb.Append(node.NodeValue!.ToString()); | |||||
} | } | ||||
else if (node.NestType == NestType.List) | else if (node.NestType == NestType.List) | ||||
{ | { | ||||
sb.Append("["); | sb.Append("["); | ||||
for(int i = 0; i < node.ListValue!.Count; i++) | for(int i = 0; i < node.ListValue!.Count; i++) | ||||
{ | { | ||||
WriteString(node.ListValue![i], sb); | |||||
WriteString(node.ListValue![i].AsNest(), sb); | |||||
if(i != node.ListValue!.Count - 1) | if(i != node.ListValue!.Count - 1) | ||||
{ | { | ||||
sb.Append(", "); | sb.Append(", "); | ||||
@@ -440,7 +457,7 @@ namespace Tensorflow.Common.Types | |||||
foreach (var (key, value) in node.DictValue!) | foreach (var (key, value) in node.DictValue!) | ||||
{ | { | ||||
sb.Append($"{key}: "); | sb.Append($"{key}: "); | ||||
WriteString(value, sb); | |||||
WriteString(value.AsNest(), sb); | |||||
if (i != count - 1) | if (i != count - 1) | ||||
{ | { | ||||
sb.Append(", "); | sb.Append(", "); | ||||
@@ -454,5 +471,15 @@ namespace Tensorflow.Common.Types | |||||
sb.Append("<empty>"); | sb.Append("<empty>"); | ||||
} | } | ||||
} | } | ||||
public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs) | |||||
{ | |||||
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 }); | |||||
} | |||||
public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs) | |||||
{ | |||||
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 }); | |||||
} | |||||
} | } | ||||
} | } |
@@ -6,7 +6,11 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | ||||
{ | { | ||||
public NestType NestType => NestType.Dictionary; | |||||
public IDictionary<TKey, TValue> Value { get; set; } | public IDictionary<TKey, TValue> Value { get; set; } | ||||
public int ShallowNestedCount => Values.Count; | |||||
public int TotalNestedCount => Values.Count; | |||||
public NestDictionary(IDictionary<TKey, TValue> dict) | public NestDictionary(IDictionary<TKey, TValue> dict) | ||||
{ | { | ||||
Value = dict; | Value = dict; | ||||
@@ -10,29 +10,39 @@ namespace Tensorflow.Common.Types | |||||
/// <typeparam name="T"></typeparam> | /// <typeparam name="T"></typeparam> | ||||
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | ||||
{ | { | ||||
public List<T> Value { get; set; } | |||||
public NestType NestType => NestType.List; | |||||
public List<T> Values { get; set; } | |||||
public int ShallowNestedCount => Values.Count; | |||||
public int TotalNestedCount => Values.Count; | |||||
public NestList(params T[] values) | |||||
{ | |||||
Values = new List<T>(values); | |||||
} | |||||
public NestList(IEnumerable<T> values) | public NestList(IEnumerable<T> values) | ||||
{ | { | ||||
Value = new List<T>(values); | |||||
Values = new List<T>(values); | |||||
} | } | ||||
public IEnumerable<T> Flatten() | public IEnumerable<T> Flatten() | ||||
{ | { | ||||
return Value; | |||||
return Values; | |||||
} | } | ||||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | ||||
{ | { | ||||
return new NestList<TOut>(Value.Select(x => func(x))); | |||||
return new NestList<TOut>(Values.Select(x => func(x))); | |||||
} | } | ||||
public Nest<T> AsNest() | public Nest<T> AsNest() | ||||
{ | { | ||||
return new Nest<T>(Value.Select(x => new Nest<T>(x))); | |||||
return new Nest<T>(Values.Select(x => new Nest<T>(x))); | |||||
} | } | ||||
// Enumerator implementation | // Enumerator implementation | ||||
public IEnumerator<T> GetEnumerator() | public IEnumerator<T> GetEnumerator() | ||||
{ | { | ||||
return Value.GetEnumerator(); | |||||
return Values.GetEnumerator(); | |||||
} | } | ||||
IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
@@ -10,7 +10,11 @@ namespace Tensorflow.Common.Types | |||||
/// <typeparam name="T"></typeparam> | /// <typeparam name="T"></typeparam> | ||||
public class NestNode<T> : INestStructure<T> | public class NestNode<T> : INestStructure<T> | ||||
{ | { | ||||
public NestType NestType => NestType.Node; | |||||
public T Value { get; set; } | public T Value { get; set; } | ||||
public int ShallowNestedCount => 1; | |||||
public int TotalNestedCount => 1; | |||||
public NestNode(T value) | public NestNode(T value) | ||||
{ | { | ||||
Value = value; | Value = value; | ||||
@@ -161,8 +161,8 @@ namespace Tensorflow | |||||
break; | break; | ||||
} | } | ||||
yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? | |||||
null : new Tensors(results.Skip(FirstInputTensorCount))); | |||||
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? | |||||
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); | |||||
} | } | ||||
} | } | ||||
@@ -352,13 +352,19 @@ namespace Tensorflow.Eager | |||||
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | ||||
break; | break; | ||||
case TF_AttrType.TF_ATTR_SHAPE: | case TF_AttrType.TF_ATTR_SHAPE: | ||||
var dims = (value as long[]).ToArray(); | |||||
long[] dims; | |||||
if (value is Shape shape) dims = shape.dims.ToArray(); | |||||
else if (value is long[] longs) dims = longs.ToArray(); | |||||
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray(); | |||||
else dims = ((long[])value).ToArray(); | |||||
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | ||||
status.Check(true); | status.Check(true); | ||||
break; | break; | ||||
case TF_AttrType.TF_ATTR_FUNC: | case TF_AttrType.TF_ATTR_FUNC: | ||||
if (value is ConcreteFunction func) | if (value is ConcreteFunction func) | ||||
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | ||||
else if(value is string str) | |||||
c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); | |||||
else | else | ||||
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | ||||
break; | break; | ||||
@@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
outgrad_vec = output_gradients.ToList(); | outgrad_vec = output_gradients.ToList(); | ||||
} | } | ||||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true); | |||||
bool unconnected_gradients_zero = unconnected_gradients == "zero"; | bool unconnected_gradients_zero = unconnected_gradients == "zero"; | ||||
@@ -137,7 +137,6 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | ||||
} | } | ||||
Shape tensor_shape = new(dims); | |||||
if(status.Code != TF_Code.TF_OK) | if(status.Code != TF_Code.TF_OK) | ||||
{ | { | ||||
@@ -145,6 +144,7 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
Shape tensor_shape = new(dims); | |||||
return new TapeTensor(id, dtype, tensor_shape); | return new TapeTensor(id, dtype, tensor_shape); | ||||
} | } | ||||
} | } | ||||
@@ -173,8 +173,12 @@ namespace Tensorflow.Eager | |||||
return dtype == dtypes.variant || dtype == dtypes.resource; | return dtype == dtypes.variant || dtype == dtypes.resource; | ||||
} | } | ||||
bool ListContainNone(long[] list) | |||||
bool ListContainNone(long[]? list) | |||||
{ | { | ||||
if(list is null) | |||||
{ | |||||
return true; | |||||
} | |||||
int len = list.Length; | int len = list.Length; | ||||
if(len == 0) | if(len == 0) | ||||
{ | { | ||||
@@ -10,6 +10,11 @@ namespace Tensorflow.Eager | |||||
var str = NDArrayRender.ToString(nd); | var str = NDArrayRender.ToString(nd); | ||||
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | ||||
} | } | ||||
public string ToString(int maxLength) | |||||
{ | |||||
var nd = new NDArray(this); | |||||
var str = NDArrayRender.ToString(nd, maxLength); | |||||
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,19 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Exceptions | |||||
{ | |||||
public class NotOkStatusException : TensorflowException | |||||
{ | |||||
public NotOkStatusException() : base() | |||||
{ | |||||
} | |||||
public NotOkStatusException(string message) : base(message) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -1,4 +1,5 @@ | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | |||||
namespace Tensorflow.Framework.Models | namespace Tensorflow.Framework.Models | ||||
{ | { | ||||
@@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models | |||||
shapes.Insert(0, dim); | shapes.Insert(0, dim); | ||||
return new TensorSpec(shapes.ToArray(), _dtype); | return new TensorSpec(shapes.ToArray(), _dtype); | ||||
} | } | ||||
public static TensorSpec FromTensor(Tensor tensor, string? name = null) | |||||
{ | |||||
if(tensor is EagerTensor) | |||||
{ | |||||
return new TensorSpec(tensor.shape, tensor.dtype, name); | |||||
} | |||||
else | |||||
{ | |||||
return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,89 @@ | |||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow.Framework | |||||
{ | |||||
internal static class auto_control_deps_utils | |||||
{ | |||||
public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; | |||||
public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph) | |||||
{ | |||||
List<int> result = new List<int>(); | |||||
// A cache to store the read only resource inputs of an Op. | |||||
// Operation -> ObjectIdentitySet of resource handles. | |||||
Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs = | |||||
new Dictionary<Operation, HashSet<Tensor>>(); | |||||
for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) | |||||
{ | |||||
Tensor t = func_graph.Inputs[inputIndex]; | |||||
if (t.dtype != dtypes.resource) | |||||
continue; | |||||
bool readOnly = true; | |||||
foreach (var op in t.consumers()) | |||||
{ | |||||
if (opReadOnlyResourceInputs.ContainsKey(op)) | |||||
{ | |||||
if (!opReadOnlyResourceInputs[op].Contains(t)) | |||||
{ | |||||
readOnly = false; | |||||
break; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
List<int> indices = _get_read_only_resource_input_indices_op(op); | |||||
opReadOnlyResourceInputs[op] = new HashSet<Tensor>( | |||||
indices.Select(i => op.inputs[i])); | |||||
if (!opReadOnlyResourceInputs[op].Contains(t)) | |||||
{ | |||||
readOnly = false; | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
if (readOnly) | |||||
result.Add(inputIndex); | |||||
} | |||||
return result; | |||||
} | |||||
private static List<int> _get_read_only_resource_input_indices_op(Operation op) | |||||
{ | |||||
// ignore the RESOURCE_READ_OPS | |||||
int[] read_only_input_indices; | |||||
try | |||||
{ | |||||
read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR); | |||||
} | |||||
catch (InvalidArgumentError) | |||||
{ | |||||
return new List<int>(); | |||||
} | |||||
int read_only_index = 0; | |||||
List<int> result = new(); | |||||
for (int i = 0; i < op.inputs.Length; i++) | |||||
{ | |||||
if (read_only_index >= read_only_input_indices.Length) | |||||
{ | |||||
break; | |||||
} | |||||
if (op.inputs[i].dtype != dtypes.resource) | |||||
{ | |||||
continue; | |||||
} | |||||
if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) | |||||
{ | |||||
result.Add(i); | |||||
read_only_index++; | |||||
} | |||||
} | |||||
return result; | |||||
} | |||||
} | |||||
} |
@@ -42,10 +42,10 @@ namespace Tensorflow.Framework | |||||
func_graph.as_default(); | func_graph.as_default(); | ||||
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | ||||
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | ||||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||||
var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | ||||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||||
// TODO(Rinne): func_graph.ControlOutputs | // TODO(Rinne): func_graph.ControlOutputs | ||||
_set_handle_data(func_graph, fdef); | _set_handle_data(func_graph, fdef); | ||||
@@ -8,6 +8,7 @@ using Tensorflow.Gradients; | |||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using Tensorflow.Common.Extensions; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
@@ -40,6 +41,18 @@ namespace Tensorflow.Functions | |||||
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | ||||
public IEnumerable<IVariableV1> Variables => func_graph.Variables; | public IEnumerable<IVariableV1> Variables => func_graph.Variables; | ||||
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | ||||
internal NameAttrList AsNameAttrList | |||||
{ | |||||
get | |||||
{ | |||||
NameAttrList ret = new() { Name = this.Name }; | |||||
foreach (var (name, value) in _attrs) | |||||
{ | |||||
ret.Attr[name] = value; | |||||
} | |||||
return ret; | |||||
} | |||||
} | |||||
public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
{ | { | ||||
@@ -90,8 +90,7 @@ namespace Tensorflow.Gradients | |||||
? input_values[0].rank + dim_int | ? input_values[0].rank + dim_int | ||||
: dim_int % input_values[0].rank; | : dim_int % input_values[0].rank; | ||||
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | ||||
var sizes_tensor = constant_op.constant(sizes); | |||||
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList(); | |||||
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList(); | |||||
} | } | ||||
else if (constant_op.is_constant(concat_dim)) | else if (constant_op.is_constant(concat_dim)) | ||||
{ | { | ||||
@@ -127,7 +126,7 @@ namespace Tensorflow.Gradients | |||||
new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | ||||
new Tensor[] { tf.constant(1), tf.constant(-1) }); | new Tensor[] { tf.constant(1), tf.constant(-1) }); | ||||
var squeeze_sizes = array_ops.squeeze(slice); | var squeeze_sizes = array_ops.squeeze(slice); | ||||
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); | |||||
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList(); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable | |||||
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | ||||
public Dictionary<string, AttrValue> Attrs { get; set; } | public Dictionary<string, AttrValue> Attrs { get; set; } | ||||
Dictionary<long, (Tensor, Tensor)> _captures | |||||
internal Dictionary<long, (Tensor, Tensor)> _captures | |||||
= new Dictionary<long, (Tensor, Tensor)>(); | = new Dictionary<long, (Tensor, Tensor)>(); | ||||
public Tensor[] external_captures | public Tensor[] external_captures | ||||
@@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable | |||||
var flat_func_args = nest.flatten(func_args as object); | var flat_func_args = nest.flatten(func_args as object); | ||||
var flat_func_kwargs = nest.flatten(func_kwargs as object); | var flat_func_kwargs = nest.flatten(func_kwargs as object); | ||||
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | ||||
.Where(x => x is Tensor).Select(x => (Tensor)x)); | |||||
.Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); | |||||
//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | ||||
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | ||||
@@ -129,7 +129,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
protected Graph outer_graph; | |||||
internal Graph outer_graph; | |||||
public Graph OuterGraph => outer_graph; | public Graph OuterGraph => outer_graph; | ||||
public Dictionary<string, EagerDefinedFunction> Functions => _functions; | public Dictionary<string, EagerDefinedFunction> Functions => _functions; | ||||
public SafeGraphHandle c_graph => _handle; | public SafeGraphHandle c_graph => _handle; | ||||
@@ -4,8 +4,6 @@ | |||||
{ | { | ||||
// TODO: maybe change the `RNNArgs` and implement this class. | // TODO: maybe change the `RNNArgs` and implement this class. | ||||
public bool UnitForgetBias { get; set; } | public bool UnitForgetBias { get; set; } | ||||
public float Dropout { get; set; } | |||||
public float RecurrentDropout { get; set; } | |||||
public int Implementation { get; set; } | public int Implementation { get; set; } | ||||
} | } | ||||
} | } |
@@ -1,7 +1,35 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
using Newtonsoft.Json; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
{ | { | ||||
// TODO: complete the implementation | // TODO: complete the implementation | ||||
public class LSTMCellArgs : LayerArgs | |||||
public class LSTMCellArgs : AutoSerializeLayerArgs | |||||
{ | { | ||||
[JsonProperty("units")] | |||||
public int Units { get; set; } | |||||
// TODO(Rinne): lack of initialized value of Activation. Merging keras | |||||
// into tf.net could resolve it. | |||||
[JsonProperty("activation")] | |||||
public Activation Activation { get; set; } | |||||
[JsonProperty("recurrent_activation")] | |||||
public Activation RecurrentActivation { get; set; } | |||||
[JsonProperty("use_bias")] | |||||
public bool UseBias { get; set; } = true; | |||||
[JsonProperty("dropout")] | |||||
public float Dropout { get; set; } = .0f; | |||||
[JsonProperty("recurrent_dropout")] | |||||
public float RecurrentDropout { get; set; } = .0f; | |||||
[JsonProperty("kernel_initializer")] | |||||
public IInitializer KernelInitializer { get; set; } | |||||
[JsonProperty("recurrent_initializer")] | |||||
public IInitializer RecurrentInitializer { get; set; } | |||||
[JsonProperty("bias_initializer")] | |||||
public IInitializer BiasInitializer { get; set; } | |||||
[JsonProperty("unit_forget_bias")] | |||||
public bool UnitForgetBias { get; set; } = true; | |||||
[JsonProperty("implementation")] | |||||
public int Implementation { get; set; } = 2; | |||||
} | } | ||||
} | } |
@@ -7,12 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
// TODO(Rinne): add regularizers. | // TODO(Rinne): add regularizers. | ||||
public class RNNArgs : AutoSerializeLayerArgs | public class RNNArgs : AutoSerializeLayerArgs | ||||
{ | { | ||||
[JsonProperty("cell")] | |||||
// TODO: the cell should be serialized with `serialize_keras_object`. | |||||
public IRnnCell Cell { get; set; } = null; | |||||
[JsonProperty("cells")] | |||||
public IList<IRnnCell> Cells { get; set; } = null; | |||||
[JsonProperty("return_sequences")] | [JsonProperty("return_sequences")] | ||||
public bool ReturnSequences { get; set; } = false; | public bool ReturnSequences { get; set; } = false; | ||||
[JsonProperty("return_state")] | [JsonProperty("return_state")] | ||||
@@ -25,8 +19,10 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
public bool Unroll { get; set; } = false; | public bool Unroll { get; set; } = false; | ||||
[JsonProperty("time_major")] | [JsonProperty("time_major")] | ||||
public bool TimeMajor { get; set; } = false; | public bool TimeMajor { get; set; } = false; | ||||
public int? InputDim { get; set; } | |||||
public int? InputLength { get; set; } | |||||
// TODO: Add `num_constants` and `zero_output_for_mask`. | // TODO: Add `num_constants` and `zero_output_for_mask`. | ||||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
public int Units { get; set; } | public int Units { get; set; } | ||||
public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
@@ -38,21 +34,5 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
public float Dropout { get; set; } = .0f; | public float Dropout { get; set; } = .0f; | ||||
public bool ZeroOutputForMask { get; set; } = false; | public bool ZeroOutputForMask { get; set; } = false; | ||||
public float RecurrentDropout { get; set; } = .0f; | public float RecurrentDropout { get; set; } = .0f; | ||||
// kernel_regularizer=None, | |||||
// recurrent_regularizer=None, | |||||
// bias_regularizer=None, | |||||
// activity_regularizer=None, | |||||
// kernel_constraint=None, | |||||
// recurrent_constraint=None, | |||||
// bias_constraint=None, | |||||
// dropout=0., | |||||
// recurrent_dropout=0., | |||||
// return_sequences=False, | |||||
// return_state=False, | |||||
// go_backwards=False, | |||||
// stateful=False, | |||||
// unroll=False, | |||||
// **kwargs): | |||||
} | } | ||||
} | } |
@@ -1,7 +1,4 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | namespace Tensorflow.Keras.ArgsDefinition.Rnn | ||||
{ | { | ||||
@@ -25,5 +22,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
public IInitializer RecurrentInitializer { get; set; } | public IInitializer RecurrentInitializer { get; set; } | ||||
[JsonProperty("bias_initializer")] | [JsonProperty("bias_initializer")] | ||||
public IInitializer BiasInitializer { get; set; } | public IInitializer BiasInitializer { get; set; } | ||||
} | } | ||||
} | } |
@@ -5,7 +5,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
{ | { | ||||
public class StackedRNNCellsArgs : LayerArgs | public class StackedRNNCellsArgs : LayerArgs | ||||
{ | { | ||||
public IList<IRnnCell> Cells { get; set; } | |||||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
public bool ReverseStateOrder = false; | |||||
} | } | ||||
} | } |
@@ -160,6 +160,18 @@ namespace Tensorflow.Keras.Layers | |||||
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | ||||
public ILayer LeakyReLU(float alpha = 0.3f); | public ILayer LeakyReLU(float alpha = 0.3f); | ||||
public IRnnCell LSTMCell(int uints, | |||||
string activation = "tanh", | |||||
string recurrent_activation = "sigmoid", | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", | |||||
string bias_initializer = "zeros", | |||||
bool unit_forget_bias = true, | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f, | |||||
int implementation = 2); | |||||
public ILayer LSTM(int units, | public ILayer LSTM(int units, | ||||
Activation activation = null, | Activation activation = null, | ||||
Activation recurrent_activation = null, | Activation recurrent_activation = null, | ||||
@@ -7,13 +7,19 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | { | ||||
public interface IRnnCell: ILayer | public interface IRnnCell: ILayer | ||||
{ | { | ||||
GeneralizedTensorShape StateSize { get; } | |||||
GeneralizedTensorShape OutputSize { get; } | |||||
bool IsTFRnnCell { get; } | |||||
/// <summary> | |||||
/// If the derived class tends to not implement it, please return null. | |||||
/// </summary> | |||||
INestStructure<long>? StateSize { get; } | |||||
/// <summary> | |||||
/// If the derived class tends to not implement it, please return null. | |||||
/// </summary> | |||||
INestStructure<long>? OutputSize { get; } | |||||
/// <summary> | /// <summary> | ||||
/// Whether the optional RNN args are supported when appying the layer. | /// Whether the optional RNN args are supported when appying the layer. | ||||
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | ||||
/// </summary> | /// </summary> | ||||
bool SupportOptionalArgs { get; } | bool SupportOptionalArgs { get; } | ||||
Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); | |||||
} | } | ||||
} | } |
@@ -7,7 +7,7 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
public class NDArrayRender | public class NDArrayRender | ||||
{ | { | ||||
public static string ToString(NDArray array) | |||||
public static string ToString(NDArray array, int maxLength = 10) | |||||
{ | { | ||||
Shape shape = array.shape; | Shape shape = array.shape; | ||||
if (shape.IsScalar) | if (shape.IsScalar) | ||||
@@ -15,12 +15,12 @@ namespace Tensorflow.NumPy | |||||
var s = new StringBuilder(); | var s = new StringBuilder(); | ||||
s.Append("array("); | s.Append("array("); | ||||
Build(s, array); | |||||
Build(s, array, maxLength); | |||||
s.Append(")"); | s.Append(")"); | ||||
return s.ToString(); | return s.ToString(); | ||||
} | } | ||||
static void Build(StringBuilder s, NDArray array) | |||||
static void Build(StringBuilder s, NDArray array, int maxLength) | |||||
{ | { | ||||
var shape = array.shape; | var shape = array.shape; | ||||
@@ -35,11 +35,11 @@ namespace Tensorflow.NumPy | |||||
var len = shape[0]; | var len = shape[0]; | ||||
s.Append("["); | s.Append("["); | ||||
if (len <= 10) | |||||
if (len <= maxLength) | |||||
{ | { | ||||
for (int i = 0; i < len; i++) | for (int i = 0; i < len; i++) | ||||
{ | { | ||||
Build(s, array[i]); | |||||
Build(s, array[i], maxLength); | |||||
if (i < len - 1) | if (i < len - 1) | ||||
{ | { | ||||
s.Append(", "); | s.Append(", "); | ||||
@@ -49,9 +49,9 @@ namespace Tensorflow.NumPy | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
for (int i = 0; i < 5; i++) | |||||
for (int i = 0; i < maxLength / 2; i++) | |||||
{ | { | ||||
Build(s, array[i]); | |||||
Build(s, array[i], maxLength); | |||||
if (i < len - 1) | if (i < len - 1) | ||||
{ | { | ||||
s.Append(", "); | s.Append(", "); | ||||
@@ -62,9 +62,9 @@ namespace Tensorflow.NumPy | |||||
s.Append(" ... "); | s.Append(" ... "); | ||||
s.AppendLine(); | s.AppendLine(); | ||||
for (int i = (int)len - 5; i < len; i++) | |||||
for (int i = (int)len - maxLength / 2; i < len; i++) | |||||
{ | { | ||||
Build(s, array[i]); | |||||
Build(s, array[i], maxLength); | |||||
if (i < len - 1) | if (i < len - 1) | ||||
{ | { | ||||
s.Append(", "); | s.Append(", "); | ||||
@@ -19,13 +19,14 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras.Saving.Common; | using Tensorflow.Keras.Saving.Common; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
[JsonConverter(typeof(CustomizedShapeJsonConverter))] | [JsonConverter(typeof(CustomizedShapeJsonConverter))] | ||||
public class Shape | |||||
public class Shape : INestStructure<long> | |||||
{ | { | ||||
public int ndim => _dims == null ? -1 : _dims.Length; | public int ndim => _dims == null ? -1 : _dims.Length; | ||||
long[] _dims; | long[] _dims; | ||||
@@ -41,6 +42,27 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public NestType NestType => NestType.List; | |||||
public int ShallowNestedCount => ndim; | |||||
/// <summary> | |||||
/// The total item count of depth 1 of the nested structure. | |||||
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||||
/// </summary> | |||||
public int TotalNestedCount => ndim; | |||||
public IEnumerable<long> Flatten() => dims.Select(x => x); | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func) | |||||
{ | |||||
return new NestList<TOut>(dims.Select(x => func(x))); | |||||
} | |||||
public Nest<long> AsNest() | |||||
{ | |||||
return new NestList<long>(Flatten()).AsNest(); | |||||
} | |||||
#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | ||||
public int Length => ndim; | public int Length => ndim; | ||||
public long[] Slice(int start, int length) | public long[] Slice(int start, int length) | ||||
@@ -0,0 +1,22 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Operations.Initializers | |||||
{ | |||||
/// <summary> | |||||
/// An initializer specially used for debugging (to load weights from disk). | |||||
/// </summary> | |||||
class NpyLoadInitializer : IInitializer | |||||
{ | |||||
string _path; | |||||
public NpyLoadInitializer(string path) { _path = path; } | |||||
public string ClassName => ""; | |||||
public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
public Tensor Apply(InitializerArgs args) | |||||
{ | |||||
return np.load(_path); | |||||
} | |||||
} | |||||
} |
@@ -58,8 +58,7 @@ public class Orthogonal : IInitializer | |||||
if (num_rows < num_cols) | if (num_rows < num_cols) | ||||
{ | { | ||||
// q = tf.linalg.matrix_transpose(q); | |||||
throw new NotImplementedException(""); | |||||
q = array_ops.matrix_transpose(q); | |||||
} | } | ||||
return _gain * tf.reshape(q, shape); | return _gain * tf.reshape(q, shape); | ||||
@@ -89,7 +89,7 @@ namespace Tensorflow | |||||
gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | ||||
// i = input_gate, j = new_input, f = forget_gate, o = output_gate | // i = input_gate, j = new_input, f = forget_gate, o = output_gate | ||||
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||||
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||||
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | ||||
var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | ||||
@@ -181,8 +181,12 @@ namespace Tensorflow | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
public GeneralizedTensorShape StateSize => throw new NotImplementedException(); | |||||
public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | |||||
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public INestStructure<long> StateSize => throw new NotImplementedException(); | |||||
public INestStructure<long> OutputSize => throw new NotImplementedException(); | |||||
public bool IsTFRnnCell => throw new NotImplementedException(); | public bool IsTFRnnCell => throw new NotImplementedException(); | ||||
public bool SupportOptionalArgs => throw new NotImplementedException(); | public bool SupportOptionalArgs => throw new NotImplementedException(); | ||||
} | } | ||||
@@ -15,9 +15,11 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using Google.Protobuf.Collections; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Functions; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.OpDef.Types; | using static Tensorflow.OpDef.Types; | ||||
@@ -387,9 +389,13 @@ namespace Tensorflow | |||||
case "list(type)": | case "list(type)": | ||||
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | ||||
break; | break; | ||||
case "list(float)": | |||||
if (value != null) | |||||
attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray()); | |||||
break; | |||||
case "list(int)": | case "list(int)": | ||||
if (value != null) | if (value != null) | ||||
attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x))); | |||||
attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x))); | |||||
break; | break; | ||||
case "bool": | case "bool": | ||||
attr_value.B = (bool)value; | attr_value.B = (bool)value; | ||||
@@ -420,6 +426,15 @@ namespace Tensorflow | |||||
case "list(shape)": | case "list(shape)": | ||||
attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | ||||
break; | break; | ||||
case "func": | |||||
attr_value.Func = _MakeFunc(value, attr_def.Name); | |||||
break; | |||||
case "list(func)": | |||||
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | |||||
break; | |||||
case "list(string)": | |||||
attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x))); | |||||
break; | |||||
default: | default: | ||||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | ||||
} | } | ||||
@@ -427,6 +442,47 @@ namespace Tensorflow | |||||
return attr_value; | return attr_value; | ||||
} | } | ||||
private NameAttrList _MakeFunc(object func, string arg_name) | |||||
{ | |||||
if(func is NameAttrList attrList) | |||||
{ | |||||
return attrList; | |||||
} | |||||
NameAttrList fn_attr; | |||||
if(func is string funcStr) | |||||
{ | |||||
fn_attr = new NameAttrList() { Name = funcStr }; | |||||
} | |||||
else if(func is ConcreteFunction concrete) | |||||
{ | |||||
concrete.AddTograph(ops.get_default_graph()); | |||||
fn_attr = concrete.AsNameAttrList; | |||||
} | |||||
else if(func is EagerDefinedFunction eager) | |||||
{ | |||||
eager.AddToGraph(ops.get_default_graph()); | |||||
fn_attr = new NameAttrList() { Name = eager.Name }; | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); | |||||
} | |||||
return fn_attr; | |||||
} | |||||
private List<NameAttrList> _MakeFuncList(object funcList, string arg_name) | |||||
{ | |||||
List<NameAttrList> res = new List<NameAttrList>(); | |||||
if(funcList is IEnumerable enumerable) | |||||
{ | |||||
foreach(var func in enumerable) | |||||
{ | |||||
res.Add(_MakeFunc(func, arg_name)); | |||||
} | |||||
} | |||||
return res; | |||||
} | |||||
private bool _IsListParameter(ArgDef arg) | private bool _IsListParameter(ArgDef arg) | ||||
{ | { | ||||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | if (!String.IsNullOrEmpty(arg.NumberAttr)) | ||||
@@ -34,7 +34,7 @@ namespace Tensorflow | |||||
return num; | return num; | ||||
} | } | ||||
protected Tensor[] _outputs; | |||||
internal Tensor[] _outputs; | |||||
public virtual Tensor[] outputs => _outputs; | public virtual Tensor[] outputs => _outputs; | ||||
public Tensor output => _outputs.FirstOrDefault(); | public Tensor output => _outputs.FirstOrDefault(); | ||||
@@ -46,9 +46,9 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public partial class Operation : ITensorOrOperation | public partial class Operation : ITensorOrOperation | ||||
{ | { | ||||
private readonly IntPtr _handle; // _c_op in python | |||||
protected IntPtr _handle; // _c_op in python | |||||
private readonly Graph _graph; | |||||
protected Graph _graph; | |||||
internal Func<Operation, object[], Tensor[]> _gradient_function; | internal Func<Operation, object[], Tensor[]> _gradient_function; | ||||
@@ -69,6 +69,7 @@ namespace Tensorflow | |||||
//private OperationDescription _op_desc; | //private OperationDescription _op_desc; | ||||
public NodeDef node_def => GetNodeDef(); | public NodeDef node_def => GetNodeDef(); | ||||
protected Operation() { } | |||||
public Operation(IntPtr handle, Graph g = null) | public Operation(IntPtr handle, Graph g = null) | ||||
{ | { | ||||
@@ -185,7 +186,16 @@ namespace Tensorflow | |||||
} | } | ||||
public virtual T get_attr<T>(string name) | public virtual T get_attr<T>(string name) | ||||
=> (T)get_attr(name); | |||||
{ | |||||
if (typeof(T).IsValueType) | |||||
{ | |||||
return (T)Convert.ChangeType(get_attr(name), typeof(T)); | |||||
} | |||||
else | |||||
{ | |||||
return (T)get_attr(name); | |||||
} | |||||
} | |||||
internal unsafe TF_DataType _get_attr_type(string name) | internal unsafe TF_DataType _get_attr_type(string name) | ||||
{ | { | ||||
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -38,10 +39,6 @@ namespace Tensorflow.Operations | |||||
bool _infer_shape; | bool _infer_shape; | ||||
public override bool infer_shape => _infer_shape; | public override bool infer_shape => _infer_shape; | ||||
public bool _dynamic_size; | |||||
public Shape _element_shape; | |||||
public List<Tensor> _colocate_with; | |||||
Tensor _handle; | Tensor _handle; | ||||
public override Tensor handle => _handle; | public override Tensor handle => _handle; | ||||
@@ -56,6 +53,7 @@ namespace Tensorflow.Operations | |||||
bool infer_shape = true, Shape? element_shape = null, | bool infer_shape = true, Shape? element_shape = null, | ||||
bool colocate_with_first_write_call = true, string name = null) | bool colocate_with_first_write_call = true, string name = null) | ||||
{ | { | ||||
_size = size; | |||||
_flow = constant_op.constant(0); | _flow = constant_op.constant(0); | ||||
_infer_shape = infer_shape; | _infer_shape = infer_shape; | ||||
_element_shape = element_shape ?? Shape.Null; | _element_shape = element_shape ?? Shape.Null; | ||||
@@ -16,7 +16,9 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -33,18 +35,18 @@ namespace Tensorflow.Operations | |||||
/// first tensor written to it. | /// first tensor written to it. | ||||
/// </summary> | /// </summary> | ||||
bool _colocate_with_first_write_call; | bool _colocate_with_first_write_call; | ||||
public bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||||
public override bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||||
bool _infer_shape; | bool _infer_shape; | ||||
public bool infer_shape => _infer_shape; | |||||
public bool _dynamic_size; | |||||
public override bool infer_shape => _infer_shape; | |||||
public List<Shape> _element_shape; | public List<Shape> _element_shape; | ||||
public List<Tensor> _colocate_with; | public List<Tensor> _colocate_with; | ||||
internal Tensor _handle; | internal Tensor _handle; | ||||
public Tensor handle => _handle; | |||||
public override Tensor handle => _handle; | |||||
internal Tensor _flow; | internal Tensor _flow; | ||||
public override Tensor flow => _flow; | |||||
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | ||||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | ||||
@@ -55,6 +57,7 @@ namespace Tensorflow.Operations | |||||
dynamic_size = dynamic_size ?? false; | dynamic_size = dynamic_size ?? false; | ||||
_dynamic_size = dynamic_size.Value; | _dynamic_size = dynamic_size.Value; | ||||
_dtype = dtype; | _dtype = dtype; | ||||
_size = size; | |||||
_colocate_with_first_write_call = colocate_with_first_write_call; | _colocate_with_first_write_call = colocate_with_first_write_call; | ||||
if (colocate_with_first_write_call) | if (colocate_with_first_write_call) | ||||
@@ -235,4 +238,173 @@ namespace Tensorflow.Operations | |||||
return value; | return value; | ||||
} | } | ||||
} | } | ||||
public class _GraphTensorArrayV2 : TensorArray | |||||
{ | |||||
internal TF_DataType _dtype; | |||||
public override TF_DataType dtype => _dtype; | |||||
/// <summary> | |||||
/// Used to keep track of what tensors the TensorArray should be | |||||
/// colocated with. We choose to colocate the TensorArray with the | |||||
/// first tensor written to it. | |||||
/// </summary> | |||||
bool _colocate_with_first_write_call; | |||||
public override bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||||
bool _infer_shape; | |||||
public override bool infer_shape => _infer_shape; | |||||
public Shape _element_shape; | |||||
public List<Tensor> _colocate_with; | |||||
internal Tensor _handle; | |||||
public override Tensor handle => _handle; | |||||
internal Tensor _flow; | |||||
public override Tensor flow => _flow; | |||||
public _GraphTensorArrayV2(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||||
bool infer_shape = true, Shape? element_shape = null, | |||||
bool colocate_with_first_write_call = true, string name = null) | |||||
{ | |||||
Debug.Assert(handle is null); | |||||
dynamic_size = dynamic_size ?? false; | |||||
_dynamic_size = dynamic_size.Value; | |||||
_size = size; | |||||
if(flow is not null && flow.dtype != dtypes.variant) | |||||
{ | |||||
throw new TypeError($"Expected `flow` to be a variant tensor, but received `{flow.dtype}` instead"); | |||||
} | |||||
if(flow is null && size is null) | |||||
{ | |||||
throw new ValueError("Argument `size` must be provided if argument `flow` is not provided."); | |||||
} | |||||
if(flow is not null && size is not null) | |||||
{ | |||||
throw new ValueError("Cannot provide both `flow` and `size` arguments at the same time."); | |||||
} | |||||
if(flow is not null && element_shape is not null) | |||||
{ | |||||
throw new ValueError("Cannot provide both `flow` and `element_shape` arguments at the same time."); | |||||
} | |||||
_dtype = dtype; | |||||
_element_shape = element_shape; | |||||
_infer_shape = infer_shape; | |||||
tf_with(ops.name_scope(name, "TensorArrayV2", new object[] { size, flow }), scope => | |||||
{ | |||||
if (flow is null) | |||||
{ | |||||
_flow = list_ops.tensor_list_reserve(element_shape, size, dtype, scope.scope_name); | |||||
} | |||||
else | |||||
{ | |||||
_flow = flow; | |||||
} | |||||
}); | |||||
_colocate_with_first_write_call = false; | |||||
_colocate_with = null; | |||||
} | |||||
public override TensorArray unstack(Tensor value, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _flow, value }), delegate | |||||
{ | |||||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
Debug.Assert(value.dtype == _dtype); | |||||
var flow_out = list_ops.tensor_list_from_tensor(value, value.shape.dims.Skip(1).ToArray()); | |||||
return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||||
}); | |||||
} | |||||
public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _flow, value, indices }), delegate | |||||
{ | |||||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
Debug.Assert(value.dtype == _dtype); | |||||
var flow_out = list_ops.tensor_list_scatter(value, indices, _element_shape, _flow); | |||||
return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||||
}); | |||||
} | |||||
public override Tensor read<T>(T index, string name = null) | |||||
{ | |||||
if(index is Tensor tensor) | |||||
{ | |||||
return read(tensor, name); | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError("Please use non-generic method instead."); | |||||
} | |||||
} | |||||
public Tensor read(Tensor index, string name = null) | |||||
{ | |||||
return tf_with(tf.name_scope(name, "TensorArrayV2Read", new object[] { _flow, index }), scope => | |||||
{ | |||||
return list_ops.tensor_list_get_item(_flow, index, _dtype, _element_shape, name); | |||||
}); | |||||
} | |||||
public override TensorArray write(Tensor index, Tensor value, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "TensorArrayV2Write", new { _flow, index, value }), delegate | |||||
{ | |||||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
Debug.Assert(value.dtype == _dtype); | |||||
var flow_out = list_ops.tensor_list_set_item(_flow, index, value, _dynamic_size, name); | |||||
return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||||
}); | |||||
} | |||||
public override TensorArray write<T>(int index, T value, string name = null) | |||||
{ | |||||
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||||
var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||||
return write(index_tensor, value_tensor); | |||||
} | |||||
private Tensor size(string name = null) | |||||
{ | |||||
if(!_dynamic_size && _size is not null) | |||||
{ | |||||
return ops.convert_to_tensor(_size, dtypes.int32); | |||||
} | |||||
else | |||||
{ | |||||
return gen_list_ops.tensor_list_length(_flow, name); | |||||
} | |||||
} | |||||
public override Tensor stack(string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "TensorArrayV2Stack", _flow), delegate | |||||
{ | |||||
int ta_size; | |||||
if(!_dynamic_size && (_size is not null)) | |||||
{ | |||||
var size_tensor = tensor_util.constant_value(_size); | |||||
ta_size = size_tensor is null ? -1 : (int)size_tensor; | |||||
} | |||||
else | |||||
{ | |||||
ta_size = -1; | |||||
} | |||||
var value = list_ops.tensor_list_stack(_flow, _dtype, ta_size, _element_shape); | |||||
return value; | |||||
}); | |||||
} | |||||
public override Tensor gather(Tensor indices, string name = null) | |||||
{ | |||||
return list_ops.tensor_list_gather(_flow, indices, _dtype, _element_shape, name); | |||||
} | |||||
} | |||||
} | } |
@@ -119,6 +119,27 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||||
{ | |||||
dtype = dtype.as_base_dtype(); | |||||
Tensor shapeTensor; | |||||
if(shape.Length > 1) | |||||
{ | |||||
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); | |||||
if(shapeTensor.ndim > 1) | |||||
{ | |||||
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
shapeTensor = shape[0]; | |||||
} | |||||
var output = fill(shapeTensor, array_ops.constant(0, dtype), name); | |||||
Debug.Assert(output.dtype.as_base_dtype() == dtype); | |||||
return output; | |||||
} | |||||
public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | ||||
{ | { | ||||
return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate | return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate | ||||
@@ -307,6 +328,9 @@ namespace Tensorflow | |||||
public static Tensor fill<T>(Shape dims, T value, string name = null) | public static Tensor fill<T>(Shape dims, T value, string name = null) | ||||
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | ||||
public static Tensor fill<T>(Tensor dims, T value, string name = null) | |||||
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the rank of a tensor. | /// Returns the rank of a tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -947,38 +971,70 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1, | |||||
string name = "split") | |||||
/// <summary> | |||||
/// Transposes last two dimensions of tensor `a`. | |||||
/// For example: | |||||
/// <code> python | |||||
/// x = tf.constant([[1, 2, 3], [4, 5, 6]]) | |||||
/// tf.matrix_transpose(x) # [[1, 4], | |||||
/// # [2, 5], | |||||
/// # [3, 6]] | |||||
/// </code> | |||||
/// Matrix with two batch dimensions. | |||||
/// x.shape is [1, 2, 3, 4] | |||||
/// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3] | |||||
/// </summary> | |||||
/// <param name="a"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="conjugate"></param> | |||||
/// <returns></returns> | |||||
/// <exception cref="ValueError"></exception> | |||||
public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false) | |||||
{ | { | ||||
if (num == -1) | |||||
num = (int)size_splits.shape[0]; | |||||
return gen_array_ops.split_v(value, size_splits, tf.convert_to_tensor(axis), num, name: name); | |||||
return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | |||||
{ | |||||
var a_shape = a.shape; | |||||
var ndims = a.shape.ndim; | |||||
Axis perm; | |||||
if(ndims != 0) | |||||
{ | |||||
if (ndims < 2) | |||||
{ | |||||
throw new ValueError("Argument `a` should be a (batch) matrix with rank " + | |||||
$">= 2. Received `a` = {a} with shape: {a_shape}"); | |||||
} | |||||
perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray()); | |||||
} | |||||
else | |||||
{ | |||||
var a_rank = a.rank; | |||||
perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray()); | |||||
} | |||||
return transpose(a, perm:perm, conjugate:conjugate); | |||||
}); | |||||
} | } | ||||
public static Tensor[] split<T>(Tensor value, int num_split, T axis, | |||||
public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis = null, | |||||
string name = "split") | string name = "split") | ||||
{ | { | ||||
var size_splits = ops.convert_to_tensor(num_split); | |||||
return gen_array_ops.split(split_dim: axis, value: value, num_split: num_or_size_splits, name); | |||||
} | |||||
if (tf.Context.executing_eagerly()) | |||||
public static Tensor[] split(Tensor value, int[] num_or_size_splits, Tensor axis = null, int num = -1, | |||||
string name = "split") | |||||
{ | |||||
if(num_or_size_splits.Length == 0) | |||||
{ | { | ||||
return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.Context); | |||||
throw new ValueError("Rank-0 tensors are not supported as the num_or_size_splits argument to split."); | |||||
} | } | ||||
var size_splits = ops.convert_to_tensor(num_or_size_splits); | |||||
var _op = tf.OpDefLib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); | |||||
return _op.outputs; | |||||
} | |||||
private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null) | |||||
{ | |||||
var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { value }); | |||||
var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32); | |||||
var _inputs_flat = new List<Tensor> { axis_tensor }; | |||||
_inputs_flat.AddRange(input); | |||||
var _attrs = new object[] { "num_split", num_split, "T", _attr_T }; | |||||
if(num == -1) | |||||
{ | |||||
num = (int)size_splits.shape[0]; | |||||
} | |||||
return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name); | |||||
return gen_array_ops.split_v(value: value, size_splits: size_splits, split_dim: axis, num_split: num, name: name); | |||||
} | } | ||||
public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) | public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) | ||||
@@ -675,16 +675,17 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public static Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||||
Func<Tensor[], Tensor[]> body, | |||||
Tensor[] loop_vars, | |||||
public static Tensors while_loop(Func<Tensors, Tensor> cond, | |||||
Func<Tensors, Tensors> body, | |||||
Tensors loop_vars, | |||||
int parallel_iterations = 10, | int parallel_iterations = 10, | ||||
string name = null) | string name = null) | ||||
{ | { | ||||
var executing_eagerly = tf.Context.executing_eagerly(); | var executing_eagerly = tf.Context.executing_eagerly(); | ||||
if (!executing_eagerly) | if (!executing_eagerly) | ||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
return while_v2.while_loop(cond, body, loop_vars, parallel_iterations: parallel_iterations, | |||||
name: name); | |||||
} | } | ||||
return tf_with(ops.name_scope("name", "while"), delegate | return tf_with(ops.name_scope("name", "while"), delegate | ||||
@@ -16,12 +16,20 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Functions; | |||||
using Tensorflow.Graphs; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class control_flow_util | public class control_flow_util | ||||
{ | { | ||||
public static readonly bool ENABLE_CONTROL_FLOW_V2 = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0" || | |||||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0") || | |||||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2") != "0") || | |||||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2") != "0") || | |||||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2") != "0"); | |||||
/// <summary> | /// <summary> | ||||
/// Return true if `op` is an Exit. | /// Return true if `op` is an Exit. | ||||
/// </summary> | /// </summary> | ||||
@@ -196,5 +204,74 @@ namespace Tensorflow | |||||
} | } | ||||
return null; | return null; | ||||
} | } | ||||
public static bool EnableControlFlowV2(Graph graph) | |||||
{ | |||||
return ENABLE_CONTROL_FLOW_V2 || graph.building_function && (graph is not FuncGraph func || func.captures.Length == 0); | |||||
} | |||||
public static string create_new_tf_function(FuncGraph func_graph) | |||||
{ | |||||
var func = new EagerDefinedFunction(func_graph.Name, func_graph, func_graph.Inputs, func_graph.Outputs, new Dictionary<string, AttrValue>()); | |||||
func.AddToGraph(func_graph); | |||||
return func_graph.Name; | |||||
} | |||||
public static (Operation, Tensor[]) get_op_and_outputs(Tensor[] inputs) | |||||
{ | |||||
if(inputs.Length == 0) | |||||
{ | |||||
return (null, new Tensor[0]); | |||||
} | |||||
else | |||||
{ | |||||
return (inputs[0], inputs); | |||||
} | |||||
} | |||||
public static Tensor[] run_as_function_for_tape_gradients(Func<Tensor[], Tensor[]> make_op, Tensor[] inputs) | |||||
{ | |||||
if(gradients_util.PossibleTapeGradientTypes(inputs) == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER | |||||
&& !(ops.get_default_graph().building_function)) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
return make_op(inputs); | |||||
} | |||||
} | |||||
public static string unique_fn_name(string scope, string name) | |||||
{ | |||||
return $"{scope}{name}_{ops.uid()}".Replace("/", "_"); | |||||
} | |||||
public static bool output_all_intermediates() | |||||
{ | |||||
if (in_defun()) | |||||
{ | |||||
return false; | |||||
} | |||||
if(tf.Context.FunctionCallOptions.ExecutorType == "SINGLE_THREADED_EXECUTOR") | |||||
{ | |||||
return false; | |||||
} | |||||
// TODO(Rinne): check this after refactoring keras building. | |||||
return false; | |||||
} | |||||
public static bool in_defun() | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
return false; | |||||
} | |||||
var graph = ops.get_default_graph(); | |||||
// TODO(Rinne): CondBranchFuncGraph, WhileBodyFuncGraph, WhileCondFuncGraph | |||||
return graph is FuncGraph; | |||||
} | |||||
} | } | ||||
} | } |
@@ -1778,10 +1778,10 @@ new_height, new_width"); | |||||
{ | { | ||||
// a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3] | // a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3] | ||||
var a_xy_minmax = array_ops.split( | var a_xy_minmax = array_ops.split( | ||||
value: boxes_a, num_split: 4, axis: 2); | |||||
value: boxes_a, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||||
// b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3] | // b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3] | ||||
var b_xy_minmax = array_ops.split( | var b_xy_minmax = array_ops.split( | ||||
value: boxes_b, num_split: 4, axis: 2); | |||||
value: boxes_b, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||||
var i_xmin = math_ops.maximum( | var i_xmin = math_ops.maximum( | ||||
a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 })); | a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 })); | ||||
@@ -1943,7 +1943,7 @@ new_height, new_width"); | |||||
using (ops.name_scope("canonicalize_coordinates")) | using (ops.name_scope("canonicalize_coordinates")) | ||||
{ | { | ||||
// y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3] | // y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3] | ||||
var yx = array_ops.split(value: boxes, num_split: 4, axis: 2); | |||||
var yx = array_ops.split(value: boxes, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||||
var y_1_is_min = math_ops.reduce_all( | var y_1_is_min = math_ops.reduce_all( | ||||
gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0])); | gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0])); | ||||
var y_minmax = control_flow_ops.cond( | var y_minmax = control_flow_ops.cond( | ||||
@@ -0,0 +1,111 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Eager; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
internal class list_ops | |||||
{ | |||||
private static void _set_handle_data(Tensor list_handle, Shape element_shape, TF_DataType element_dtype) | |||||
{ | |||||
if(list_handle is EagerTensor eagerTensor) | |||||
{ | |||||
var handle_data = new CppShapeInferenceResult.Types.HandleData(); | |||||
handle_data.IsSet = true; | |||||
handle_data.ShapeAndType.Add(new CppShapeInferenceResult.Types.HandleShapeAndType() | |||||
{ | |||||
Shape = element_shape.as_proto(), | |||||
Dtype = element_dtype.as_datatype_enum(), | |||||
Type = new FullTypeDef() { TypeId = FullTypeId.TftArray } | |||||
}); | |||||
list_handle.HandleData = handle_data; | |||||
} | |||||
} | |||||
private static Tensor _build_element_shape(Shape? shape) | |||||
{ | |||||
if(shape is null || shape.IsNull) | |||||
{ | |||||
return ops.convert_to_tensor(-1); | |||||
} | |||||
else | |||||
{ | |||||
return ops.convert_to_tensor(shape); | |||||
} | |||||
} | |||||
public static Tensor tensor_list_reserve(Shape? shape, Tensor num_elements, TF_DataType element_dtype, string name = null) | |||||
{ | |||||
var result = gen_list_ops.tensor_list_reserve(_build_element_shape(shape), num_elements, element_dtype, name); | |||||
_set_handle_data(result, shape, element_dtype); | |||||
return result; | |||||
} | |||||
public static Tensor tensor_list_from_tensor(Tensor tensor, Shape element_shape, string? name = null) | |||||
{ | |||||
var result = gen_list_ops.tensor_list_from_tensor(tensor, _build_element_shape(element_shape), name); | |||||
_set_handle_data(result, tensor.shape, tensor.dtype); | |||||
return result; | |||||
} | |||||
public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype, | |||||
Shape? element_shape = null, string? name = null) | |||||
{ | |||||
return gen_list_ops.tensor_list_get_item(input_handle, index, _build_element_shape(element_shape), | |||||
element_dtype, name); | |||||
} | |||||
public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item, | |||||
bool resize_if_index_out_of_bounds = false, string? name = null) | |||||
{ | |||||
if (resize_if_index_out_of_bounds) | |||||
{ | |||||
var input_list_size = gen_list_ops.tensor_list_length(input_handle); | |||||
input_handle = control_flow_ops.cond(index >= input_list_size, | |||||
() => gen_list_ops.tensor_list_resize(input_handle, index + 1), | |||||
() => input_handle); | |||||
} | |||||
var output_handle = gen_list_ops.tensor_list_set_item(input_handle, index, item, name); | |||||
handle_data_util.copy_handle_data(input_handle, output_handle); | |||||
return output_handle; | |||||
} | |||||
public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int num_elements = -1, | |||||
Shape? element_shape = null, string? name = null) | |||||
{ | |||||
return gen_list_ops.tensor_list_stack(input_handle, _build_element_shape(element_shape), element_dtype, num_elements, name); | |||||
} | |||||
public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype, | |||||
Shape? element_shape = null, string? name = null) | |||||
{ | |||||
return gen_list_ops.tensor_list_gather(input_handle, indices, _build_element_shape(element_shape), element_dtype, name); | |||||
} | |||||
public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Shape? element_shape = null, Tensor? input_handle = null, | |||||
string? name = null) | |||||
{ | |||||
if(input_handle is not null) | |||||
{ | |||||
var output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(input_handle, tensor, indices, name); | |||||
handle_data_util.copy_handle_data(input_handle, output_handle); | |||||
return output_handle; | |||||
} | |||||
else | |||||
{ | |||||
var output_handle = gen_list_ops.tensor_list_scatter_v2(tensor, indices, _build_element_shape(element_shape), | |||||
constant_op.constant(-1), name); | |||||
_set_handle_data(output_handle, element_shape, tensor.dtype); | |||||
return output_handle; | |||||
} | |||||
} | |||||
public static Tensor empty_tensor_list(Shape? element_shape, TF_DataType element_dtype, int max_num_elements = -1, | |||||
string? name = null) | |||||
{ | |||||
return gen_list_ops.empty_tensor_list(_build_element_shape(element_shape), element_dtype: element_dtype, | |||||
max_num_elements: ops.convert_to_tensor(max_num_elements, dtype: dtypes.int32), name: name); | |||||
} | |||||
} | |||||
} |
@@ -13,11 +13,23 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) | public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) | ||||
{ | { | ||||
var new_ta = tf.TensorArray( | |||||
dtype: old_ta.dtype, | |||||
infer_shape: old_ta.infer_shape, | |||||
if (!tf.Context.executing_eagerly() && old_ta is not _GraphTensorArrayV2 && control_flow_util.EnableControlFlowV2(ops.get_default_graph())) | |||||
{ | |||||
throw new NotImplementedException("Attempting to build a graph-mode TF2-style " | |||||
+ "TensorArray from either an eager-mode " | |||||
+ "TensorArray or a TF1-style TensorArray. " | |||||
+ "This is not currently supported. You may be " | |||||
+ "attempting to capture a TensorArray " | |||||
+ "inside a tf.function or tf.data map function. " | |||||
+ "Instead, construct a new TensorArray inside " | |||||
+ "the function."); | |||||
} | |||||
var new_ta = TensorArray.Create(old_ta.dtype, handle: old_ta.handle, flow: flow, infer_shape: old_ta.infer_shape, | |||||
colocate_with_first_write_call: old_ta.colocate_with_first_write_call); | colocate_with_first_write_call: old_ta.colocate_with_first_write_call); | ||||
new_ta._dynamic_size = old_ta._dynamic_size; | |||||
new_ta._size = old_ta._size; | |||||
new_ta._colocate_with = old_ta._colocate_with; | |||||
new_ta._element_shape = old_ta._element_shape; | |||||
return new_ta; | return new_ta; | ||||
} | } | ||||
@@ -0,0 +1,401 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Text; | |||||
using Tensorflow.Common.Extensions; | |||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Eager; | |||||
using Tensorflow.Framework; | |||||
using Tensorflow.Framework.Models; | |||||
using Tensorflow.Graphs; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
class _OperationWithOutputs : Operation | |||||
{ | |||||
public _OperationWithOutputs(IntPtr handle, Graph g = null) | |||||
{ | |||||
_handle = handle; | |||||
_graph = g; | |||||
_outputs = null; | |||||
g._add_op(this); | |||||
} | |||||
} | |||||
internal class while_v2 | |||||
{ | |||||
public static Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||||
Func<Tensors, Tensors> body, | |||||
Tensors loop_vars, | |||||
int maximum_iterations = -1, | |||||
int parallel_iterations = 10, | |||||
string name = null, | |||||
bool back_prop = true, | |||||
bool return_same_structure = true) | |||||
{ | |||||
var orig_loop_vars = loop_vars; | |||||
var flat_orig_loop_vars = orig_loop_vars.Flatten().ToArray(); | |||||
int len_orig_loop_vars = orig_loop_vars.Length; | |||||
loop_vars = _tensor_array_to_flow(loop_vars); | |||||
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors(); | |||||
var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars)); | |||||
var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); | |||||
if(string.IsNullOrEmpty(name)) | |||||
{ | |||||
name = "while"; | |||||
} | |||||
return tf_with<ITensorFlowObject, Tensor[]>(ops.name_scope(name), nameScopeWhile => | |||||
{ | |||||
string scope = (nameScopeWhile as ops.NameScope).scope_name; | |||||
string cond_name = control_flow_util.unique_fn_name(scope, "cond"); | |||||
string body_name = control_flow_util.unique_fn_name(scope, "body"); | |||||
var maximum_iterations_loop_var = _build_maximum_iterations_loop_var(maximum_iterations); | |||||
var loop_counter = constant_op.constant(0, maximum_iterations == -1 ? TF_DataType.DtInvalid : maximum_iterations_loop_var.dtype, | |||||
name: "loop_counter"); | |||||
loop_vars = new Tensor[] { loop_counter, maximum_iterations_loop_var }.Concat(loop_vars).ToArray(); | |||||
var func_graph_signature = new TensorSpec[] {TensorSpec.FromTensor(loop_counter),TensorSpec.FromTensor(maximum_iterations_loop_var)} | |||||
.Concat(loop_vars_signature.Flatten()).ToArray(); | |||||
// TODO(Rinne): possible wrong implemenation here. | |||||
var add_control_dependencies = false; | |||||
object[] wrapped_cond(object[] inputs) | |||||
{ | |||||
Tensor loop_counter = (Tensor)inputs[0]; | |||||
Tensor maximum_iterations_arg = (Tensor)inputs[1]; | |||||
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); | |||||
var pred = cond(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); | |||||
if(pred.shape.IsNull || pred.shape.ndim > 0) | |||||
{ | |||||
pred = array_ops.squeeze(pred); | |||||
} | |||||
if(maximum_iterations == -1) | |||||
{ | |||||
return new object[] { pred }; | |||||
} | |||||
else | |||||
{ | |||||
return new object[] { math_ops.logical_and(loop_counter < maximum_iterations_arg, pred) }; | |||||
} | |||||
} | |||||
var cond_graph = FuncGraph.func_graph_from_func(cond_name, wrapped_cond, null, | |||||
null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies); | |||||
bool stateful_parallelism = false; | |||||
object[] wrapped_body(object[] inputs) | |||||
{ | |||||
Tensor loop_counter = (Tensor)inputs[0]; | |||||
Tensor maximum_iterations_arg = (Tensor)inputs[1]; | |||||
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); | |||||
_copy_handle_data(loop_vars.Flatten().Skip(2), args); | |||||
foreach(var t in cond_graph.external_captures) | |||||
{ | |||||
var graph = (FuncGraph)(ops.get_default_graph()); | |||||
graph.capture(t); | |||||
} | |||||
var outputs = body(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); | |||||
outputs = _tensor_array_to_flow(outputs); | |||||
return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray(); | |||||
} | |||||
var body_graph = FuncGraph.func_graph_from_func(body_name, wrapped_body, null, null, func_graph_signature, | |||||
add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism); | |||||
// TODO(Rinne): possible wrong implementation here. | |||||
NestList<Tensors> loop_vars_list = new(new Tensors[] { loop_vars, body_graph.external_captures.ToTensors() }); | |||||
body_graph.Outputs.AddRange(body_graph.internal_captures); | |||||
cond_graph.as_default(); | |||||
int num_cond_captures = cond_graph.external_captures.Length; | |||||
Debug.Assert(cond_graph.external_captures.SequenceEqual(body_graph.external_captures.Take(num_cond_captures).ToArray())); | |||||
_duplicate_body_captures_in_cond(cond_graph, body_graph.external_captures.Skip(num_cond_captures).ToArray()); | |||||
cond_graph.Exit(); | |||||
int first_loop_var_index = 2; | |||||
int num_flattened_oututs = orig_loop_vars.Length; | |||||
int num_original_outputs = body_graph.Outputs.Length; | |||||
if (back_prop && control_flow_util.output_all_intermediates()) | |||||
{ | |||||
var intermediate_tensors = _get_intermediates(body_graph); | |||||
foreach(var intermediate_tensor in intermediate_tensors) | |||||
{ | |||||
var tensor_list = list_ops.empty_tensor_list(intermediate_tensor.shape, intermediate_tensor.dtype, maximum_iterations); | |||||
loop_vars_list.Values.Add(tensor_list); | |||||
cond_graph.as_default(); | |||||
cond_graph.capture(tensor_list); | |||||
cond_graph.Exit(); | |||||
body_graph.as_default(); | |||||
var appended_tensor_list = gen_ops.tensor_list_push_back(tensor_list, intermediate_tensor); | |||||
body_graph.Outputs.Add(appended_tensor_list); | |||||
body_graph.Exit(); | |||||
} | |||||
} | |||||
List<Tensor> flattened_loop_vars = new(); | |||||
foreach(var item in loop_vars_list.Values) | |||||
{ | |||||
flattened_loop_vars.AddRange(item.Flatten()); | |||||
} | |||||
// skip the check | |||||
// TODO(Rinne): deal with control dependencies | |||||
var output_shapes = body_graph.Outputs.Select(t => t.shape).ToArray(); | |||||
var span = new Span<Shape>(output_shapes).Slice(first_loop_var_index, num_flattened_oututs); | |||||
for(int i = 0; i < span.Length; i++) | |||||
{ | |||||
span[i] = flat_shape_invariants[i]; | |||||
} | |||||
Tensor[] outputs = _build_while_op(flattened_loop_vars.ToArray(), cond_graph, body_graph, output_shapes, parallel_iterations, | |||||
(nameScopeWhile as ops.NameScope).scope_name, num_original_outputs, stateful_parallelism); | |||||
if (!ops.get_default_graph().building_function) | |||||
{ | |||||
outputs = outputs.Select(t => array_ops.identity(t)).ToArray(); | |||||
} | |||||
var output_loop_vars = outputs.Skip(first_loop_var_index).Take(num_flattened_oututs).ToArray(); | |||||
if (!back_prop) | |||||
{ | |||||
output_loop_vars = output_loop_vars.Select(t => array_ops.stop_gradient(t)).ToArray(); | |||||
} | |||||
outputs = _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, output_loop_vars); | |||||
return outputs; | |||||
}); | |||||
} | |||||
private static Tensors _tensor_array_to_flow(Tensors loop_vars) | |||||
{ | |||||
if(loop_vars.NestType == NestType.Node) | |||||
{ | |||||
if(loop_vars.NodeValue is FakeTensorByTensorArray fake) | |||||
{ | |||||
return new Tensors(fake.TensorArray.flow); | |||||
} | |||||
else | |||||
{ | |||||
return new Tensors(loop_vars.NodeValue!); | |||||
} | |||||
} | |||||
else if(loop_vars.NestType == NestType.List) | |||||
{ | |||||
List<INestStructure<Tensor>> list = new(); | |||||
foreach(var item in loop_vars.ListValue!) | |||||
{ | |||||
if(item.NestType == NestType.Node) | |||||
{ | |||||
var nested = item.AsNest(); | |||||
if (nested.NodeValue is FakeTensorByTensorArray fake) | |||||
{ | |||||
list.Add(new Nest<Tensor>(fake.TensorArray.flow)); | |||||
} | |||||
else | |||||
{ | |||||
list.Add(new Nest<Tensor>(nested.NodeValue!)); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
list.Add(new Nest<Tensor>(item.AsNest())); | |||||
} | |||||
} | |||||
return Tensors.FromNest(new Nest<Tensor>(list)); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
private static Tensor[] _build_while_op(Tensor[] loop_vars, FuncGraph cond_graph, FuncGraph body_graph, | |||||
Shape[] output_shapes, int parallel_iterations, string name, int num_original_outputs, bool stateful_parallelism) | |||||
{ | |||||
var cond_stateful_ops = cond_graph.get_operations().Select(x => x.op); | |||||
var body_stateful_ops = body_graph.get_operations().Select(x => x.op); | |||||
bool is_stateful = cond_stateful_ops.Count() > 0 || body_stateful_ops.Count() > 0; | |||||
Tensor[] _make_op(Tensor[] inputs) | |||||
{ | |||||
Tensor[] outputs; | |||||
if (is_stateful) | |||||
{ | |||||
outputs = gen_functional_ops._while( | |||||
inputs, | |||||
control_flow_util.create_new_tf_function(cond_graph), | |||||
control_flow_util.create_new_tf_function(body_graph), | |||||
output_shapes, | |||||
parallel_iterations, | |||||
name | |||||
); | |||||
} | |||||
else | |||||
{ | |||||
outputs = gen_functional_ops.stateless_while( | |||||
inputs, | |||||
control_flow_util.create_new_tf_function(cond_graph), | |||||
control_flow_util.create_new_tf_function(body_graph), | |||||
output_shapes, | |||||
parallel_iterations, | |||||
name | |||||
); | |||||
} | |||||
var (while_op, tensors) = control_flow_util.get_op_and_outputs(outputs); | |||||
_copy_handle_data(body_graph.Outputs, tensors); | |||||
_set_read_only_resource_inputs_attr(while_op, new FuncGraph[]{cond_graph, body_graph}); | |||||
while_op._set_attr("_num_original_outputs", new AttrValue() { I = num_original_outputs }); | |||||
while_op._set_attr("_stateful_parallelism", new AttrValue() { B = stateful_parallelism }); | |||||
cond_graph.outer_graph = ops.get_default_graph(); | |||||
body_graph.outer_graph = ops.get_default_graph(); | |||||
// TODO(Rinne): set the two graphs to while_op | |||||
return tensors; | |||||
} | |||||
return control_flow_util.run_as_function_for_tape_gradients(_make_op, loop_vars); | |||||
} | |||||
/// <summary> | |||||
/// Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="branch_graphs"></param> | |||||
private static void _set_read_only_resource_inputs_attr(Operation op, FuncGraph[] branch_graphs) | |||||
{ | |||||
List<int> read_only_indices = Enumerable.Range(0, op.inputs.Length).ToList(); | |||||
foreach(var branch_graph in branch_graphs) | |||||
{ | |||||
if (read_only_indices.Count == 0) | |||||
{ | |||||
break; | |||||
} | |||||
var branch_read_only_indices = auto_control_deps_utils.get_read_only_resource_input_indices_graph(branch_graph); | |||||
read_only_indices = read_only_indices.Intersect(branch_read_only_indices).ToList(); | |||||
} | |||||
AttrValue.Types.ListValue listValue = new(); | |||||
listValue.I.AddRange(read_only_indices.OrderBy(x => x).Select(x => (long)x)); | |||||
op._set_attr(auto_control_deps_utils.READ_ONLY_RESOURCE_INPUTS_ATTR, new AttrValue() | |||||
{ | |||||
List = listValue | |||||
}); | |||||
} | |||||
private static Tensors _pack_sequence_as<T>(INestStructure<T> loop_vars_signature, Tensor[] flat_orig_loop_vars, Tensor[] loop_vars) | |||||
{ | |||||
var flattened_loop_vars = zip(loop_vars, flat_orig_loop_vars).Select<(Tensor, Tensor), Tensor>(item => | |||||
{ | |||||
var (flow, y) = item; | |||||
if (y is FakeTensorByTensorArray ta) | |||||
{ | |||||
return new FakeTensorByTensorArray(tensor_array_ops.build_ta_with_new_flow(ta.TensorArray, flow)); | |||||
} | |||||
else | |||||
{ | |||||
return flow; | |||||
} | |||||
}).ToArray(); | |||||
return Nest.PackSequenceAs(loop_vars_signature, flattened_loop_vars).ToTensors(); | |||||
} | |||||
private static Tensor[] _get_intermediates(FuncGraph func_graph) | |||||
{ | |||||
List<Tensor> intermediates = new(); | |||||
var reversed_captures = func_graph.captures.ToDictionary(x => x.Item2, x => x.Item1); | |||||
foreach(var op in func_graph.get_operations()) | |||||
{ | |||||
Debug.Assert(op is Operation); | |||||
var oper = (Operation)op; | |||||
if(oper.type == "Identity" || oper.type == "MutexLock") | |||||
{ | |||||
continue; | |||||
} | |||||
foreach(var o in op.outputs) | |||||
{ | |||||
if(o != func_graph.Inputs[0] && o.dtype != dtypes.resource && !reversed_captures.ContainsKey(o)) | |||||
{ | |||||
intermediates.Add(o); | |||||
} | |||||
} | |||||
} | |||||
return intermediates.ToArray(); | |||||
} | |||||
private static void _duplicate_body_captures_in_cond(FuncGraph cond_graph, Tensor[] body_graph_captures) | |||||
{ | |||||
var types = body_graph_captures.Select(t => t.dtype).ToList(); | |||||
var c_graph = cond_graph.c_graph; | |||||
var placeholders = types.Select(x => CreatePlaceholder(c_graph, _build_cond_placeholders_name_prefix(cond_graph), x)).ToList(); | |||||
var placeholder_ops = placeholders.Select(ph => new _OperationWithOutputs(ph.oper, cond_graph)).ToList(); | |||||
List<Tensor> tensors = new(); | |||||
foreach(var (op, ph, dtype) in zip(placeholder_ops, placeholders, types)) | |||||
{ | |||||
var tensor = Tensor._create_with_tf_output(op, 0, dtype, ph); | |||||
op._outputs = new Tensor[] { tensor }; | |||||
tensors.Add(tensor); | |||||
} | |||||
var tuples = zip(body_graph_captures, tensors).ToList(); | |||||
var keys = body_graph_captures.Select(t => t.Id).ToList(); | |||||
cond_graph._captures.Update(zip(keys, tuples).ToDictionary(x => x.Item1, x => x.Item2)); | |||||
cond_graph.Inputs.AddRange(tensors); | |||||
} | |||||
private static TF_Output CreatePlaceholder(SafeGraphHandle graph, string name, TF_DataType dtype) | |||||
{ | |||||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||||
c_api.TF_SetAttrType(desc, "dtype", dtype); | |||||
var op = c_api.TF_FinishOperation(desc, tf.Status); | |||||
tf.Status.Check(true); | |||||
var output = new TF_Output(); | |||||
output.oper = op; | |||||
output.index = 0; | |||||
return output; | |||||
} | |||||
private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph) | |||||
{ | |||||
return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); | |||||
} | |||||
private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype, | |||||
string name) | |||||
{ | |||||
return ops.convert_to_tensor(value, dtype, name, false); | |||||
} | |||||
private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) | |||||
{ | |||||
return ops.convert_to_tensor(maximum_iterations, dtypes.int32, "maximum_iterations"); | |||||
} | |||||
private static void _copy_handle_data(IEnumerable<Tensor> src_tensors, IEnumerable<Tensor> dst_tensors) | |||||
{ | |||||
foreach(var (src_t, dst_t) in zip(src_tensors, dst_tensors)) | |||||
{ | |||||
handle_data_util.copy_handle_data(src_t, dst_t); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
using Tensorflow.Exceptions; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
@@ -88,7 +89,7 @@ namespace Tensorflow | |||||
case TF_Code.TF_INVALID_ARGUMENT: | case TF_Code.TF_INVALID_ARGUMENT: | ||||
throw new InvalidArgumentError(message); | throw new InvalidArgumentError(message); | ||||
default: | default: | ||||
throw new TensorflowException(message); | |||||
throw new NotOkStatusException(message); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -111,7 +111,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" /> | <PackageReference Include="Newtonsoft.Json" Version="13.0.3" /> | ||||
<PackageReference Include="OneOf" Version="3.0.223" /> | <PackageReference Include="OneOf" Version="3.0.223" /> | ||||
<PackageReference Include="Protobuf.Text" Version="0.7.0" /> | |||||
<PackageReference Include="Protobuf.Text" Version="0.7.1" /> | |||||
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -105,6 +105,13 @@ namespace Tensorflow | |||||
_id = ops.uid(); | _id = ops.uid(); | ||||
} | } | ||||
internal static Tensor _create_with_tf_output(Operation op, int value_index, TF_DataType dtype, TF_Output tf_output) | |||||
{ | |||||
Tensor ret = new Tensor(op, value_index, dtype); | |||||
ret._tf_output = tf_output; | |||||
return ret; | |||||
} | |||||
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | ||||
{ | { | ||||
_handle = TF_NewTensor(shape, dtype, null); | _handle = TF_NewTensor(shape, dtype, null); | ||||
@@ -14,7 +14,9 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -44,5 +46,27 @@ namespace Tensorflow | |||||
public abstract Tensor stack(string name = null); | public abstract Tensor stack(string name = null); | ||||
public abstract Tensor gather(Tensor indices, string name = null); | public abstract Tensor gather(Tensor indices, string name = null); | ||||
internal bool _dynamic_size; | |||||
internal Tensor _size; | |||||
internal List<Tensor> _colocate_with; | |||||
internal Shape _element_shape; | |||||
public static TensorArray Create(TF_DataType dtype, Tensor size = null, bool dynamic_size = false, | |||||
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||||
bool infer_shape = true, Shape? element_shape = null, | |||||
bool colocate_with_first_write_call = true, string name = null) | |||||
{ | |||||
if (tf.Context.executing_eagerly() && (flow is null || flow.dtype != dtypes.variant)) | |||||
{ | |||||
return new _EagerTensorArray(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, | |||||
infer_shape, element_shape, colocate_with_first_write_call, name); | |||||
} | |||||
else | |||||
{ | |||||
return new _GraphTensorArrayV2(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, | |||||
infer_shape, element_shape, colocate_with_first_write_call, name); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -4,6 +4,8 @@ using System.Collections; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Operations; | |||||
using Tensorflow.Common.Extensions; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -58,7 +60,7 @@ namespace Tensorflow | |||||
public Tensor this[params string[] slices] | public Tensor this[params string[] slices] | ||||
=> this.First()[slices]; | => this.First()[slices]; | ||||
private Tensors(Nest<Tensor> nested) : base(nested) | |||||
internal Tensors(Nest<Tensor> nested) : base(nested) | |||||
{ | { | ||||
} | } | ||||
@@ -68,9 +70,9 @@ namespace Tensorflow | |||||
} | } | ||||
public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||||
public Tensors(IList<Tensor> tensors) : base(tensors.Select(x => new Nest<Tensor>(x))) | |||||
{ | { | ||||
} | } | ||||
public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | ||||
@@ -78,6 +80,32 @@ namespace Tensorflow | |||||
} | } | ||||
/// <summary> | |||||
/// Get the element in shallow level. For example, for ts = [1, [2, 3], 4], | |||||
/// common indexer has ts[1] = 2. Shallow indexer has ts[1] = [2, 3] | |||||
/// </summary> | |||||
/// <param name="index"></param> | |||||
/// <returns></returns> | |||||
public Tensors GetShallow(int index) | |||||
{ | |||||
if(NestType == NestType.Node) | |||||
{ | |||||
if(index > 0) | |||||
{ | |||||
throw new IndexOutOfRangeException(); | |||||
} | |||||
return this; | |||||
} | |||||
else if(NestType == NestType.List) | |||||
{ | |||||
return ListValue![index].AsNest().ToTensors(); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors) | private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors) | ||||
{ | { | ||||
if (tensors.Length == 0) | if (tensors.Length == 0) | ||||
@@ -115,8 +143,8 @@ namespace Tensorflow | |||||
else if(NestType == NestType.Node) | else if(NestType == NestType.Node) | ||||
{ | { | ||||
NestType = NestType.List; | NestType = NestType.List; | ||||
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) }; | |||||
Value = null; | |||||
ListValue = new() { new Nest<Tensor>(NodeValue), new Nest<Tensor>(tensor) }; | |||||
NodeValue = null; | |||||
} | } | ||||
else if(NestType == NestType.List) | else if(NestType == NestType.List) | ||||
{ | { | ||||
@@ -125,7 +153,7 @@ namespace Tensorflow | |||||
else //Empty | else //Empty | ||||
{ | { | ||||
NestType = NestType.Node; | NestType = NestType.Node; | ||||
Value = tensor; | |||||
NodeValue = tensor; | |||||
} | } | ||||
} | } | ||||
@@ -140,9 +168,9 @@ namespace Tensorflow | |||||
else if (NestType == NestType.Node) | else if (NestType == NestType.Node) | ||||
{ | { | ||||
NestType = NestType.List; | NestType = NestType.List; | ||||
ListValue = new() { new Nest<Tensor>(Value) }; | |||||
ListValue = new() { new Nest<Tensor>(NodeValue) }; | |||||
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | ||||
Value = null; | |||||
NodeValue = null; | |||||
} | } | ||||
else if(NestType == NestType.List) | else if(NestType == NestType.List) | ||||
{ | { | ||||
@@ -151,7 +179,7 @@ namespace Tensorflow | |||||
else // empty | else // empty | ||||
{ | { | ||||
NestType = NestType.List; | NestType = NestType.List; | ||||
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList(); | |||||
ListValue = tensors.Select(x => new Nest<Tensor>(x) as INestStructure<Tensor>).ToList(); | |||||
} | } | ||||
} | } | ||||
@@ -166,9 +194,9 @@ namespace Tensorflow | |||||
else if(NestType == NestType.Node) | else if(NestType == NestType.Node) | ||||
{ | { | ||||
NestType = NestType.List; | NestType = NestType.List; | ||||
ListValue = new() { new Nest<Tensor>(Value) }; | |||||
ListValue = new() { new Nest<Tensor>(NodeValue) }; | |||||
ListValue.Insert(index, new Nest<Tensor>(tensor)); | ListValue.Insert(index, new Nest<Tensor>(tensor)); | ||||
Value = null; | |||||
NodeValue = null; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -283,7 +311,7 @@ namespace Tensorflow | |||||
=> tensors?.SingleOrNull; | => tensors?.SingleOrNull; | ||||
public static implicit operator Tensor[](Tensors tensors) | public static implicit operator Tensor[](Tensors tensors) | ||||
=> tensors.Flatten().ToArray(); | |||||
=> tensors.Flatten().ToArray(); | |||||
#endregion | #endregion | ||||
public static Tensors? FromNest(Nest<Tensor> nested) | public static Tensors? FromNest(Nest<Tensor> nested) | ||||
@@ -298,7 +326,7 @@ namespace Tensorflow | |||||
public void Deconstruct(out Tensor a, out Tensors? b) | public void Deconstruct(out Tensor a, out Tensors? b) | ||||
{ | { | ||||
a = this.First(); | a = this.First(); | ||||
b = Length == 1? null : new Tensors(this.Skip(1)); | |||||
b = Length == 1? null : new Tensors(this.Skip(1).ToArray()); | |||||
} | } | ||||
public override string ToString() | public override string ToString() | ||||
@@ -179,8 +179,7 @@ namespace Tensorflow.Train | |||||
// handles slot variables. | // handles slot variables. | ||||
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) | if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) | ||||
{ | { | ||||
var temp = new_variable as Trackable; | |||||
var res = _track_trackable(temp, args.Name, args.Overwrite); | |||||
var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite); | |||||
Debug.Assert(res is IVariableV1); | Debug.Assert(res is IVariableV1); | ||||
return res as IVariableV1; | return res as IVariableV1; | ||||
} | } | ||||
@@ -170,11 +170,28 @@ namespace Tensorflow | |||||
public Tensor value() | public Tensor value() | ||||
=> GraphElement ?? _read_variable_op(); | => GraphElement ?? _read_variable_op(); | ||||
protected Tensor _read_variable_op() | |||||
protected Tensor _read_variable_op(bool no_copy = false) | |||||
{ | { | ||||
variable_accessed(this); | variable_accessed(this); | ||||
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||||
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); | |||||
Tensor read_and_set_handle(bool no_copy) | |||||
{ | |||||
if (no_copy) | |||||
{ | |||||
gen_resource_variable_ops.disable_copy_on_read(handle); | |||||
} | |||||
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||||
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); | |||||
return result; | |||||
} | |||||
// TODO(Rinne): deal with caching device. | |||||
var result = read_and_set_handle(no_copy); | |||||
if (!tf.Context.executing_eagerly()) | |||||
{ | |||||
tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle }, | |||||
backward_function: (x, _) => x); | |||||
} | |||||
// have to set shape when converting to substituent placeholder | // have to set shape when converting to substituent placeholder | ||||
if (result.shape.ndim == -1) | if (result.shape.ndim == -1) | ||||
@@ -576,7 +576,7 @@ namespace Tensorflow | |||||
public static HandleData get_resource_handle_data(Tensor graph_op) | public static HandleData get_resource_handle_data(Tensor graph_op) | ||||
{ | { | ||||
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | ||||
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||||
return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data)); | |||||
} | } | ||||
public static void dismantle_graph(Graph graph) | public static void dismantle_graph(Graph graph) | ||||
@@ -25,6 +25,7 @@ using static Tensorflow.Binding; | |||||
using static Tensorflow.Graphs.SubGraphUtility; | using static Tensorflow.Graphs.SubGraphUtility; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using System.Diagnostics; | |||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
{ | { | ||||
@@ -485,7 +486,7 @@ namespace Tensorflow.Keras | |||||
var first_flatted_input = flatted_inptus[0]; | var first_flatted_input = flatted_inptus[0]; | ||||
var time_steps = first_flatted_input.shape[0]; | var time_steps = first_flatted_input.shape[0]; | ||||
var batch = first_flatted_input.shape[1]; | var batch = first_flatted_input.shape[1]; | ||||
var time_steps_t = (int)first_flatted_input.shape[0]; | |||||
var time_steps_t = tf.shape(first_flatted_input)[0]; | |||||
foreach (var input_ in flatted_inptus) | foreach (var input_ in flatted_inptus) | ||||
{ | { | ||||
@@ -704,7 +705,7 @@ namespace Tensorflow.Keras | |||||
var input_ta = new List<TensorArray>(); | var input_ta = new List<TensorArray>(); | ||||
for (int i = 0; i < flatted_inptus.Count; i++) | for (int i = 0; i < flatted_inptus.Count; i++) | ||||
{ | { | ||||
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t)); | |||||
input_ta.Add(TensorArray.Create(dtype: flatted_inptus[i].dtype, size: time_steps_t)); | |||||
} | } | ||||
foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) | foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) | ||||
@@ -730,18 +731,15 @@ namespace Tensorflow.Keras | |||||
(output_time_zero, _) = step_function(input_time_zero, | (output_time_zero, _) = step_function(input_time_zero, | ||||
constants is null ? initial_states : initial_states.MergeWith(constants)); | constants is null ? initial_states : initial_states.MergeWith(constants)); | ||||
int output_ta_size = return_all_outputs ? time_steps_t : 1; | |||||
Tensor output_ta_size = return_all_outputs ? time_steps_t : constant_op.constant(1); | |||||
var output_ta = new List<TensorArray>(); | var output_ta = new List<TensorArray>(); | ||||
for (int i = 0; i < output_time_zero.ToList().Count; i++) | |||||
foreach(var output in output_time_zero.Flatten()) | |||||
{ | { | ||||
var Out = output_time_zero.ToList()[i]; | |||||
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); | |||||
output_ta.Add(TensorArray.Create(dtype: output.dtype, size: output_ta_size, element_shape: output.shape)); | |||||
} | } | ||||
var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); | var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); | ||||
Func<Tensor, Tensor>? masking_fn; | Func<Tensor, Tensor>? masking_fn; | ||||
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; | Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; | ||||
if (mask != null) | if (mask != null) | ||||
@@ -750,7 +748,7 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
mask = tf.reverse(mask, axis: new[] { 0 }); | mask = tf.reverse(mask, axis: new[] { 0 }); | ||||
} | } | ||||
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t); | |||||
var mask_ta = TensorArray.Create(dtype: TF_DataType.TF_BOOL, size: time_steps_t); | |||||
mask_ta = mask_ta.unstack(mask); | mask_ta = mask_ta.unstack(mask); | ||||
masking_fn = (time) => | masking_fn = (time) => | ||||
@@ -810,9 +808,9 @@ namespace Tensorflow.Keras | |||||
masking_fn = null; | masking_fn = null; | ||||
} | } | ||||
Func<Tensor, Tensor> cond = (time) => (time < time_steps_t); | |||||
Func<Tensors, Tensor> cond = (time) => (time[0] < time_steps_t); | |||||
int parallel_iterations = 32; | int parallel_iterations = 32; | ||||
new_states = states; | |||||
Tensors final_outputs; | |||||
if (masking_fn != null) | if (masking_fn != null) | ||||
{ | { | ||||
// Mask for the T output will be base on the output of T - 1. In the | // Mask for the T output will be base on the output of T - 1. In the | ||||
@@ -825,7 +823,7 @@ namespace Tensorflow.Keras | |||||
var prev_output = flat_zero_output; | var prev_output = flat_zero_output; | ||||
var output_ta_t = output_ta; | var output_ta_t = output_ta; | ||||
Tensor _step(Tensor time) | |||||
Tensors _step(Tensors tensors) | |||||
{ | { | ||||
/* | /* | ||||
RNN step function. | RNN step function. | ||||
@@ -838,23 +836,28 @@ namespace Tensorflow.Keras | |||||
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` | Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` | ||||
*/ | */ | ||||
Tensor time = tensors[0]; | |||||
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; | |||||
Tensors prev_output = tensors.GetShallow(2); | |||||
Tensors states = new Tensors(tensors.Skip(2 + prev_output.Length).ToArray()); | |||||
var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | ||||
// maybe set shape | // maybe set shape | ||||
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | ||||
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | ||||
var mask_t = masking_fn(time); | var mask_t = masking_fn(time); | ||||
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); | |||||
var (output, new_states) = step_function(current_input, states.MergeWith(constants)); | |||||
// mask output | // mask output | ||||
var flat_output = Nest.Flatten(output).ToList(); | var flat_output = Nest.Flatten(output).ToList(); | ||||
var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); | |||||
var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.Flatten().ToList(); | |||||
// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type | // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type | ||||
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); | var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); | ||||
// mask states | // mask states | ||||
var flat_state = states.ToList(); | |||||
var flat_new_state = new_states_internal.ToList(); | |||||
var flat_state = states.Flatten().ToList(); | |||||
var flat_new_state = new_states.Flatten().ToList(); | |||||
foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | ||||
{ | { | ||||
@@ -865,38 +868,37 @@ namespace Tensorflow.Keras | |||||
} | } | ||||
var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); | var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); | ||||
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); | |||||
new_states = Nest.PackSequenceAs(new_states, flat_final_state.ToArray()).ToTensors(); | |||||
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | ||||
output_ta_t = zip(output_ta_t, flat_new_output).Select(item => | |||||
{ | |||||
var (ta, out_) = item; | |||||
return ta.write(ta_index_to_write, out_); | |||||
}).ToList(); | |||||
Debug.Assert(flat_output.Count() == 1); | |||||
output_ta_t = output_ta_t.write(ta_index_to_write, flat_new_output.First()); | |||||
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||||
output_ta = output_ta_t; | |||||
new_states = new_states_internal; | |||||
return time + 1; | |||||
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(flat_new_output).Concat(new_states) | |||||
.ToArray().ToTensors(); | |||||
} | } | ||||
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); | |||||
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) } | |||||
.Concat(flat_zero_output.Flatten()).Concat(states).ToArray().ToTensors(); | |||||
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); | |||||
new_states = final_outputs.Skip(3).ToList(); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
var output_ta_t = output_ta; | var output_ta_t = output_ta; | ||||
new_states = states; | new_states = states; | ||||
Tensor _step(Tensor time) | |||||
Tensors _step(Tensors tensors) | |||||
{ | { | ||||
Tensor time = tensors[0]; | |||||
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; | |||||
Tensors states = new Tensors(tensors.Skip(2).ToArray()); | |||||
var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | ||||
// maybe set shape | // maybe set shape | ||||
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | ||||
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | ||||
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); | |||||
var (output, new_states) = step_function(current_input, states.MergeWith(constants)); | |||||
var flat_state = new_states.Flatten().ToList(); | var flat_state = new_states.Flatten().ToList(); | ||||
var flat_new_state = new_states_internal.Flatten().ToList(); | |||||
var flat_new_state = new_states.Flatten().ToList(); | |||||
foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | ||||
{ | { | ||||
if (new_state is Tensor) | if (new_state is Tensor) | ||||
@@ -906,24 +908,23 @@ namespace Tensorflow.Keras | |||||
} | } | ||||
var flat_output = Nest.Flatten(output); | var flat_output = Nest.Flatten(output); | ||||
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | ||||
output_ta_t = zip(output_ta_t, flat_output).Select(item => | |||||
{ | |||||
var (ta, out_) = item; | |||||
return ta.write(ta_index_to_write, out_); | |||||
}).ToList(); | |||||
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||||
output_ta = output_ta_t; | |||||
new_states = new_states_internal; | |||||
return time + 1; | |||||
Debug.Assert(flat_output.Count() == 1); | |||||
output_ta_t = output_ta_t.write(ta_index_to_write, flat_output.First()); | |||||
new_states = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||||
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(new_states).ToArray().ToTensors(); | |||||
} | } | ||||
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); | |||||
Debug.Assert(output_ta.Count == 1); | |||||
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }.Concat(states).ToArray().ToTensors(); | |||||
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); | |||||
new_states = final_outputs.Skip(2).ToList(); | |||||
} | } | ||||
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors()); | |||||
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors()); | |||||
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); | |||||
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors(); | |||||
output_ta = new List<TensorArray> { (final_outputs[1] as FakeTensorByTensorArray).TensorArray }; | |||||
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToArray().ToTensors()); | |||||
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToArray().ToTensors()); | |||||
outputs = Nest.PackSequenceAs(output_time_zero, (Tensor[])outputs).ToTensors(); | |||||
last_output = Nest.PackSequenceAs(output_time_zero, (Tensor[])last_output).ToTensors(); | |||||
} | } | ||||
Func<Tensor, Tensor> set_shape; | Func<Tensor, Tensor> set_shape; | ||||
@@ -38,6 +38,8 @@ namespace Tensorflow.Keras.Engine | |||||
_handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
_set_mask_metadata(inputs, outputs, null); | _set_mask_metadata(inputs, outputs, null); | ||||
// TODO(Rinne): set save spec if null | |||||
scope.__exit__(); | scope.__exit__(); | ||||
return outputs; | return outputs; | ||||
@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine | |||||
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); | var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); | ||||
graph.as_default(); | graph.as_default(); | ||||
var shapes = input_shape.ToShapeArray(); | var shapes = input_shape.ToShapeArray(); | ||||
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x))); | |||||
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)).ToArray()); | |||||
try | try | ||||
{ | { | ||||
Call(x, training: false); | Call(x, training: false); | ||||
@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
X = new Tensors(x), | |||||
X = new Tensors(x.ToArray()), | |||||
Y = y, | Y = y, | ||||
Model = this, | Model = this, | ||||
StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
@@ -188,7 +188,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
var data = iterator.next(); | var data = iterator.next(); | ||||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | ||||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | ||||
return outputs; | return outputs; | ||||
} | } | ||||
@@ -110,7 +110,7 @@ namespace Tensorflow.Keras.Engine | |||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
X = new Tensors(train_x), | |||||
X = new Tensors(train_x.ToArray()), | |||||
Y = train_y, | Y = train_y, | ||||
BatchSize = batch_size, | BatchSize = batch_size, | ||||
InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
var data = iterator.next(); | var data = iterator.next(); | ||||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | ||||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | ||||
return outputs; | return outputs; | ||||
} | } | ||||
@@ -0,0 +1,4 @@ | |||||
namespace System.Runtime.CompilerServices | |||||
{ | |||||
internal static class IsExternalInit { } | |||||
} |
@@ -702,16 +702,14 @@ namespace Tensorflow.Keras.Layers | |||||
UseBias = use_bias, | UseBias = use_bias, | ||||
KernelInitializer = GetInitializerByName(kernel_initializer), | KernelInitializer = GetInitializerByName(kernel_initializer), | ||||
RecurrentInitializer = GetInitializerByName(recurrent_initializer), | RecurrentInitializer = GetInitializerByName(recurrent_initializer), | ||||
BiasInitializer = GetInitializerByName(bias_initializer), | |||||
Dropout = dropout, | Dropout = dropout, | ||||
RecurrentDropout = recurrent_dropout | RecurrentDropout = recurrent_dropout | ||||
}); | }); | ||||
public IRnnCell StackedRNNCells( | public IRnnCell StackedRNNCells( | ||||
IEnumerable<IRnnCell> cells) | IEnumerable<IRnnCell> cells) | ||||
=> new StackedRNNCells(new StackedRNNCellsArgs | |||||
{ | |||||
Cells = cells.ToList() | |||||
}); | |||||
=> new StackedRNNCells(cells.ToList(), new StackedRNNCellsArgs()); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -756,9 +754,8 @@ namespace Tensorflow.Keras.Layers | |||||
bool stateful = false, | bool stateful = false, | ||||
bool unroll = false, | bool unroll = false, | ||||
bool time_major = false) | bool time_major = false) | ||||
=> new RNN(new RNNArgs | |||||
=> new RNN(cell, new RNNArgs | |||||
{ | { | ||||
Cell = cell, | |||||
ReturnSequences = return_sequences, | ReturnSequences = return_sequences, | ||||
ReturnState = return_state, | ReturnState = return_state, | ||||
GoBackwards = go_backwards, | GoBackwards = go_backwards, | ||||
@@ -775,9 +772,8 @@ namespace Tensorflow.Keras.Layers | |||||
bool stateful = false, | bool stateful = false, | ||||
bool unroll = false, | bool unroll = false, | ||||
bool time_major = false) | bool time_major = false) | ||||
=> new RNN(new RNNArgs | |||||
=> new RNN(cell, new RNNArgs | |||||
{ | { | ||||
Cells = cell.ToList(), | |||||
ReturnSequences = return_sequences, | ReturnSequences = return_sequences, | ||||
ReturnState = return_state, | ReturnState = return_state, | ||||
GoBackwards = go_backwards, | GoBackwards = go_backwards, | ||||
@@ -786,6 +782,33 @@ namespace Tensorflow.Keras.Layers | |||||
TimeMajor = time_major | TimeMajor = time_major | ||||
}); | }); | ||||
public IRnnCell LSTMCell(int uints, | |||||
string activation = "tanh", | |||||
string recurrent_activation = "sigmoid", | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed. | |||||
string bias_initializer = "zeros", | |||||
bool unit_forget_bias = true, | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f, | |||||
int implementation = 2) | |||||
=> new LSTMCell(new LSTMCellArgs | |||||
{ | |||||
Units = uints, | |||||
Activation = keras.activations.GetActivationFromName(activation), | |||||
RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), | |||||
UseBias = use_bias, | |||||
KernelInitializer = GetInitializerByName(kernel_initializer), | |||||
RecurrentInitializer = GetInitializerByName(recurrent_initializer), | |||||
BiasInitializer = GetInitializerByName(bias_initializer), | |||||
UnitForgetBias = unit_forget_bias, | |||||
Dropout = dropout, | |||||
RecurrentDropout = recurrent_dropout, | |||||
Implementation = implementation | |||||
}); | |||||
/// <summary> | /// <summary> | ||||
/// Long Short-Term Memory layer - Hochreiter 1997. | /// Long Short-Term Memory layer - Hochreiter 1997. | ||||
/// </summary> | /// </summary> | ||||
@@ -846,7 +869,8 @@ namespace Tensorflow.Keras.Layers | |||||
GoBackwards = go_backwards, | GoBackwards = go_backwards, | ||||
Stateful = stateful, | Stateful = stateful, | ||||
TimeMajor = time_major, | TimeMajor = time_major, | ||||
Unroll = unroll | |||||
Unroll = unroll, | |||||
UnitForgetBias = unit_forget_bias | |||||
}); | }); | ||||
/// <summary> | /// <summary> | ||||
@@ -4,10 +4,11 @@ using System.Text; | |||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
{ | { | ||||
public abstract class DropoutRNNCellMixin: RnnCellBase | |||||
public abstract class DropoutRNNCellMixin: Layer, IRnnCell | |||||
{ | { | ||||
public float dropout; | public float dropout; | ||||
public float recurrent_dropout; | public float recurrent_dropout; | ||||
@@ -17,6 +18,14 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
public abstract INestStructure<long> StateSize { get; } | |||||
public abstract INestStructure<long> OutputSize { get; } | |||||
public abstract bool SupportOptionalArgs { get; } | |||||
public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype) | |||||
{ | |||||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); | |||||
} | |||||
protected void _create_non_trackable_mask_cache() | protected void _create_non_trackable_mask_cache() | ||||
{ | { | ||||
@@ -32,7 +41,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||||
public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1) | |||||
{ | { | ||||
if (dropout == 0f) | if (dropout == 0f) | ||||
return null; | return null; | ||||
@@ -44,7 +53,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
// Get the recurrent dropout mask for RNN cell. | // Get the recurrent dropout mask for RNN cell. | ||||
public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||||
public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1) | |||||
{ | { | ||||
if (dropout == 0f) | if (dropout == 0f) | ||||
return null; | return null; | ||||
@@ -2,6 +2,7 @@ | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Common.Extensions; | |||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
{ | { | ||||
@@ -14,22 +15,105 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
public class LSTM : RNN | public class LSTM : RNN | ||||
{ | { | ||||
LSTMArgs args; | LSTMArgs args; | ||||
InputSpec[] state_spec; | |||||
int units => args.Units; | |||||
InputSpec[] _state_spec; | |||||
InputSpec _input_spec; | |||||
bool _could_use_gpu_kernel; | |||||
public LSTM(LSTMArgs args) : | public LSTM(LSTMArgs args) : | ||||
base(args) | |||||
base(CreateCell(args), args) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
state_spec = new[] { units, units } | |||||
.Select(dim => new InputSpec(shape: (-1, dim))) | |||||
.ToArray(); | |||||
_input_spec = new InputSpec(ndim: 3); | |||||
_state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray(); | |||||
_could_use_gpu_kernel = args.Activation == keras.activations.Tanh | |||||
&& args.RecurrentActivation == keras.activations.Sigmoid | |||||
&& args.RecurrentDropout == 0 && !args.Unroll && args.UseBias | |||||
&& ops.executing_eagerly_outside_functions(); | |||||
} | |||||
private static IRnnCell CreateCell(LSTMArgs lstmArgs) | |||||
{ | |||||
return new LSTMCell(new LSTMCellArgs() | |||||
{ | |||||
Units = lstmArgs.Units, | |||||
Activation = lstmArgs.Activation, | |||||
RecurrentActivation = lstmArgs.RecurrentActivation, | |||||
UseBias = lstmArgs.UseBias, | |||||
KernelInitializer = lstmArgs.KernelInitializer, | |||||
RecurrentInitializer = lstmArgs.RecurrentInitializer, | |||||
UnitForgetBias = lstmArgs.UnitForgetBias, | |||||
BiasInitializer = lstmArgs.BiasInitializer, | |||||
// TODO(Rinne): kernel_regularizer | |||||
// TODO(Rinne): recurrent_regularizer | |||||
// TODO(Rinne): bias_regularizer | |||||
// TODO(Rinne): kernel_constriant | |||||
// TODO(Rinne): recurrent_constriant | |||||
// TODO(Rinne): bias_constriant | |||||
Dropout = lstmArgs.Dropout, | |||||
RecurrentDropout = lstmArgs.RecurrentDropout, | |||||
Implementation = lstmArgs.Implementation, | |||||
DType = lstmArgs.DType, | |||||
Trainable = lstmArgs.Trainable | |||||
}); | |||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
{ | { | ||||
return base.Call(inputs, initial_state: state, training: training); | |||||
// skip the condition of ragged input | |||||
(inputs, initial_state, _) = _process_inputs(inputs, initial_state, null); | |||||
Tensor mask = null; | |||||
if(optional_args is RnnOptionalArgs rnnArgs) | |||||
{ | |||||
mask = rnnArgs.Mask; | |||||
} | |||||
var single_input = inputs.Single; | |||||
var input_shape = single_input.shape; | |||||
var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; | |||||
_maybe_reset_cell_dropout_mask(Cell); | |||||
Func<Tensors, Tensors, (Tensors, Tensors)> step = (inputs, states) => | |||||
{ | |||||
var res = Cell.Apply(inputs, states, training is null ? true : training.Value); | |||||
var (output, state) = res; | |||||
return (output, state); | |||||
}; | |||||
var (last_output, outputs, states) = keras.backend.rnn( | |||||
step, | |||||
inputs, | |||||
initial_state, | |||||
constants: null, | |||||
go_backwards: args.GoBackwards, | |||||
mask: mask, | |||||
unroll: args.Unroll, | |||||
input_length: ops.convert_to_tensor(timesteps), | |||||
time_major: args.TimeMajor, | |||||
zero_output_for_mask: args.ZeroOutputForMask, | |||||
return_all_outputs: args.ReturnSequences | |||||
); | |||||
Tensor output; | |||||
if (args.ReturnSequences) | |||||
{ | |||||
output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards); | |||||
} | |||||
else | |||||
{ | |||||
output = last_output; | |||||
} | |||||
if (args.ReturnState) | |||||
{ | |||||
return new Tensor[] { output }.Concat(states).ToArray().ToTensors(); | |||||
} | |||||
else | |||||
{ | |||||
return output; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,16 +1,233 @@ | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Newtonsoft.Json; | |||||
using Serilog.Core; | |||||
using System.Diagnostics; | |||||
using Tensorflow.Common.Extensions; | |||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
{ | { | ||||
public class LSTMCell : Layer | |||||
/// <summary> | |||||
/// Cell class for the LSTM layer. | |||||
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) | |||||
/// for details about the usage of RNN API. | |||||
/// This class processes one step within the whole time sequence input, whereas | |||||
/// `tf.keras.layer.LSTM` processes the whole sequence. | |||||
/// </summary> | |||||
public class LSTMCell : DropoutRNNCellMixin | |||||
{ | { | ||||
LSTMCellArgs args; | |||||
LSTMCellArgs _args; | |||||
IVariableV1 _kernel; | |||||
IVariableV1 _recurrent_kernel; | |||||
IInitializer _bias_initializer; | |||||
IVariableV1 _bias; | |||||
INestStructure<long> _state_size; | |||||
INestStructure<long> _output_size; | |||||
public override INestStructure<long> StateSize => _state_size; | |||||
public override INestStructure<long> OutputSize => _output_size; | |||||
public override bool SupportOptionalArgs => false; | |||||
public LSTMCell(LSTMCellArgs args) | public LSTMCell(LSTMCellArgs args) | ||||
: base(args) | : base(args) | ||||
{ | { | ||||
this.args = args; | |||||
_args = args; | |||||
if (args.Units <= 0) | |||||
{ | |||||
throw new ValueError( | |||||
$"units must be a positive integer, got {args.Units}"); | |||||
} | |||||
_args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); | |||||
_args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); | |||||
if (_args.RecurrentDropout != 0f && _args.Implementation != 1) | |||||
{ | |||||
Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." + | |||||
"Using `implementation=1`."); | |||||
_args.Implementation = 1; | |||||
} | |||||
_state_size = new NestList<long>(_args.Units, _args.Units); | |||||
_output_size = new NestNode<long>(_args.Units); | |||||
} | |||||
public override void build(KerasShapesWrapper input_shape) | |||||
{ | |||||
base.build(input_shape); | |||||
var single_shape = input_shape.ToSingleShape(); | |||||
var input_dim = single_shape[-1]; | |||||
_kernel = add_weight("kernel", (input_dim, _args.Units * 4), | |||||
initializer: _args.KernelInitializer | |||||
); | |||||
_recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4), | |||||
initializer: _args.RecurrentInitializer | |||||
); | |||||
if (_args.UseBias) | |||||
{ | |||||
if (_args.UnitForgetBias) | |||||
{ | |||||
Tensor bias_initializer() | |||||
{ | |||||
return keras.backend.concatenate( | |||||
new Tensors( | |||||
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))), | |||||
tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))), | |||||
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
_bias_initializer = _args.BiasInitializer; | |||||
} | |||||
_bias = add_weight("bias", (_args.Units * 4), | |||||
initializer: _bias_initializer | |||||
); | |||||
} | |||||
built = true; | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
{ | |||||
var h_tm1 = states[0]; // previous memory state | |||||
var c_tm1 = states[1]; // previous carry state | |||||
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4); | |||||
var rec_dp_mask = get_recurrent_dropout_mask_for_cell( | |||||
h_tm1, training.Value, count: 4); | |||||
Tensor c; | |||||
Tensor o; | |||||
if (_args.Implementation == 1) | |||||
{ | |||||
Tensor inputs_i; | |||||
Tensor inputs_f; | |||||
Tensor inputs_c; | |||||
Tensor inputs_o; | |||||
if (0f < _args.Dropout && _args.Dropout < 1f) | |||||
{ | |||||
inputs_i = inputs * dp_mask[0]; | |||||
inputs_f = inputs * dp_mask[1]; | |||||
inputs_c = inputs * dp_mask[2]; | |||||
inputs_o = inputs * dp_mask[3]; | |||||
} | |||||
else | |||||
{ | |||||
inputs_i = inputs; | |||||
inputs_f = inputs; | |||||
inputs_c = inputs; | |||||
inputs_o = inputs; | |||||
} | |||||
var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1); | |||||
Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3]; | |||||
var x_i = math_ops.matmul(inputs_i, k_i); | |||||
var x_f = math_ops.matmul(inputs_f, k_f); | |||||
var x_c = math_ops.matmul(inputs_c, k_c); | |||||
var x_o = math_ops.matmul(inputs_o, k_o); | |||||
if (_args.UseBias) | |||||
{ | |||||
var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0); | |||||
Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3]; | |||||
x_i = gen_nn_ops.bias_add(x_i, b_i); | |||||
x_f = gen_nn_ops.bias_add(x_f, b_f); | |||||
x_c = gen_nn_ops.bias_add(x_c, b_c); | |||||
x_o = gen_nn_ops.bias_add(x_o, b_o); | |||||
} | |||||
Tensor h_tm1_i; | |||||
Tensor h_tm1_f; | |||||
Tensor h_tm1_c; | |||||
Tensor h_tm1_o; | |||||
if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f) | |||||
{ | |||||
h_tm1_i = h_tm1 * rec_dp_mask[0]; | |||||
h_tm1_f = h_tm1 * rec_dp_mask[1]; | |||||
h_tm1_c = h_tm1 * rec_dp_mask[2]; | |||||
h_tm1_o = h_tm1 * rec_dp_mask[3]; | |||||
} | |||||
else | |||||
{ | |||||
h_tm1_i = h_tm1; | |||||
h_tm1_f = h_tm1; | |||||
h_tm1_c = h_tm1; | |||||
h_tm1_o = h_tm1; | |||||
} | |||||
var x = new Tensor[] { x_i, x_f, x_c, x_o }; | |||||
var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o }; | |||||
(c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1); | |||||
} | |||||
else | |||||
{ | |||||
if (0f < _args.Dropout && _args.Dropout < 1f) | |||||
inputs = inputs * dp_mask[0]; | |||||
var z = math_ops.matmul(inputs, _kernel.AsTensor()); | |||||
z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor()); | |||||
if (_args.UseBias) | |||||
{ | |||||
z = tf.nn.bias_add(z, _bias); | |||||
} | |||||
var z_array = tf.split(z, num_split: 4, axis: 1); | |||||
(c, o) = _compute_carry_and_output_fused(z_array, c_tm1); | |||||
} | |||||
var h = o * _args.Activation.Apply(c); | |||||
// 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组 | |||||
return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new NestList<Tensor>(h, c) }).ToTensors(); | |||||
} | |||||
/// <summary> | |||||
/// Computes carry and output using split kernels. | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="h_tm1"></param> | |||||
/// <param name="c_tm1"></param> | |||||
/// <returns></returns> | |||||
/// <exception cref="NotImplementedException"></exception> | |||||
public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1) | |||||
{ | |||||
Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3]; | |||||
Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2], | |||||
h_tm1_o = h_tm1[3]; | |||||
var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor(); | |||||
int startIndex = (int)_recurrent_kernel_tensor.shape[0]; | |||||
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||||
new[] { 0, 0 }, new[] { startIndex, _args.Units }); | |||||
var i = _args.RecurrentActivation.Apply( | |||||
x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice)); | |||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||||
new[] { 0, _args.Units }, new[] { startIndex, _args.Units}); | |||||
var f = _args.RecurrentActivation.Apply( | |||||
x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice)); | |||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||||
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units }); | |||||
var c = f * c_tm1 + i * _args.Activation.Apply( | |||||
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); | |||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||||
new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units }); | |||||
var o = _args.Activation.Apply( | |||||
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); | |||||
return new Tensors(c, o); | |||||
} | |||||
/// <summary> | |||||
/// Computes carry and output using fused kernels. | |||||
/// </summary> | |||||
/// <param name="z"></param> | |||||
/// <param name="c_tm1"></param> | |||||
/// <returns></returns> | |||||
public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1) | |||||
{ | |||||
Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3]; | |||||
var i = _args.RecurrentActivation.Apply(z0); | |||||
var f = _args.RecurrentActivation.Apply(z1); | |||||
var c = f * c_tm1 + i * _args.Activation.Apply(z2); | |||||
var o = _args.RecurrentActivation.Apply(z3); | |||||
return new Tensors(c, o); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -11,6 +11,7 @@ using Tensorflow.Common.Extensions; | |||||
using System.Linq.Expressions; | using System.Linq.Expressions; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using System.Runtime.CompilerServices; | |||||
// from tensorflow.python.distribute import distribution_strategy_context as ds_context; | // from tensorflow.python.distribute import distribution_strategy_context as ds_context; | ||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
@@ -30,25 +31,39 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
private int _num_constants; | private int _num_constants; | ||||
protected IVariableV1 _kernel; | protected IVariableV1 _kernel; | ||||
protected IVariableV1 _bias; | protected IVariableV1 _bias; | ||||
protected IRnnCell _cell; | |||||
public RNN(RNNArgs args) : base(PreConstruct(args)) | |||||
private IRnnCell _cell; | |||||
protected IRnnCell Cell | |||||
{ | { | ||||
_args = args; | |||||
SupportsMasking = true; | |||||
// if is StackedRnncell | |||||
if (args.Cells != null) | |||||
get | |||||
{ | { | ||||
_cell = new StackedRNNCells(new StackedRNNCellsArgs | |||||
{ | |||||
Cells = args.Cells | |||||
}); | |||||
return _cell; | |||||
} | } | ||||
else | |||||
init | |||||
{ | { | ||||
_cell = args.Cell; | |||||
_cell = value; | |||||
_self_tracked_trackables.Add(_cell); | |||||
} | } | ||||
} | |||||
public RNN(IRnnCell cell, RNNArgs args) : base(PreConstruct(args)) | |||||
{ | |||||
_args = args; | |||||
SupportsMasking = true; | |||||
Cell = cell; | |||||
// get input_shape | |||||
_args = PreConstruct(args); | |||||
_num_constants = 0; | |||||
} | |||||
public RNN(IEnumerable<IRnnCell> cells, RNNArgs args) : base(PreConstruct(args)) | |||||
{ | |||||
_args = args; | |||||
SupportsMasking = true; | |||||
Cell = new StackedRNNCells(cells, new StackedRNNCellsArgs()); | |||||
// get input_shape | // get input_shape | ||||
_args = PreConstruct(args); | _args = PreConstruct(args); | ||||
@@ -65,7 +80,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
if (_states == null) | if (_states == null) | ||||
{ | { | ||||
// CHECK(Rinne): check if this is correct. | // CHECK(Rinne): check if this is correct. | ||||
var nested = _cell.StateSize.MapStructure<Tensor?>(x => null); | |||||
var nested = Cell.StateSize.MapStructure<Tensor?>(x => null); | |||||
_states = nested.AsNest().ToTensors(); | _states = nested.AsNest().ToTensors(); | ||||
} | } | ||||
return _states; | return _states; | ||||
@@ -73,7 +88,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
set { _states = value; } | set { _states = value; } | ||||
} | } | ||||
private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape) | |||||
private INestStructure<Shape> compute_output_shape(Shape input_shape) | |||||
{ | { | ||||
var batch = input_shape[0]; | var batch = input_shape[0]; | ||||
var time_step = input_shape[1]; | var time_step = input_shape[1]; | ||||
@@ -83,13 +98,15 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
// state_size is a array of ints or a positive integer | // state_size is a array of ints or a positive integer | ||||
var state_size = _cell.StateSize.ToSingleShape(); | |||||
var state_size = Cell.StateSize; | |||||
if(state_size?.TotalNestedCount == 1) | |||||
{ | |||||
state_size = new NestList<long>(state_size.Flatten().First()); | |||||
} | |||||
// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor | |||||
Func<Shape, Shape> _get_output_shape; | |||||
_get_output_shape = (flat_output_size) => | |||||
Func<long, Shape> _get_output_shape = (flat_output_size) => | |||||
{ | { | ||||
var output_dim = flat_output_size.as_int_list(); | |||||
var output_dim = new Shape(flat_output_size).as_int_list(); | |||||
Shape output_shape; | Shape output_shape; | ||||
if (_args.ReturnSequences) | if (_args.ReturnSequences) | ||||
{ | { | ||||
@@ -110,33 +127,30 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
return output_shape; | return output_shape; | ||||
}; | }; | ||||
Type type = _cell.GetType(); | |||||
Type type = Cell.GetType(); | |||||
PropertyInfo output_size_info = type.GetProperty("output_size"); | PropertyInfo output_size_info = type.GetProperty("output_size"); | ||||
Shape output_shape; | |||||
INestStructure<Shape> output_shape; | |||||
if (output_size_info != null) | if (output_size_info != null) | ||||
{ | { | ||||
output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape()); | |||||
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 | |||||
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); | |||||
output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
output_shape = _get_output_shape(state_size); | |||||
output_shape = new NestNode<Shape>(_get_output_shape(state_size.Flatten().First())); | |||||
} | } | ||||
if (_args.ReturnState) | if (_args.ReturnState) | ||||
{ | { | ||||
Func<Shape, Shape> _get_state_shape; | |||||
_get_state_shape = (flat_state) => | |||||
Func<long, Shape> _get_state_shape = (flat_state) => | |||||
{ | { | ||||
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); | |||||
var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list()); | |||||
return new Shape(state_shape); | return new Shape(state_shape); | ||||
}; | }; | ||||
var state_shape = _get_state_shape(state_size); | |||||
var state_shape = Nest.MapStructure(_get_state_shape, state_size); | |||||
return new List<Shape> { output_shape, state_shape }; | |||||
return new Nest<Shape>(new[] { output_shape, state_shape } ); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -171,7 +185,9 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
public override void build(KerasShapesWrapper input_shape) | public override void build(KerasShapesWrapper input_shape) | ||||
{ | { | ||||
object get_input_spec(Shape shape) | |||||
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]); | |||||
InputSpec get_input_spec(Shape shape) | |||||
{ | { | ||||
var input_spec_shape = shape.as_int_list(); | var input_spec_shape = shape.as_int_list(); | ||||
@@ -206,7 +222,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
// append bacth dim | // append bacth dim | ||||
state_spec_shape = new int[] { -1 }.concat(state_spec_shape); | state_spec_shape = new int[] { -1 }.concat(state_spec_shape); | ||||
return new InputSpec(shape: state_spec_shape); | return new InputSpec(shape: state_spec_shape); | ||||
} | } | ||||
// Check whether the input shape contains any nested shapes. It could be | // Check whether the input shape contains any nested shapes. It could be | ||||
@@ -214,10 +229,13 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
// numpy inputs. | // numpy inputs. | ||||
if (!_cell.Built) | |||||
if (Cell is Layer layer && !layer.Built) | |||||
{ | { | ||||
_cell.build(input_shape); | |||||
layer.build(input_shape); | |||||
layer.Built = true; | |||||
} | } | ||||
this.built = true; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -248,10 +266,10 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); | (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); | ||||
_maybe_reset_cell_dropout_mask(_cell); | |||||
if (_cell is StackedRNNCells) | |||||
_maybe_reset_cell_dropout_mask(Cell); | |||||
if (Cell is StackedRNNCells) | |||||
{ | { | ||||
var stack_cell = _cell as StackedRNNCells; | |||||
var stack_cell = Cell as StackedRNNCells; | |||||
foreach (IRnnCell cell in stack_cell.Cells) | foreach (IRnnCell cell in stack_cell.Cells) | ||||
{ | { | ||||
_maybe_reset_cell_dropout_mask(cell); | _maybe_reset_cell_dropout_mask(cell); | ||||
@@ -298,23 +316,23 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) | // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) | ||||
Func<Tensors, Tensors, (Tensors, Tensors)> step; | Func<Tensors, Tensors, (Tensors, Tensors)> step; | ||||
bool is_tf_rnn_cell = _cell.IsTFRnnCell; | |||||
bool is_tf_rnn_cell = false; | |||||
if (constants is not null) | if (constants is not null) | ||||
{ | { | ||||
if (!_cell.SupportOptionalArgs) | |||||
if (!Cell.SupportOptionalArgs) | |||||
{ | { | ||||
throw new ValueError( | throw new ValueError( | ||||
$"RNN cell {_cell} does not support constants." + | |||||
$"RNN cell {Cell} does not support constants." + | |||||
$"Received: constants={constants}"); | $"Received: constants={constants}"); | ||||
} | } | ||||
step = (inputs, states) => | step = (inputs, states) => | ||||
{ | { | ||||
constants = new Tensors(states.TakeLast(_num_constants)); | |||||
states = new Tensors(states.SkipLast(_num_constants)); | |||||
constants = new Tensors(states.TakeLast(_num_constants).ToArray()); | |||||
states = new Tensors(states.SkipLast(_num_constants).ToArray()); | |||||
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | ||||
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||||
return (output, new_states.Single); | |||||
var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||||
return (output, new_states); | |||||
}; | }; | ||||
} | } | ||||
else | else | ||||
@@ -322,7 +340,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
step = (inputs, states) => | step = (inputs, states) => | ||||
{ | { | ||||
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states; | states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states; | ||||
var (output, new_states) = _cell.Apply(inputs, states); | |||||
var (output, new_states) = Cell.Apply(inputs, states); | |||||
return (output, new_states); | return (output, new_states); | ||||
}; | }; | ||||
} | } | ||||
@@ -366,6 +384,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
//var tapeSet = tf.GetTapeSet(); | |||||
//foreach(var tape in tapeSet) | |||||
//{ | |||||
// tape.Watch(output); | |||||
//} | |||||
return output; | return output; | ||||
} | } | ||||
} | } | ||||
@@ -389,18 +412,18 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants) | |||||
protected (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants) | |||||
{ | { | ||||
if (inputs.Length > 1) | if (inputs.Length > 1) | ||||
{ | { | ||||
if (_num_constants != 0) | if (_num_constants != 0) | ||||
{ | { | ||||
initial_state = new Tensors(inputs.Skip(1)); | |||||
initial_state = new Tensors(inputs.Skip(1).ToArray()); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants)); | |||||
constants = new Tensors(inputs.TakeLast(_num_constants)); | |||||
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray()); | |||||
constants = new Tensors(inputs.TakeLast(_num_constants).ToArray()); | |||||
} | } | ||||
if (len(initial_state) == 0) | if (len(initial_state) == 0) | ||||
initial_state = null; | initial_state = null; | ||||
@@ -418,7 +441,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
tmp.add(tf.math.count_nonzero(s.Single())); | tmp.add(tf.math.count_nonzero(s.Single())); | ||||
} | } | ||||
var non_zero_count = tf.add_n(tmp); | var non_zero_count = tf.add_n(tmp); | ||||
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); | |||||
initial_state = tf.cond(non_zero_count > 0, States, initial_state); | |||||
if ((int)non_zero_count.numpy() > 0) | if ((int)non_zero_count.numpy() > 0) | ||||
{ | { | ||||
initial_state = States; | initial_state = States; | ||||
@@ -428,16 +451,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | { | ||||
initial_state = States; | initial_state = States; | ||||
} | } | ||||
// TODO(Wanglongzhi2001), | |||||
// initial_state = tf.nest.map_structure( | |||||
//# When the layer has a inferred dtype, use the dtype from the | |||||
//# cell. | |||||
// lambda v: tf.cast( | |||||
// v, self.compute_dtype or self.cell.compute_dtype | |||||
// ), | |||||
// initial_state, | |||||
// ) | |||||
//initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state); | |||||
} | } | ||||
else if (initial_state is null) | else if (initial_state is null) | ||||
{ | { | ||||
@@ -477,7 +491,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
void _maybe_reset_cell_dropout_mask(ILayer cell) | |||||
protected void _maybe_reset_cell_dropout_mask(ILayer cell) | |||||
{ | { | ||||
if (cell is DropoutRNNCellMixin CellDRCMixin) | if (cell is DropoutRNNCellMixin CellDRCMixin) | ||||
{ | { | ||||
@@ -488,26 +502,21 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
private static RNNArgs PreConstruct(RNNArgs args) | private static RNNArgs PreConstruct(RNNArgs args) | ||||
{ | { | ||||
if (args.Kwargs == null) | |||||
{ | |||||
args.Kwargs = new Dictionary<string, object>(); | |||||
} | |||||
// If true, the output for masked timestep will be zeros, whereas in the | // If true, the output for masked timestep will be zeros, whereas in the | ||||
// false case, output from previous timestep is returned for masked timestep. | // false case, output from previous timestep is returned for masked timestep. | ||||
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); | |||||
var zeroOutputForMask = args.ZeroOutputForMask; | |||||
Shape input_shape; | Shape input_shape; | ||||
var propIS = (Shape)args.Kwargs.Get("input_shape", null); | |||||
var propID = (int?)args.Kwargs.Get("input_dim", null); | |||||
var propIL = (int?)args.Kwargs.Get("input_length", null); | |||||
var propIS = args.InputShape; | |||||
var propID = args.InputDim; | |||||
var propIL = args.InputLength; | |||||
if (propIS == null && (propID != null || propIL != null)) | if (propIS == null && (propID != null || propIL != null)) | ||||
{ | { | ||||
input_shape = new Shape( | input_shape = new Shape( | ||||
propIL ?? -1, | propIL ?? -1, | ||||
propID ?? -1); | propID ?? -1); | ||||
args.Kwargs["input_shape"] = input_shape; | |||||
args.InputShape = input_shape; | |||||
} | } | ||||
return args; | return args; | ||||
@@ -558,36 +567,14 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
protected Tensors get_initial_state(Tensors inputs) | protected Tensors get_initial_state(Tensors inputs) | ||||
{ | { | ||||
var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state"); | |||||
var input = inputs[0]; | var input = inputs[0]; | ||||
var input_shape = inputs.shape; | |||||
var input_shape = array_ops.shape(inputs); | |||||
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | ||||
var dtype = input.dtype; | var dtype = input.dtype; | ||||
Tensors init_state = new Tensors(); | |||||
if(get_initial_state_fn != null) | |||||
{ | |||||
init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype }); | |||||
} | |||||
//if (_cell is RnnCellBase rnn_base_cell) | |||||
//{ | |||||
// init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); | |||||
//} | |||||
else | |||||
{ | |||||
init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); | |||||
} | |||||
Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); | |||||
return init_state; | return init_state; | ||||
} | } | ||||
// Check whether the state_size contains multiple states. | |||||
public static bool is_multiple_state(GeneralizedTensorShape state_size) | |||||
{ | |||||
return state_size.Shapes.Length > 1; | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,24 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | |||||
public abstract class RnnCellBase: Layer, IRnnCell | |||||
{ | |||||
public RnnCellBase(LayerArgs args) : base(args) { } | |||||
public abstract GeneralizedTensorShape StateSize { get; } | |||||
public abstract GeneralizedTensorShape OutputSize { get; } | |||||
public abstract bool IsTFRnnCell { get; } | |||||
public abstract bool SupportOptionalArgs { get; } | |||||
public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype) | |||||
{ | |||||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); | |||||
} | |||||
} | |||||
} |
@@ -10,14 +10,14 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
public class SimpleRNN : RNN | public class SimpleRNN : RNN | ||||
{ | { | ||||
SimpleRNNArgs args; | SimpleRNNArgs args; | ||||
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args)) | |||||
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args), args) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
} | } | ||||
private static SimpleRNNArgs CreateCellForArgs(SimpleRNNArgs args) | |||||
private static SimpleRNNCell CreateCellForArgs(SimpleRNNArgs args) | |||||
{ | { | ||||
args.Cell = new SimpleRNNCell(new SimpleRNNCellArgs() | |||||
return new SimpleRNNCell(new SimpleRNNCellArgs() | |||||
{ | { | ||||
Units = args.Units, | Units = args.Units, | ||||
Activation = args.Activation, | Activation = args.Activation, | ||||
@@ -30,21 +30,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
DType = args.DType, | DType = args.DType, | ||||
Trainable = args.Trainable, | Trainable = args.Trainable, | ||||
}); | }); | ||||
return args; | |||||
} | |||||
public override void build(KerasShapesWrapper input_shape) | |||||
{ | |||||
var single_shape = input_shape.ToSingleShape(); | |||||
var input_dim = single_shape[-1]; | |||||
_buildInputShape = input_shape; | |||||
_kernel = add_weight("kernel", (single_shape[-1], args.Units), | |||||
initializer: args.KernelInitializer | |||||
//regularizer = self.kernel_regularizer, | |||||
//constraint = self.kernel_constraint, | |||||
//caching_device = default_caching_device, | |||||
); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -7,6 +7,7 @@ using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
{ | { | ||||
@@ -23,12 +24,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
IVariableV1 _kernel; | IVariableV1 _kernel; | ||||
IVariableV1 _recurrent_kernel; | IVariableV1 _recurrent_kernel; | ||||
IVariableV1 _bias; | IVariableV1 _bias; | ||||
GeneralizedTensorShape _state_size; | |||||
GeneralizedTensorShape _output_size; | |||||
INestStructure<long> _state_size; | |||||
INestStructure<long> _output_size; | |||||
public override GeneralizedTensorShape StateSize => _state_size; | |||||
public override GeneralizedTensorShape OutputSize => _output_size; | |||||
public override bool IsTFRnnCell => true; | |||||
public override INestStructure<long> StateSize => _state_size; | |||||
public override INestStructure<long> OutputSize => _output_size; | |||||
public override bool SupportOptionalArgs => false; | public override bool SupportOptionalArgs => false; | ||||
public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) | public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) | ||||
@@ -41,8 +41,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); | this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); | ||||
this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); | this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); | ||||
_state_size = new GeneralizedTensorShape(args.Units); | |||||
_output_size = new GeneralizedTensorShape(args.Units); | |||||
_state_size = new NestNode<long>(args.Units); | |||||
_output_size = new NestNode<long>(args.Units); | |||||
} | } | ||||
public override void build(KerasShapesWrapper input_shape) | public override void build(KerasShapesWrapper input_shape) | ||||
@@ -74,8 +74,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | { | ||||
// TODO(Rinne): check if it will have multiple tensors when not nested. | // TODO(Rinne): check if it will have multiple tensors when not nested. | ||||
Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; | Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; | ||||
var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value); | |||||
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); | |||||
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value); | |||||
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value); | |||||
Tensor h; | Tensor h; | ||||
var ranks = inputs.rank; | var ranks = inputs.rank; | ||||
@@ -98,7 +98,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | { | ||||
prev_output = math_ops.multiply(prev_output, rec_dp_mask); | prev_output = math_ops.multiply(prev_output, rec_dp_mask); | ||||
} | } | ||||
var tmp = _recurrent_kernel.AsTensor(); | |||||
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); | Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); | ||||
if (_args.Activation != null) | if (_args.Activation != null) | ||||
@@ -116,10 +115,5 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
return new Tensors(output, output); | return new Tensors(output, output); | ||||
} | } | ||||
} | } | ||||
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||||
{ | |||||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,10 +1,8 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.ComponentModel; | using System.ComponentModel; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
@@ -15,30 +13,15 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
public class StackedRNNCells : Layer, IRnnCell | public class StackedRNNCells : Layer, IRnnCell | ||||
{ | { | ||||
public IList<IRnnCell> Cells { get; set; } | public IList<IRnnCell> Cells { get; set; } | ||||
public bool reverse_state_order; | |||||
public bool _reverse_state_order; | |||||
public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | |||||
public StackedRNNCells(IEnumerable<IRnnCell> cells, StackedRNNCellsArgs args) : base(args) | |||||
{ | { | ||||
if (args.Kwargs == null) | |||||
{ | |||||
args.Kwargs = new Dictionary<string, object>(); | |||||
} | |||||
foreach (var cell in args.Cells) | |||||
{ | |||||
//Type type = cell.GetType(); | |||||
//var CallMethodInfo = type.GetMethod("Call"); | |||||
//if (CallMethodInfo == null) | |||||
//{ | |||||
// throw new ValueError( | |||||
// "All cells must have a `Call` method. " + | |||||
// $"Received cell without a `Call` method: {cell}"); | |||||
//} | |||||
} | |||||
Cells = args.Cells; | |||||
reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); | |||||
Cells = cells.ToList(); | |||||
if (reverse_state_order) | |||||
_reverse_state_order = args.ReverseStateOrder; | |||||
if (_reverse_state_order) | |||||
{ | { | ||||
throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " + | throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " + | ||||
"be deprecated. Please update the code to work with the " + | "be deprecated. Please update the code to work with the " + | ||||
@@ -47,49 +30,37 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
} | } | ||||
public GeneralizedTensorShape StateSize | |||||
public bool SupportOptionalArgs => false; | |||||
public INestStructure<long> StateSize | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count); | |||||
if (reverse_state_order && Cells.Count > 0) | |||||
if (_reverse_state_order) | |||||
{ | { | ||||
var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell)); | |||||
foreach (var cell in idxAndCell) | |||||
{ | |||||
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); | |||||
} | |||||
var state_sizes = Cells.Reverse().Select(cell => cell.StateSize); | |||||
return new Nest<long>(state_sizes); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
//foreach (var cell in Cells) | |||||
//{ | |||||
// state_size.Shapes.add(cell.StateSize.Shapes.First()); | |||||
//} | |||||
var idxAndCell = Cells.Select((cell, idx) => (idx, cell)); | |||||
foreach (var cell in idxAndCell) | |||||
{ | |||||
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); | |||||
} | |||||
var state_sizes = Cells.Select(cell => cell.StateSize); | |||||
return new Nest<long>(state_sizes); | |||||
} | } | ||||
return state_size; | |||||
} | } | ||||
} | } | ||||
public object output_size | |||||
public INestStructure<long> OutputSize | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
var lastCell = Cells.LastOrDefault(); | |||||
if (lastCell.OutputSize.ToSingleShape() != -1) | |||||
var lastCell = Cells.Last(); | |||||
if(lastCell.OutputSize is not null) | |||||
{ | { | ||||
return lastCell.OutputSize; | return lastCell.OutputSize; | ||||
} | } | ||||
else if (RNN.is_multiple_state(lastCell.StateSize)) | |||||
else if (RnnUtils.is_multiple_state(lastCell.StateSize)) | |||||
{ | { | ||||
return lastCell.StateSize.First(); | |||||
//throw new NotImplementedException(""); | |||||
return new NestNode<long>(lastCell.StateSize.Flatten().First()); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -98,79 +69,65 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
} | } | ||||
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||||
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | { | ||||
var cells = reverse_state_order ? Cells.Reverse() : Cells; | |||||
Tensors initial_states = new Tensors(); | |||||
var cells = _reverse_state_order ? Cells.Reverse() : Cells; | |||||
List<Tensor> initial_states = new List<Tensor>(); | |||||
foreach (var cell in cells) | foreach (var cell in cells) | ||||
{ | { | ||||
var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state"); | |||||
if (get_initial_state_fn != null) | |||||
{ | |||||
var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype }); | |||||
initial_states.Add(result); | |||||
} | |||||
else | |||||
{ | |||||
initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value)); | |||||
} | |||||
initial_states.Add(cell.GetInitialState(inputs, batch_size, dtype)); | |||||
} | } | ||||
return initial_states; | |||||
return new Tensors(initial_states); | |||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
{ | { | ||||
// Recover per-cell states. | // Recover per-cell states. | ||||
var state_size = reverse_state_order ? StateSize.Reverse() : StateSize; | |||||
var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten(); | |||||
var state_size = _reverse_state_order ? new NestList<long>(StateSize.Flatten().Reverse()) : StateSize; | |||||
var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray()); | |||||
var new_nest_states = new Tensors(); | |||||
var new_nest_states = Nest<Tensor>.Empty; | |||||
// Call the cells in order and store the returned states. | // Call the cells in order and store the returned states. | ||||
foreach (var (cell, states) in zip(Cells, nested_states)) | |||||
foreach (var (cell, internal_states) in zip(Cells, nested_states)) | |||||
{ | { | ||||
// states = states if tf.nest.is_nested(states) else [states] | |||||
var type = cell.GetType(); | |||||
bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null; | |||||
state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state; | |||||
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | ||||
Tensors? constants = rnn_optional_args?.Constants; | Tensors? constants = rnn_optional_args?.Constants; | ||||
Tensors new_states; | Tensors new_states; | ||||
(inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||||
(inputs, new_states) = cell.Apply(inputs, internal_states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||||
new_nest_states.Add(new_states); | |||||
new_nest_states = new_nest_states.MergeWith(new_states); | |||||
} | } | ||||
new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray(); | |||||
return new Nest<Tensor>(new List<Nest<Tensor>> { | |||||
new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(inputs.Single()) }), new Nest<Tensor>(new_nest_states) }) | |||||
.ToTensors(); | |||||
return Tensors.FromNest((inputs, Nest.PackSequenceAs(state_size, Nest.Flatten(new_nest_states).ToArray()))); | |||||
} | } | ||||
public void build() | |||||
public override void build(KerasShapesWrapper input_shape) | |||||
{ | { | ||||
built = true; | |||||
// @tf_utils.shape_type_conversion | |||||
// def build(self, input_shape) : | |||||
// if isinstance(input_shape, list) : | |||||
// input_shape = input_shape[0] | |||||
// for cell in self.cells: | |||||
// if isinstance(cell, Layer) and not cell.built: | |||||
// with K.name_scope(cell.name): | |||||
// cell.build(input_shape) | |||||
// cell.built = True | |||||
// if getattr(cell, 'output_size', None) is not None: | |||||
// output_dim = cell.output_size | |||||
// elif _is_multiple_state(cell.state_size) : | |||||
// output_dim = cell.state_size[0] | |||||
// else: | |||||
// output_dim = cell.state_size | |||||
// input_shape = tuple([input_shape[0]] + | |||||
// tensor_shape.TensorShape(output_dim).as_list()) | |||||
// self.built = True | |||||
var shape = input_shape.ToSingleShape(); | |||||
foreach(var cell in Cells) | |||||
{ | |||||
if(cell is Layer layer && !layer.Built) | |||||
{ | |||||
// ignored the name scope. | |||||
layer.build(shape); | |||||
layer.Built = true; | |||||
} | |||||
INestStructure<long> output_dim; | |||||
if(cell.OutputSize is not null) | |||||
{ | |||||
output_dim = cell.OutputSize; | |||||
} | |||||
else if (RnnUtils.is_multiple_state(cell.StateSize)) | |||||
{ | |||||
output_dim = new NestNode<long>(cell.StateSize.Flatten().First()); | |||||
} | |||||
else | |||||
{ | |||||
output_dim = cell.StateSize; | |||||
} | |||||
shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.Flatten()).ToArray()); | |||||
} | |||||
this.Built = true; | |||||
} | } | ||||
public override IKerasConfig get_config() | public override IKerasConfig get_config() | ||||
@@ -198,14 +155,5 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
// deserialize_layer(cell_config, custom_objects = custom_objects)) | // deserialize_layer(cell_config, custom_objects = custom_objects)) | ||||
// return cls(cells, **config) | // return cls(cells, **config) | ||||
} | } | ||||
public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | |||||
public bool IsTFRnnCell => true; | |||||
public bool SupportOptionalArgs => throw new NotImplementedException(); | |||||
} | } | ||||
} | } |
@@ -10,33 +10,33 @@ namespace Tensorflow.Keras.Utils | |||||
{ | { | ||||
internal static class RnnUtils | internal static class RnnUtils | ||||
{ | { | ||||
internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) | |||||
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, INestStructure<long> state_size, TF_DataType dtype) | |||||
{ | { | ||||
Func<GeneralizedTensorShape, Tensor> create_zeros; | |||||
create_zeros = (GeneralizedTensorShape unnested_state_size) => | |||||
Func<long, Tensor> create_zeros = (unnested_state_size) => | |||||
{ | { | ||||
var flat_dims = unnested_state_size.ToSingleShape().dims; | |||||
var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray(); | |||||
return array_ops.zeros(new Shape(init_state_size), dtype: dtype); | |||||
var flat_dims = new Shape(unnested_state_size).dims; | |||||
var init_state_size = new Tensor[] { batch_size_tensor }. | |||||
Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray(); | |||||
return array_ops.zeros(init_state_size, dtype: dtype); | |||||
}; | }; | ||||
// TODO(Rinne): map structure with nested tensors. | // TODO(Rinne): map structure with nested tensors. | ||||
if(state_size.Shapes.Length > 1) | |||||
if(state_size.TotalNestedCount > 1) | |||||
{ | { | ||||
return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s)))); | |||||
return new Tensors(state_size.Flatten().Select(s => create_zeros(s)).ToArray()); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
return create_zeros(state_size); | |||||
return create_zeros(state_size.Flatten().First()); | |||||
} | } | ||||
} | } | ||||
internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype) | |||||
internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype) | |||||
{ | { | ||||
if (inputs != null) | |||||
if (inputs is not null) | |||||
{ | { | ||||
batch_size = inputs.shape[0]; | |||||
batch_size = array_ops.shape(inputs)[0]; | |||||
dtype = inputs.dtype; | dtype = inputs.dtype; | ||||
} | } | ||||
return generate_zero_filled_state(batch_size, cell.StateSize, dtype); | return generate_zero_filled_state(batch_size, cell.StateSize, dtype); | ||||
@@ -77,17 +77,27 @@ namespace Tensorflow.Keras.Utils | |||||
Debug.Assert(initial_state is null && constants is null); | Debug.Assert(initial_state is null && constants is null); | ||||
if(num_constants > 0) | if(num_constants > 0) | ||||
{ | { | ||||
constants = inputs.TakeLast(num_constants).ToTensors(); | |||||
inputs = inputs.SkipLast(num_constants).ToTensors(); | |||||
constants = inputs.TakeLast(num_constants).ToArray().ToTensors(); | |||||
inputs = inputs.SkipLast(num_constants).ToArray().ToTensors(); | |||||
} | } | ||||
if(inputs.Length > 1) | if(inputs.Length > 1) | ||||
{ | { | ||||
initial_state = inputs.Skip(1).ToTensors(); | |||||
inputs = inputs.Take(1).ToTensors(); | |||||
initial_state = inputs.Skip(1).ToArray().ToTensors(); | |||||
inputs = inputs.Take(1).ToArray().ToTensors(); | |||||
} | } | ||||
} | } | ||||
return (inputs, initial_state, constants); | return (inputs, initial_state, constants); | ||||
} | } | ||||
/// <summary> | |||||
/// Check whether the state_size contains multiple states. | |||||
/// </summary> | |||||
/// <param name="state_size"></param> | |||||
/// <returns></returns> | |||||
public static bool is_multiple_state(INestStructure<long> state_size) | |||||
{ | |||||
return state_size.TotalNestedCount > 1; | |||||
} | |||||
} | } | ||||
} | } |
@@ -21,21 +21,6 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
[TestMethod] | [TestMethod] | ||||
public void SimpleRNNCell() | public void SimpleRNNCell() | ||||
{ | { | ||||
//var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | |||||
//var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | |||||
//var x = tf.random.normal((4, 100)); | |||||
//var (y, h1) = cell.Apply(inputs: x, states: h0); | |||||
//var h2 = h1; | |||||
//Assert.AreEqual((4, 64), y.shape); | |||||
//Assert.AreEqual((4, 64), h2[0].shape); | |||||
//var model = keras.Sequential(new List<ILayer> | |||||
//{ | |||||
// keras.layers.InputLayer(input_shape: (4,100)), | |||||
// keras.layers.SimpleRNNCell(64) | |||||
//}); | |||||
//model.summary(); | |||||
var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | ||||
var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | ||||
var x = tf.random.normal((4, 100)); | var x = tf.random.normal((4, 100)); | ||||
@@ -60,24 +45,63 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void SimpleRNN() | |||||
public void LSTMCell() | |||||
{ | { | ||||
//var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); | |||||
///*var simple_rnn = keras.layers.SimpleRNN(4); | |||||
//var output = simple_rnn.Apply(inputs); | |||||
//Assert.AreEqual((32, 4), output.shape);*/ | |||||
var inputs = tf.ones((2, 100)); | |||||
var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) }; | |||||
var rnn = tf.keras.layers.LSTMCell(4); | |||||
var (output, new_states) = rnn.Apply(inputs, states); | |||||
Assert.AreEqual((2, 4), output.shape); | |||||
Assert.AreEqual((2, 4), new_states[0].shape); | |||||
} | |||||
//var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); | |||||
//var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); | |||||
//Assert.AreEqual((6, 10, 4), whole_sequence_output.shape); | |||||
//Assert.AreEqual((6, 4), final_state.shape); | |||||
[TestMethod] | |||||
public void TrainLSTMWithMnist() | |||||
{ | |||||
var input = keras.Input((784)); | |||||
var x = keras.layers.Reshape((28, 28)).Apply(input); | |||||
x = keras.layers.LSTM(50, return_sequences: true).Apply(x); | |||||
x = keras.layers.LSTM(100).Apply(x); | |||||
var output = keras.layers.Dense(10, activation: "softmax").Apply(x); | |||||
var inputs = keras.Input(shape: (10, 8)); | |||||
var x = keras.layers.SimpleRNN(4).Apply(inputs); | |||||
var output = keras.layers.Dense(10).Apply(x); | |||||
var model = keras.Model(inputs, output); | |||||
var model = keras.Model(input, output); | |||||
model.summary(); | model.summary(); | ||||
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
var data_loader = new MnistModelLoader(); | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = true, | |||||
ValidationSize = 55000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1); | |||||
} | } | ||||
[TestMethod] | |||||
public void SimpleRNN() | |||||
{ | |||||
var input = keras.Input((784)); | |||||
var x = keras.layers.Reshape((28, 28)).Apply(input); | |||||
x = keras.layers.SimpleRNN(10).Apply(x); | |||||
var output = keras.layers.Dense(10, activation: "softmax").Apply(x); | |||||
var model = keras.Model(input, output); | |||||
model.summary(); | |||||
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
var data_loader = new MnistModelLoader(); | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 58000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 2); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void RNNForSimpleRNNCell() | public void RNNForSimpleRNNCell() | ||||
{ | { | ||||
@@ -100,15 +124,13 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void WlzTest() | |||||
public void RNNForLSTMCell() | |||||
{ | { | ||||
long[] b = { 1, 2, 3 }; | |||||
Shape a = new Shape(Unknown).concatenate(b); | |||||
Console.WriteLine(a); | |||||
var inputs = tf.ones((5, 10, 8)); | |||||
var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4)); | |||||
var output = rnn.Apply(inputs); | |||||
Console.WriteLine($"output: {output}"); | |||||
Assert.AreEqual((5, 4), output.shape); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -28,8 +28,8 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
var i = tf.constant(2); | var i = tf.constant(2); | ||||
var j = tf.constant(3); | var j = tf.constant(3); | ||||
Func<Tensor[], Tensor> c = (x) => tf.less(x[0] + x[1], 10); | |||||
Func<Tensor[], Tensor[]> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) }; | |||||
Func<Tensors, Tensor> c = (x) => tf.less(x[0] + x[1], 10); | |||||
Func<Tensors, Tensors> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) }; | |||||
var r = tf.while_loop(c, b, new[] { i, j }); | var r = tf.while_loop(c, b, new[] { i, j }); | ||||
Assert.AreEqual(5, (int)r[0]); | Assert.AreEqual(5, (int)r[0]); | ||||
Assert.AreEqual(6, (int)r[1]); | Assert.AreEqual(6, (int)r[1]); | ||||
@@ -21,7 +21,8 @@ namespace Tensorflow.CodeGen | |||||
{ | { | ||||
sb.Append("Operation "); | sb.Append("Operation "); | ||||
} | } | ||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||||
{ | { | ||||
sb.Append("Tensor "); | sb.Append("Tensor "); | ||||
} | } | ||||
@@ -70,7 +71,8 @@ namespace Tensorflow.CodeGen | |||||
{ | { | ||||
sb.AppendLine("return null;"); | sb.AppendLine("return null;"); | ||||
} | } | ||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||||
{ | { | ||||
sb.AppendLine("return _fast_path_result[0];"); | sb.AppendLine("return _fast_path_result[0];"); | ||||
} | } | ||||
@@ -81,6 +83,14 @@ namespace Tensorflow.CodeGen | |||||
sb.AppendLine("}"); // try | sb.AppendLine("}"); // try | ||||
sb.Append("catch(NotOkStatusException ex1)\n{\n"); | |||||
sb.AppendLine("throw ex1;"); | |||||
sb.AppendLine("}"); // catch | |||||
sb.Append("catch(InvalidArgumentError ex2)\n{\n"); | |||||
sb.AppendLine("throw ex2;"); | |||||
sb.AppendLine("}"); // catch | |||||
sb.Append("catch(Exception)\n{\n"); | sb.Append("catch(Exception)\n{\n"); | ||||
sb.AppendLine("}"); // catch | sb.AppendLine("}"); // catch | ||||
@@ -149,7 +159,8 @@ namespace Tensorflow.CodeGen | |||||
{ | { | ||||
sb.AppendLine("return _op;"); | sb.AppendLine("return _op;"); | ||||
} | } | ||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||||
{ | { | ||||
sb.AppendLine("return _result[0];"); | sb.AppendLine("return _result[0];"); | ||||
} | } | ||||
@@ -174,7 +185,7 @@ namespace Tensorflow.CodeGen | |||||
{ | { | ||||
argName = $"{argName}_"; | argName = $"{argName}_"; | ||||
} | } | ||||
if (!string.IsNullOrEmpty(arg.NumberAttr)) | |||||
if (!string.IsNullOrEmpty(arg.NumberAttr) || !string.IsNullOrEmpty(arg.TypeListAttr)) | |||||
{ | { | ||||
sb.Append($"Tensors {argName}, "); | sb.Append($"Tensors {argName}, "); | ||||
} | } | ||||
@@ -273,7 +284,8 @@ namespace Tensorflow.CodeGen | |||||
{ | { | ||||
sb.Append("Operation "); | sb.Append("Operation "); | ||||
} | } | ||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||||
{ | { | ||||
sb.Append("Tensor "); | sb.Append("Tensor "); | ||||
} | } | ||||
@@ -366,6 +378,13 @@ namespace Tensorflow.CodeGen | |||||
sb.Append($"\"{attr.Name}\", {attrRealName}, "); | sb.Append($"\"{attr.Name}\", {attrRealName}, "); | ||||
} | } | ||||
} | } | ||||
else if(attr.Type == "list(type)") | |||||
{ | |||||
if (op.InputArg.Any(x => x.TypeListAttr == attr.Name)) | |||||
{ | |||||
continue; | |||||
} | |||||
} | |||||
else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name)) | else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name)) | ||||
{ | { | ||||
bool found = false; | bool found = false; | ||||
@@ -408,7 +427,8 @@ namespace Tensorflow.CodeGen | |||||
{ | { | ||||
sb.AppendLine("return null;"); | sb.AppendLine("return null;"); | ||||
} | } | ||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||||
{ | { | ||||
sb.AppendLine("return _result[0];"); | sb.AppendLine("return _result[0];"); | ||||
} | } | ||||
@@ -39,6 +39,7 @@ namespace Tensorflow.CodeGen | |||||
// Add commonly used namespaces. | // Add commonly used namespaces. | ||||
sb.AppendLine("using Tensorflow.Eager;"); | sb.AppendLine("using Tensorflow.Eager;"); | ||||
sb.AppendLine("using Tensorflow.Contexts;"); | sb.AppendLine("using Tensorflow.Contexts;"); | ||||
sb.AppendLine("using Tensorflow.Exceptions;"); | |||||
sb.AppendLine("using static Tensorflow.Binding;"); | sb.AppendLine("using static Tensorflow.Binding;"); | ||||
sb.AppendLine(); | sb.AppendLine(); | ||||
@@ -9,7 +9,7 @@ namespace Tensorflow.CodeGen | |||||
{ | { | ||||
public class OpClassifier | public class OpClassifier | ||||
{ | { | ||||
private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$"; | |||||
private static readonly string _filenamePattern = @"^gen_[a-z_]*_ops.py$"; | |||||
private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):"; | private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):"; | ||||
private Dictionary<string, HashSet<string>> _opSet = new(); | private Dictionary<string, HashSet<string>> _opSet = new(); | ||||
public Dictionary<string, HashSet<string>> OpSet => _opSet; | public Dictionary<string, HashSet<string>> OpSet => _opSet; | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
using System.Xml.Linq; | using System.Xml.Linq; | ||||
using Tensorflow.CodeGen; | using Tensorflow.CodeGen; | ||||
GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", | |||||
GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops_v2", | |||||
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | ||||
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", | @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", | ||||
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | ||||
@@ -9,7 +9,7 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | ||||
<PackageReference Include="Protobuf.Text" Version="0.7.0" /> | |||||
<PackageReference Include="Protobuf.Text" Version="0.7.1" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -155,6 +155,10 @@ namespace Tensorflow.CodeGen | |||||
} | } | ||||
else if (attr.Type == "list(type)") | else if (attr.Type == "list(type)") | ||||
{ | { | ||||
if(op.InputArg.Any(x => x.TypeListAttr == attr.Name)) | |||||
{ | |||||
continue; | |||||
} | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | ||||
{ | { | ||||
List<TF_DataType> values = new(); | List<TF_DataType> values = new(); | ||||
@@ -174,10 +178,25 @@ namespace Tensorflow.CodeGen | |||||
else if (attr.Type == "list(shape)") | else if (attr.Type == "list(shape)") | ||||
{ | { | ||||
res.Add((attr.Name, "Shape[]", "NOVALUE")); | res.Add((attr.Name, "Shape[]", "NOVALUE")); | ||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | |||||
List<string> exps = new(); | |||||
foreach (var value in attr.DefaultValue.List.Shape) | |||||
{ | |||||
exps.Add($"new Shape({string.Join(", ", value.Dim.Select(x => x.Size))})"); | |||||
} | |||||
string expression = "new Shape[]{" + $"{string.Join(", ", exps)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "string[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "string[]", "NOVALUE")); | |||||
} | |||||
} | } | ||||
else if (attr.Type == "list(string)") | else if (attr.Type == "list(string)") | ||||
{ | { | ||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | { | ||||
List<string> values = new(); | List<string> values = new(); | ||||
foreach (var value in attr.DefaultValue.List.S) | foreach (var value in attr.DefaultValue.List.S) | ||||
@@ -231,11 +250,11 @@ namespace Tensorflow.CodeGen | |||||
} | } | ||||
else if (attr.Type == "func") | else if (attr.Type == "func") | ||||
{ | { | ||||
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||||
res.Add((attr.Name, "object", "NOVALUE")); | |||||
} | } | ||||
else if (attr.Type == "list(func)") | else if (attr.Type == "list(func)") | ||||
{ | { | ||||
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||||
res.Add((attr.Name, "object[]", "NOVALUE")); | |||||
} | } | ||||
else if (attr.Type == "tensor") | else if (attr.Type == "tensor") | ||||
{ | { | ||||