Browse Source

Fix inferred_value of KerasTensor. #1142

tags/v0.110.4-Transformer-Model
Haiping Chen 2 years ago
parent
commit
6ec39ba3cb
10 changed files with 88 additions and 14 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.reshape.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.tile.cs
  3. +2
    -1
      src/TensorFlowNET.Core/GlobalUsing.cs
  4. +15
    -4
      src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs
  5. +27
    -2
      src/TensorFlowNET.Core/Operations/array_ops.cs
  6. +27
    -0
      src/TensorFlowNET.Core/Tensors/shape_utils.cs
  7. +3
    -0
      src/TensorFlowNET.Core/Tensors/tf.constant.cs
  8. +9
    -2
      src/TensorFlowNET.Core/ops.cs
  9. +1
    -1
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  10. +2
    -2
      test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.reshape.cs View File

@@ -31,6 +31,6 @@ namespace Tensorflow
public Tensor reshape(Tensor tensor, public Tensor reshape(Tensor tensor,
object[] shape, object[] shape,
string name = null) string name = null)
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name);
=> array_ops.reshape(tensor, shape, name);
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.tile.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow
=> gen_array_ops.tile(input, multiples, name); => gen_array_ops.tile(input, multiples, name);


public Tensor tile(Tensor input, object[] multiples, string name = null) public Tensor tile(Tensor input, object[] multiples, string name = null)
=> gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name);
=> array_ops.tile(input, multiples, name);


public Tensor tile(Tensor input, Shape multiples, string name = null) public Tensor tile(Tensor input, Shape multiples, string name = null)
{ {


+ 2
- 1
src/TensorFlowNET.Core/GlobalUsing.cs View File

@@ -5,4 +5,5 @@ global using System.Collections;
global using System.Data; global using System.Data;
global using System.Linq; global using System.Linq;
global using Tensorflow.Keras.Engine; global using Tensorflow.Keras.Engine;
global using Tensorflow.Framework.Models;
global using Tensorflow.Framework.Models;
global using static Tensorflow.Binding;

+ 15
- 4
src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs View File

@@ -30,21 +30,32 @@ public class KerasTensor
public static KerasTensor from_tensor(Tensor tensor) public static KerasTensor from_tensor(Tensor tensor)
{ {
var type_spec = tensor.ToTensorSpec(); var type_spec = tensor.ToTensorSpec();
var kt = new KerasTensor(type_spec, name: tensor.name);
Shape? inferred_value = default;
if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2)
{
inferred_value = tf.ones(tensor).shape;
}
var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name);
kt.original_tensors = tensor; kt.original_tensors = tensor;
return kt; return kt;
} }


public KerasTensor this[int idx]
=> _original_tensors.First()[idx];

public KerasTensor this[params Slice[] slices]
=> _original_tensors.First()[slices];

public override string ToString() public override string ToString()
=> _original_tensors.Length switch => _original_tensors.Length switch
{ {
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]",
1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}",
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]",
1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}",
_ => _original_tensors.ToString(), _ => _original_tensors.ToString(),
}; };


private string GetInferredValueString() private string GetInferredValueString()
=> _inferred_value == null ? "" : "";
=> _inferred_value == null ? "" : $" inferred_value={_inferred_value}";


public static implicit operator Tensors(KerasTensor kt) public static implicit operator Tensors(KerasTensor kt)
=> kt._original_tensors; => kt._original_tensors;


