@@ -603,7 +603,17 @@ namespace Tensorflow | |||
} | |||
} | |||
return gen_array_ops.shape(input, name: name, out_type: out_type); | |||
return tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input) | |||
{ | |||
GetGradientAttrs = (op) => new | |||
{ | |||
T = op.get_attr<TF_DataType>("T"), | |||
out_type = op.get_attr<TF_DataType>("out_type") | |||
} | |||
}.SetAttributes(new | |||
{ | |||
out_type | |||
})).First(); | |||
}); | |||
} | |||
@@ -703,23 +713,26 @@ namespace Tensorflow | |||
int new_axis_mask = 0, | |||
int shrink_axis_mask = 0, | |||
string name = null) | |||
{ | |||
var op = gen_array_ops.strided_slice( | |||
input: input_, | |||
begin: begin, | |||
end: end, | |||
strides: strides, | |||
begin_mask: begin_mask, | |||
end_mask: end_mask, | |||
ellipsis_mask: ellipsis_mask, | |||
new_axis_mask: new_axis_mask, | |||
shrink_axis_mask: shrink_axis_mask, | |||
name: name); | |||
string parent_name = name; | |||
return op; | |||
} | |||
=> tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input_, begin, end, strides) | |||
{ | |||
GetGradientAttrs = (op) => new | |||
{ | |||
T = op.get_attr<TF_DataType>("T"), | |||
Index = op.get_attr<TF_DataType>("Index"), | |||
begin_mask = op.get_attr<long>("begin_mask"), | |||
end_mask = op.get_attr<long>("end_mask"), | |||
ellipsis_mask = op.get_attr<long>("ellipsis_mask"), | |||
new_axis_mask = op.get_attr<long>("new_axis_mask"), | |||
shrink_axis_mask = op.get_attr<long>("shrink_axis_mask") | |||
} | |||
}.SetAttributes(new | |||
{ | |||
begin_mask, | |||
end_mask, | |||
ellipsis_mask, | |||
new_axis_mask, | |||
shrink_axis_mask | |||
})); | |||
/// <summary> | |||
/// Returns the gradient of `StridedSlice`. | |||
@@ -5,12 +5,17 @@ | |||
/// </summary> | |||
public class KerasTensor | |||
{ | |||
private Tensor _tensor; | |||
public void SetTensor(Tensors tensor) | |||
=> _tensor = tensor; | |||
private Tensors _inferred_value; | |||
public Tensors inferred_value | |||
{ | |||
get => _inferred_value; | |||
set => _inferred_value = value; | |||
} | |||
private TensorSpec _type_spec; | |||
private string _name; | |||
private TensorSpec _type_spec; | |||
public Shape shape => _type_spec.shape; | |||
public TF_DataType dtype => _type_spec.dtype; | |||
public KerasTensor(TensorSpec type_spec, string name = null) | |||
{ | |||
@@ -22,15 +27,23 @@ public class KerasTensor | |||
{ | |||
var type_spec = tensor.ToTensorSpec(); | |||
var kt = new KerasTensor(type_spec, name: tensor.name); | |||
kt.SetTensor(tensor); | |||
kt.inferred_value = tensor; | |||
return kt; | |||
} | |||
public override string ToString() | |||
=> _inferred_value.Length switch | |||
{ | |||
> 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"<KerasTensor: shape={x.shape} dtype={x.dtype}>")) + "]", | |||
1 => $"<KerasTensor: shape={_inferred_value.shape} dtype={_inferred_value.dtype}>", | |||
_ => _inferred_value.ToString(), | |||
}; | |||
public static implicit operator Tensors(KerasTensor kt) | |||
=> kt._tensor; | |||
=> kt._inferred_value; | |||
public static implicit operator Tensor(KerasTensor kt) | |||
=> kt._tensor; | |||
=> kt._inferred_value; | |||
public static implicit operator KerasTensor(Tensor tensor) | |||
=> from_tensor(tensor); | |||
@@ -42,7 +42,7 @@ namespace Tensorflow | |||
array_ops.stack(args.End), | |||
array_ops.stack(args.Strides)); | |||
return gen_array_ops.strided_slice( | |||
return array_ops.strided_slice( | |||
this, | |||
packed_begin, | |||
packed_end, | |||