Browse Source

Add shape and dtype to KerasTensor

tags/v0.110.4-Transformer-Model
Haiping Chen 2 years ago
parent
commit
ed1a8d2edf
3 changed files with 52 additions and 26 deletions
  1. +31
    -18
      src/TensorFlowNET.Core/Operations/array_ops.cs
  2. +20
    -7
      src/TensorFlowNET.Core/Tensors/KerasTensor.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Index.cs

+ 31
- 18
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -603,7 +603,17 @@ namespace Tensorflow
}
}

return gen_array_ops.shape(input, name: name, out_type: out_type);
return tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input)
{
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
out_type = op.get_attr<TF_DataType>("out_type")
}
}.SetAttributes(new
{
out_type
})).First();
});
}

@@ -703,23 +713,26 @@ namespace Tensorflow
int new_axis_mask = 0,
int shrink_axis_mask = 0,
string name = null)
{
var op = gen_array_ops.strided_slice(
input: input_,
begin: begin,
end: end,
strides: strides,
begin_mask: begin_mask,
end_mask: end_mask,
ellipsis_mask: ellipsis_mask,
new_axis_mask: new_axis_mask,
shrink_axis_mask: shrink_axis_mask,
name: name);

string parent_name = name;

return op;
}
=> tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input_, begin, end, strides)
{
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
Index = op.get_attr<TF_DataType>("Index"),
begin_mask = op.get_attr<long>("begin_mask"),
end_mask = op.get_attr<long>("end_mask"),
ellipsis_mask = op.get_attr<long>("ellipsis_mask"),
new_axis_mask = op.get_attr<long>("new_axis_mask"),
shrink_axis_mask = op.get_attr<long>("shrink_axis_mask")
}
}.SetAttributes(new
{
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
}));

/// <summary>
/// Returns the gradient of `StridedSlice`.


+ 20
- 7
src/TensorFlowNET.Core/Tensors/KerasTensor.cs View File

@@ -5,12 +5,17 @@
/// </summary>
public class KerasTensor
{
private Tensor _tensor;
public void SetTensor(Tensors tensor)
=> _tensor = tensor;
private Tensors _inferred_value;
public Tensors inferred_value
{
get => _inferred_value;
set => _inferred_value = value;
}

private TensorSpec _type_spec;
private string _name;
private TensorSpec _type_spec;
public Shape shape => _type_spec.shape;
public TF_DataType dtype => _type_spec.dtype;

public KerasTensor(TensorSpec type_spec, string name = null)
{
@@ -22,15 +27,23 @@ public class KerasTensor
{
var type_spec = tensor.ToTensorSpec();
var kt = new KerasTensor(type_spec, name: tensor.name);
kt.SetTensor(tensor);
kt.inferred_value = tensor;
return kt;
}

public override string ToString()
=> _inferred_value.Length switch
{
> 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"<KerasTensor: shape={x.shape} dtype={x.dtype}>")) + "]",
1 => $"<KerasTensor: shape={_inferred_value.shape} dtype={_inferred_value.dtype}>",
_ => _inferred_value.ToString(),
};

public static implicit operator Tensors(KerasTensor kt)
=> kt._tensor;
=> kt._inferred_value;

public static implicit operator Tensor(KerasTensor kt)
=> kt._tensor;
=> kt._inferred_value;

public static implicit operator KerasTensor(Tensor tensor)
=> from_tensor(tensor);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Index.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow
array_ops.stack(args.End),
array_ops.stack(args.Strides));

return gen_array_ops.strided_slice(
return array_ops.strided_slice(
this,
packed_begin,
packed_end,


Loading…
Cancel
Save