+ 27
- 2
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -137,7 +137,7 @@ namespace Tensorflow
if(shape.Length > 1) if(shape.Length > 1)
{ {
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); shapeTensor = ops.convert_to_tensor(shape, dtypes.int32);
if(shapeTensor.ndim > 1)
if (shapeTensor.ndim > 1)
{ {
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1));
} }
@@ -304,6 +304,10 @@ namespace Tensorflow
{ {
elems_as_tensors.Add(tensor); elems_as_tensors.Add(tensor);
} }
else if (elem is KerasTensor kt)
{
elems_as_tensors.Add(kt);
}
else else
{ {
var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString());
@@ -404,7 +408,10 @@ namespace Tensorflow
=> gen_array_ops.reshape(tensor, shape, name: name); => gen_array_ops.reshape(tensor, shape, name: name);


public static Tensor reshape(Tensor tensor, object[] shape, string name = null) public static Tensor reshape(Tensor tensor, object[] shape, string name = null)
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name: name);
{
var dims = shape_utils.from_object_array(shape);
return gen_array_ops.reshape(tensor, dims, name: name);
}


private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
{ {
@@ -425,6 +432,10 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "ones", new { shape }), scope => return tf_with(ops.name_scope(name, "ones", new { shape }), scope =>
{ {
name = scope; name = scope;
if (shape._shape_tuple().Length == 0)
{
shape = reshape(shape, new Shape(-1));
}
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
return output; return output;
}); });
@@ -647,6 +658,20 @@ namespace Tensorflow
} }
}); });


public static Tensor tile(Tensor input, object[] multiples, string name = null)
{
Shape dims = shape_utils.from_object_array(multiples);

return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims)
{
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
Tmultiples = op.get_attr<TF_DataType>("Tmultiples")
}
});
}

public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{ {
return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope =>


+ 27
- 0
src/TensorFlowNET.Core/Tensors/shape_utils.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Linq; using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -13,5 +14,31 @@ namespace Tensorflow


throw new NotImplementedException(""); throw new NotImplementedException("");
} }

public static Shape from_object_array(object[] shape)
{
var dims = shape.Select(x =>
{
if (x is KerasTensor kt && kt.inferred_value != null)
{
return kt.inferred_value.as_int_list()[0];
}
else if (x is EagerTensor et && et.dtype == TF_DataType.TF_INT32)
{
return et.ToArray<int>()[0];
}
else if (x is int i)
{
return i;
}
else if (x is long l)
{
return l;
}
throw new NotImplementedException();
}).ToArray();

return new Shape(dims);
}
} }
} }

+ 3
- 0
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -46,6 +46,9 @@ namespace Tensorflow
public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.ones(shape, dtype, name); => array_ops.ones(shape, dtype, name);


public Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.ones(shape, dtype, name);

public Tensor size(Tensor input, public Tensor size(Tensor input,
string name = null, string name = null,
TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input, TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input,


+ 9
- 2
src/TensorFlowNET.Core/ops.cs View File

@@ -144,11 +144,18 @@ namespace Tensorflow
} }
if (!graph.building_function) if (!graph.building_function)
{ {
throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
// return eager_tensor.AsPlaceholder(name: name);
// throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
return eager_tensor.AsPlaceholder(name: name);
} }
} }
} }
else if (value is KerasTensor kt)
{
if (kt.inferred_value != null)
{
return convert_to_tensor(kt.inferred_value, dtype: kt.dtype, name: name);
}
}


// graph mode // graph mode
Tensor ret = value switch Tensor ret = value switch


+ 1
- 1
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -141,7 +141,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac


<ItemGroup> <ItemGroup>
<PackageReference Include="HDF5-CSharp" Version="1.17.0" /> <PackageReference Include="HDF5-CSharp" Version="1.17.0" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" />
<PackageReference Include="SharpZipLib" Version="1.4.2" /> <PackageReference Include="SharpZipLib" Version="1.4.2" />
</ItemGroup> </ItemGroup>




+ 2
- 2
test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj View File

@@ -41,8 +41,8 @@


<ItemGroup> <ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" /> <PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.6.3" />
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> <PackageReference Include="MSTest.TestAdapter" Version="2.2.10" />
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> <PackageReference Include="MSTest.TestFramework" Version="2.2.10" />
</ItemGroup> </ItemGroup>


Loading…
Cancel
Save