@@ -31,6 +31,6 @@ namespace Tensorflow | |||||
public Tensor reshape(Tensor tensor, | public Tensor reshape(Tensor tensor, | ||||
object[] shape, | object[] shape, | ||||
string name = null) | string name = null) | ||||
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name); | |||||
=> array_ops.reshape(tensor, shape, name); | |||||
} | } | ||||
} | } |
@@ -23,7 +23,7 @@ namespace Tensorflow | |||||
=> gen_array_ops.tile(input, multiples, name); | => gen_array_ops.tile(input, multiples, name); | ||||
public Tensor tile(Tensor input, object[] multiples, string name = null) | public Tensor tile(Tensor input, object[] multiples, string name = null) | ||||
=> gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name); | |||||
=> array_ops.tile(input, multiples, name); | |||||
public Tensor tile(Tensor input, Shape multiples, string name = null) | public Tensor tile(Tensor input, Shape multiples, string name = null) | ||||
{ | { | ||||
@@ -5,4 +5,5 @@ global using System.Collections; | |||||
global using System.Data; | global using System.Data; | ||||
global using System.Linq; | global using System.Linq; | ||||
global using Tensorflow.Keras.Engine; | global using Tensorflow.Keras.Engine; | ||||
global using Tensorflow.Framework.Models; | |||||
global using Tensorflow.Framework.Models; | |||||
global using static Tensorflow.Binding; |
@@ -30,21 +30,32 @@ public class KerasTensor | |||||
public static KerasTensor from_tensor(Tensor tensor) | public static KerasTensor from_tensor(Tensor tensor) | ||||
{ | { | ||||
var type_spec = tensor.ToTensorSpec(); | var type_spec = tensor.ToTensorSpec(); | ||||
var kt = new KerasTensor(type_spec, name: tensor.name); | |||||
Shape? inferred_value = default; | |||||
if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2) | |||||
{ | |||||
inferred_value = tf.ones(tensor).shape; | |||||
} | |||||
var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name); | |||||
kt.original_tensors = tensor; | kt.original_tensors = tensor; | ||||
return kt; | return kt; | ||||
} | } | ||||
public KerasTensor this[int idx] | |||||
=> _original_tensors.First()[idx]; | |||||
public KerasTensor this[params Slice[] slices] | |||||
=> _original_tensors.First()[slices]; | |||||
public override string ToString() | public override string ToString() | ||||
=> _original_tensors.Length switch | => _original_tensors.Length switch | ||||
{ | { | ||||
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]", | |||||
1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}", | |||||
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]", | |||||
1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}", | |||||
_ => _original_tensors.ToString(), | _ => _original_tensors.ToString(), | ||||
}; | }; | ||||
private string GetInferredValueString() | private string GetInferredValueString() | ||||
=> _inferred_value == null ? "" : ""; | |||||
=> _inferred_value == null ? "" : $" inferred_value={_inferred_value}"; | |||||
public static implicit operator Tensors(KerasTensor kt) | public static implicit operator Tensors(KerasTensor kt) | ||||
=> kt._original_tensors; | => kt._original_tensors; | ||||
@@ -137,7 +137,7 @@ namespace Tensorflow | |||||
if(shape.Length > 1) | if(shape.Length > 1) | ||||
{ | { | ||||
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); | shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); | ||||
if(shapeTensor.ndim > 1) | |||||
if (shapeTensor.ndim > 1) | |||||
{ | { | ||||
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); | shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); | ||||
} | } | ||||
@@ -304,6 +304,10 @@ namespace Tensorflow | |||||
{ | { | ||||
elems_as_tensors.Add(tensor); | elems_as_tensors.Add(tensor); | ||||
} | } | ||||
else if (elem is KerasTensor kt) | |||||
{ | |||||
elems_as_tensors.Add(kt); | |||||
} | |||||
else | else | ||||
{ | { | ||||
var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); | var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); | ||||
@@ -404,7 +408,10 @@ namespace Tensorflow | |||||
=> gen_array_ops.reshape(tensor, shape, name: name); | => gen_array_ops.reshape(tensor, shape, name: name); | ||||
public static Tensor reshape(Tensor tensor, object[] shape, string name = null) | public static Tensor reshape(Tensor tensor, object[] shape, string name = null) | ||||
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name: name); | |||||
{ | |||||
var dims = shape_utils.from_object_array(shape); | |||||
return gen_array_ops.reshape(tensor, dims, name: name); | |||||
} | |||||
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | ||||
{ | { | ||||
@@ -425,6 +432,10 @@ namespace Tensorflow | |||||
return tf_with(ops.name_scope(name, "ones", new { shape }), scope => | return tf_with(ops.name_scope(name, "ones", new { shape }), scope => | ||||
{ | { | ||||
name = scope; | name = scope; | ||||
if (shape._shape_tuple().Length == 0) | |||||
{ | |||||
shape = reshape(shape, new Shape(-1)); | |||||
} | |||||
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); | var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); | ||||
return output; | return output; | ||||
}); | }); | ||||
@@ -647,6 +658,20 @@ namespace Tensorflow | |||||
} | } | ||||
}); | }); | ||||
public static Tensor tile(Tensor input, object[] multiples, string name = null) | |||||
{ | |||||
Shape dims = shape_utils.from_object_array(multiples); | |||||
return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims) | |||||
{ | |||||
GetGradientAttrs = (op) => new | |||||
{ | |||||
T = op.get_attr<TF_DataType>("T"), | |||||
Tmultiples = op.get_attr<TF_DataType>("Tmultiples") | |||||
} | |||||
}); | |||||
} | |||||
public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | ||||
{ | { | ||||
return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => | return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => | ||||
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -13,5 +14,31 @@ namespace Tensorflow | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
public static Shape from_object_array(object[] shape) | |||||
{ | |||||
var dims = shape.Select(x => | |||||
{ | |||||
if (x is KerasTensor kt && kt.inferred_value != null) | |||||
{ | |||||
return kt.inferred_value.as_int_list()[0]; | |||||
} | |||||
else if (x is EagerTensor et && et.dtype == TF_DataType.TF_INT32) | |||||
{ | |||||
return et.ToArray<int>()[0]; | |||||
} | |||||
else if (x is int i) | |||||
{ | |||||
return i; | |||||
} | |||||
else if (x is long l) | |||||
{ | |||||
return l; | |||||
} | |||||
throw new NotImplementedException(); | |||||
}).ToArray(); | |||||
return new Shape(dims); | |||||
} | |||||
} | } | ||||
} | } |
@@ -46,6 +46,9 @@ namespace Tensorflow | |||||
public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
=> array_ops.ones(shape, dtype, name); | => array_ops.ones(shape, dtype, name); | ||||
public Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||||
=> array_ops.ones(shape, dtype, name); | |||||
public Tensor size(Tensor input, | public Tensor size(Tensor input, | ||||
string name = null, | string name = null, | ||||
TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input, | TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input, | ||||
@@ -144,11 +144,18 @@ namespace Tensorflow | |||||
} | } | ||||
if (!graph.building_function) | if (!graph.building_function) | ||||
{ | { | ||||
throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | |||||
// return eager_tensor.AsPlaceholder(name: name); | |||||
// throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | |||||
return eager_tensor.AsPlaceholder(name: name); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
else if (value is KerasTensor kt) | |||||
{ | |||||
if (kt.inferred_value != null) | |||||
{ | |||||
return convert_to_tensor(kt.inferred_value, dtype: kt.dtype, name: name); | |||||
} | |||||
} | |||||
// graph mode | // graph mode | ||||
Tensor ret = value switch | Tensor ret = value switch | ||||
@@ -141,7 +141,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="HDF5-CSharp" Version="1.17.0" /> | <PackageReference Include="HDF5-CSharp" Version="1.17.0" /> | ||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" /> | |||||
<PackageReference Include="SharpZipLib" Version="1.4.2" /> | <PackageReference Include="SharpZipLib" Version="1.4.2" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -41,8 +41,8 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.6.3" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | <PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | ||||
</ItemGroup> | </ItemGroup> | ||||