@@ -138,7 +138,7 @@ namespace Tensorflow | |||||
_VerifyGeneratedGradients(in_grads, op); | _VerifyGeneratedGradients(in_grads, op); | ||||
} | } | ||||
if (gate_gradients) | |||||
if (gate_gradients && in_grads.Count(x => x != null) > 1) | |||||
{ | { | ||||
} | } | ||||
@@ -153,9 +153,13 @@ namespace Tensorflow | |||||
var inputs = _NonEagerInputs(op, xs).ToList(); | var inputs = _NonEagerInputs(op, xs).ToList(); | ||||
foreach (var (t_in, in_grad) in Python.zip(inputs, in_grads)) | foreach (var (t_in, in_grad) in Python.zip(inputs, in_grads)) | ||||
{ | { | ||||
if(in_grad.op != null) | |||||
if(in_grad != null) | |||||
{ | { | ||||
in_grad.shape = t_in.shape; | |||||
if(in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE) | |||||
{ | |||||
in_grad.shape = t_in.shape; | |||||
} | |||||
_SetGrad(grads, t_in, in_grad); | _SetGrad(grads, t_in, in_grad); | ||||
} | } | ||||
} | } | ||||
@@ -188,8 +192,8 @@ namespace Tensorflow | |||||
{ | { | ||||
if (!pending_count.ContainsKey(x.op.Name)) | if (!pending_count.ContainsKey(x.op.Name)) | ||||
pending_count[x.op.Name] = 0; | pending_count[x.op.Name] = 0; | ||||
else | |||||
pending_count[x.op.Name] -= 1; | |||||
pending_count[x.op.Name] -= 1; | |||||
var ready = pending_count[x.op.Name] == 0; | var ready = pending_count[x.op.Name] == 0; | ||||
@@ -284,7 +288,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
if (is_stop_op) | if (is_stop_op) | ||||
stop_ops.Add(op); | |||||
stop_ops.Insert(0, op); | |||||
} | } | ||||
stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); | stop_ops.AddRange(stop_gradient_ops.Where(x => !stop_ops.Contains(x))); | ||||
return stop_ops.ToArray(); | return stop_ops.ToArray(); | ||||
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -13,8 +14,67 @@ namespace Tensorflow | |||||
{ | { | ||||
var x = op.inputs[0]; | var x = op.inputs[0]; | ||||
var y = op.inputs[1]; | var y = op.inputs[1]; | ||||
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||||
return (grad, grad); | |||||
return (grad, grad); | |||||
var sx = array_ops.shape(x); | |||||
var sy = array_ops.shape(y); | |||||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||||
var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); | |||||
var r2 = gen_array_ops.reshape(math_ops.reduce_sum(grad, ry), sy); | |||||
return (r1, r2); | |||||
} | |||||
public static (Tensor, Tensor) _IdGrad(Operation op, Tensor grad) | |||||
{ | |||||
return (grad, null); | |||||
} | |||||
public static (Tensor, Tensor) _MulGrad(Operation op, Tensor grad) | |||||
{ | |||||
var x = op.inputs[0]; | |||||
var y = op.inputs[1]; | |||||
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad) && | |||||
new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype)) | |||||
return (gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)); | |||||
var sx = array_ops.shape(x); | |||||
var sy = array_ops.shape(y); | |||||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||||
x = math_ops.conj(x); | |||||
y = math_ops.conj(y); | |||||
var r1 = math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx); | |||||
var r2 = math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry); | |||||
return (gen_array_ops.reshape(r1, sx), gen_array_ops.reshape(r2, sy)); | |||||
} | |||||
public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad) | |||||
{ | |||||
var x = op.inputs[0]; | |||||
var y = op.inputs[1]; | |||||
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) | |||||
return (grad, -grad); | |||||
var sx = array_ops.shape(x); | |||||
var sy = array_ops.shape(y); | |||||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||||
var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); | |||||
var r2 = gen_array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy); | |||||
return (r1, r2); | |||||
} | |||||
public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) | |||||
{ | |||||
return false; | |||||
/*return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) && | |||||
string.Join(",", x.shape).Equals(string.Join(",", grad.shape));*/ | |||||
} | } | ||||
public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | ||||
@@ -27,10 +87,15 @@ namespace Tensorflow | |||||
var input_shape = array_ops.shape(op.inputs[0]); | var input_shape = array_ops.shape(op.inputs[0]); | ||||
ops.colocate_with(input_shape); | ops.colocate_with(input_shape); | ||||
var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); | var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); | ||||
//var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); | |||||
//var grad = array_ops.reshape(grad, output_shape_kept_dims); | |||||
var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); | |||||
grad = gen_array_ops.reshape(grad, output_shape_kept_dims); | |||||
return (gen_array_ops.tile(grad, tile_scaling), null); | |||||
} | |||||
return (grad, grad); | |||||
public static Tensor _safe_shape_div(Tensor x, Tensor y) | |||||
{ | |||||
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1)); | |||||
} | } | ||||
public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad) | public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad) | ||||
@@ -53,5 +118,36 @@ namespace Tensorflow | |||||
return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy)); | return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy)); | ||||
} | } | ||||
public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad) | |||||
{ | |||||
var x = op.inputs[0]; | |||||
var y = op.inputs[1]; | |||||
var z = op.outputs[0]; | |||||
var sx = array_ops.shape(x); | |||||
var sy = array_ops.shape(y); | |||||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||||
x = math_ops.conj(x); | |||||
y = math_ops.conj(y); | |||||
y = math_ops.conj(z); | |||||
var gx = gen_array_ops.reshape(math_ops.reduce_sum(grad * y * gen_math_ops.pow(x, y - 1.0), rx), sx); | |||||
Tensor log_x = null; | |||||
// Avoid false singularity at x = 0 | |||||
if (x.dtype.is_complex()) | |||||
{ | |||||
throw new NotImplementedException("x.dtype.is_complex()"); | |||||
} | |||||
else | |||||
{ | |||||
var x1 = gen_array_ops.log(x); | |||||
var y1 = array_ops.zeros_like(x); | |||||
log_x = array_ops.where(x > 0.0, x1, y1); | |||||
} | |||||
var gy = gen_array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy); | |||||
return (gx, gy); | |||||
} | |||||
} | } | ||||
} | } |
@@ -110,10 +110,18 @@ namespace Tensorflow | |||||
add_to_collection(name, value); | add_to_collection(name, value); | ||||
} | } | ||||
public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | |||||
public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | |||||
TF_DataType[] input_types = null, string name = "", | TF_DataType[] input_types = null, string name = "", | ||||
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | ||||
{ | { | ||||
if (inputs == null) | |||||
inputs = new Tensor[0]; | |||||
foreach ((int idx, Tensor a) in Python.enumerate(inputs)) | |||||
{ | |||||
} | |||||
if (String.IsNullOrEmpty(name)) | if (String.IsNullOrEmpty(name)) | ||||
{ | { | ||||
name = op_type; | name = op_type; | ||||
@@ -122,9 +130,6 @@ namespace Tensorflow | |||||
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | ||||
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | ||||
if (inputs == null) | |||||
inputs = new List<Tensor>(); | |||||
var input_ops = inputs.Select(x => x.op).ToArray(); | var input_ops = inputs.Select(x => x.op).ToArray(); | ||||
var control_inputs = _control_dependencies_for_inputs(input_ops); | var control_inputs = _control_dependencies_for_inputs(input_ops); | ||||
@@ -40,6 +40,7 @@ namespace Tensorflow | |||||
} | } | ||||
var attrs = new Dictionary<string, object>(); | var attrs = new Dictionary<string, object>(); | ||||
var inferred_from = new Dictionary<string, object>(); | |||||
var inputs = new List<Tensor>(); | var inputs = new List<Tensor>(); | ||||
var input_types = new List<TF_DataType>(); | var input_types = new List<TF_DataType>(); | ||||
var base_types = new List<TF_DataType>(); | var base_types = new List<TF_DataType>(); | ||||
@@ -49,8 +50,8 @@ namespace Tensorflow | |||||
// Perform input type inference | // Perform input type inference | ||||
foreach (var input_arg in op_def.InputArg) | foreach (var input_arg in op_def.InputArg) | ||||
{ | { | ||||
var input_arg_name = input_arg.Name; | |||||
var values = keywords[input_arg_name]; | |||||
var input_name = input_arg.Name; | |||||
var values = keywords[input_name]; | |||||
// Goals: | // Goals: | ||||
// * Convert values to Tensors if it contains constants. | // * Convert values to Tensors if it contains constants. | ||||
// * Verify that values is a list if that matches the input_arg's | // * Verify that values is a list if that matches the input_arg's | ||||
@@ -69,14 +70,25 @@ namespace Tensorflow | |||||
if (_IsListParameter(input_arg)) | if (_IsListParameter(input_arg)) | ||||
{ | { | ||||
if (!_IsListValue(values)) | if (!_IsListValue(values)) | ||||
throw new TypeError($"Expected list for '{input_arg_name}' argument to '{op_type_name}' Op, not {values}."); | |||||
throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}."); | |||||
if(input_arg.Type != DataType.DtInvalid) | if(input_arg.Type != DataType.DtInvalid) | ||||
{ | { | ||||
dtype = input_arg.Type; | dtype = input_arg.Type; | ||||
} | } | ||||
else if (!String.IsNullOrEmpty(input_arg.NumberAttr)) | else if (!String.IsNullOrEmpty(input_arg.NumberAttr)) | ||||
{ | { | ||||
if (attrs.ContainsKey(input_arg.TypeAttr)) | |||||
{ | |||||
dtype = (DataType)attrs[input_arg.TypeAttr]; | |||||
} | |||||
else | |||||
{ | |||||
if (values is Tensor[] values1) | |||||
dtype = values1[0].dtype.as_datatype_enum(); | |||||
} | |||||
if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | |||||
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | |||||
} | } | ||||
if(input_arg.IsRef && dtype != DataType.DtInvalid) | if(input_arg.IsRef && dtype != DataType.DtInvalid) | ||||
@@ -89,19 +101,48 @@ namespace Tensorflow | |||||
if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | ||||
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | ||||
if (keywords[input_arg_name] is Tensor) | |||||
if (keywords[input_name] is Tensor) | |||||
{ | { | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
keywords[input_arg_name] = ops.internal_convert_to_tensor(values, name: input_arg_name); | |||||
keywords[input_name] = ops.internal_convert_to_tensor(values, name: input_name); | |||||
} | } | ||||
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | ||||
{ | { | ||||
attrs[input_arg.TypeAttr] = (keywords[input_arg_name] as Tensor).dtype; | |||||
attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype; | |||||
} | |||||
values = new Tensor[] { keywords[input_name] as Tensor }; | |||||
} | |||||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||||
{ | |||||
if (attrs.ContainsKey(input_arg.NumberAttr)) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
attrs[input_arg.NumberAttr] = (values as Tensor[]).Length; | |||||
inferred_from[input_arg.NumberAttr] = input_name; | |||||
var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr); | |||||
if (num_attr.HasMinimum && (values as Tensor[]).Length < num_attr.Minimum) | |||||
throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " + | |||||
$"than minimum length {num_attr.Minimum}"); | |||||
} | |||||
// All tensors must have the same base type. | |||||
if(input_arg.Type != DataType.DtInvalid) | |||||
{ | |||||
} | |||||
else | |||||
{ | |||||
attrs[input_arg.TypeAttr] = base_types[0]; | |||||
inferred_from[input_arg.TypeAttr] = input_name; | |||||
var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr); | |||||
} | } | ||||
values = new Tensor[] { keywords[input_arg_name] as Tensor }; | |||||
} | } | ||||
inputs.AddRange(values as Tensor[]); | inputs.AddRange(values as Tensor[]); | ||||
@@ -125,30 +166,8 @@ namespace Tensorflow | |||||
var key = attr_def.Name; | var key = attr_def.Name; | ||||
if (!attrs.ContainsKey(key)) | if (!attrs.ContainsKey(key)) | ||||
Console.WriteLine($"_apply_op_helper: key '{key}' is not found in '{op_def.Name}' operation's attr_def."); | Console.WriteLine($"_apply_op_helper: key '{key}' is not found in '{op_def.Name}' operation's attr_def."); | ||||
var value = attrs[key]; | |||||
var attr_value = new AttrValue(); | |||||
switch (attr_def.Type) | |||||
{ | |||||
case "string": | |||||
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||||
break; | |||||
case "type": | |||||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||||
break; | |||||
case "bool": | |||||
attr_value.B = (bool)value; | |||||
break; | |||||
case "shape": | |||||
attr_value.Shape = value == null ? | |||||
attr_def.DefaultValue.Shape : | |||||
tensor_util.as_shape((long[])value); | |||||
break; | |||||
default: | |||||
throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||||
} | |||||
attr_protos[key] = attr_value; | |||||
attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); | |||||
} | } | ||||
// Determine output types (possibly using attrs) | // Determine output types (possibly using attrs) | ||||
@@ -167,7 +186,7 @@ namespace Tensorflow | |||||
} | } | ||||
// Add Op to graph | // Add Op to graph | ||||
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||||
var op = g.create_op(op_type_name, inputs.ToArray(), output_types.ToArray(), | |||||
name: scope, | name: scope, | ||||
input_types: input_types.ToArray(), | input_types: input_types.ToArray(), | ||||
attrs: attr_protos, | attrs: attr_protos, | ||||
@@ -182,6 +201,41 @@ namespace Tensorflow | |||||
return v.as_base_dtype().as_datatype_enum(); | return v.as_base_dtype().as_datatype_enum(); | ||||
} | } | ||||
private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value) | |||||
{ | |||||
var attr_value = new AttrValue(); | |||||
switch (attr_def.Type) | |||||
{ | |||||
case "string": | |||||
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||||
break; | |||||
case "type": | |||||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||||
break; | |||||
case "bool": | |||||
attr_value.B = (bool)value; | |||||
break; | |||||
case "float": | |||||
attr_value.F = (float)value; | |||||
break; | |||||
case "int": | |||||
attr_value.I = (int)value; | |||||
if (attr_def.HasMinimum && attr_value.I < attr_def.Minimum) | |||||
throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}."); | |||||
break; | |||||
case "shape": | |||||
attr_value.Shape = value == null ? | |||||
attr_def.DefaultValue.Shape : | |||||
tensor_util.as_shape((long[])value); | |||||
break; | |||||
default: | |||||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | |||||
} | |||||
return attr_value; | |||||
} | |||||
private bool _IsListParameter(ArgDef arg) | private bool _IsListParameter(ArgDef arg) | ||||
{ | { | ||||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | if (!String.IsNullOrEmpty(arg.NumberAttr)) | ||||
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using Google.Protobuf.Collections; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
@@ -74,7 +75,7 @@ namespace Tensorflow | |||||
/// </param> | /// </param> | ||||
/// <param name="original_op"></param> | /// <param name="original_op"></param> | ||||
/// <param name="op_def"></param> | /// <param name="op_def"></param> | ||||
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
{ | { | ||||
Graph = g; | Graph = g; | ||||
@@ -101,7 +102,8 @@ namespace Tensorflow | |||||
if(op_def == null) | if(op_def == null) | ||||
op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
_handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); | |||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
// Initialize self._outputs. | // Initialize self._outputs. | ||||
output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
@@ -118,6 +120,41 @@ namespace Tensorflow | |||||
_control_flow_post_processing(); | _control_flow_post_processing(); | ||||
} | } | ||||
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | |||||
{ | |||||
var grouped_inputs = new List<object>(); | |||||
int i = 0; | |||||
int input_len = 0; | |||||
bool is_sequence = false; | |||||
foreach (var input_arg in op_def.InputArg) | |||||
{ | |||||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||||
{ | |||||
input_len = (int)attrs[input_arg.NumberAttr].I; | |||||
is_sequence = true; | |||||
} | |||||
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
{ | |||||
input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | |||||
is_sequence = true; | |||||
} | |||||
else | |||||
{ | |||||
input_len = 1; | |||||
is_sequence = false; | |||||
} | |||||
if (is_sequence) | |||||
grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray()); | |||||
else | |||||
grouped_inputs.Add(inputs[i]); | |||||
i += input_len; | |||||
} | |||||
return grouped_inputs.ToArray(); | |||||
} | |||||
public object get_attr<T>(string name) | public object get_attr<T>(string name) | ||||
{ | { | ||||
AttrValue x = null; | AttrValue x = null; | ||||
@@ -58,6 +58,22 @@ namespace Tensorflow | |||||
return math_ops.rank_internal(input, name, optimize: true); | return math_ops.rank_internal(input, name, optimize: true); | ||||
} | } | ||||
public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = "") | |||||
{ | |||||
if( x == null && y == null) | |||||
{ | |||||
throw new NotImplementedException("where"); | |||||
} | |||||
else if(x != null && y != null) | |||||
{ | |||||
return gen_array_ops.select(condition, x, y, name); | |||||
} | |||||
else | |||||
{ | |||||
throw new ValueError("x and y must both be non-None or both be None."); | |||||
} | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -128,5 +144,30 @@ namespace Tensorflow | |||||
return null; | return null; | ||||
}); | }); | ||||
} | } | ||||
public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool optimize = true) | |||||
{ | |||||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => | |||||
{ | |||||
name = scope; | |||||
tensor = ops.convert_to_tensor(tensor, name: "tensor"); | |||||
// is_fully_defined return unexpected value. | |||||
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) | |||||
{ | |||||
} | |||||
if(dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT) | |||||
{ | |||||
throw new NotImplementedException("zeros_like"); | |||||
// return zeros(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name); | |||||
} | |||||
else | |||||
{ | |||||
return gen_array_ops.zeros_like(tensor, name: name); | |||||
} | |||||
}); | |||||
} | |||||
} | } | ||||
} | } |
@@ -12,6 +12,20 @@ namespace Tensorflow | |||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
public static Execute _execute = new Execute(); | public static Execute _execute = new Execute(); | ||||
public static Tensor greater<Tx, Ty>(Tx x, Ty y, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Greater", name: name, args: new { x, y }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor less<Tx, Ty>(Tx x, Ty y, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Less", name: name, args: new { x, y }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = "") | public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = "") | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Placeholder", args: new { dtype, shape }); | var _op = _op_def_lib._apply_op_helper("Placeholder", args: new { dtype, shape }); | ||||
@@ -39,6 +53,13 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor log(Tensor x, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Log", name: name, args: new { x }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor rank(Tensor input, string name = "") | public static Tensor rank(Tensor input, string name = "") | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Rank", name: name, args: new { input }); | var _op = _op_def_lib._apply_op_helper("Rank", name: name, args: new { input }); | ||||
@@ -80,6 +101,17 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor where() | |||||
{ | |||||
throw new NotImplementedException("where"); | |||||
} | |||||
public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = "") | public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = "") | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type }); | var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type }); | ||||
@@ -91,5 +123,17 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type }); | var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor tile(Tensor input, Tensor multiples, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor zeros_like(Tensor x, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("ZerosLike", name, new { x }); | |||||
return _op.outputs[0]; | |||||
} | |||||
} | } | ||||
} | } |
@@ -24,7 +24,7 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor sub(Tensor x, Tensor y, string name = "") | |||||
public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = "") | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | ||||
@@ -52,6 +52,13 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor floor_div(Tensor x, Tensor y, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("FloorDiv", name, args: new { x, y }); | |||||
return _op.outputs[0]; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Multiply the matrix "a" by the matrix "b". | /// Multiply the matrix "a" by the matrix "b". | ||||
/// </summary> | /// </summary> | ||||
@@ -68,9 +75,23 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor pow(Tensor x, double y) | |||||
/// <summary> | |||||
/// Returns the max of x and y (i.e. x > y ? x : y) element-wise. | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="y"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor maximum<T1, T2>(T1 x, T2 y, string name = "") | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("Maximum", name, args: new { x, y }); | |||||
return _op.outputs[0]; | |||||
} | |||||
public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = "") | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Pow", args: new { x, y }); | |||||
var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y }); | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
@@ -104,6 +104,14 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor floordiv(Tensor x, Tensor y, string name = "") | |||||
{ | |||||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "floordiv", new object[] { }), scope => | |||||
{ | |||||
return gen_math_ops.floor_div(x, y, name); | |||||
}); | |||||
} | |||||
public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true) | public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true) | ||||
{ | { | ||||
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope => | return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope => | ||||
@@ -23,15 +23,10 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor operator -(Tensor t1) | |||||
{ | |||||
return gen_math_ops.neg(t1); | |||||
} | |||||
public static Tensor operator -(Tensor t1, Tensor t2) | |||||
{ | |||||
return gen_math_ops.sub(t1, t2); | |||||
} | |||||
public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1); | |||||
public static Tensor operator -(Tensor t1, Tensor t2) => gen_math_ops.sub(t1, t2); | |||||
public static Tensor operator -(Tensor t1, int t2) => gen_math_ops.sub(t1, t2); | |||||
public static Tensor operator -(Tensor t1, double t2) => gen_math_ops.sub(t1, t2); | |||||
public static Tensor operator *(double x, Tensor y) | public static Tensor operator *(double x, Tensor y) | ||||
{ | { | ||||
@@ -84,5 +79,10 @@ namespace Tensorflow | |||||
return gen_math_ops.floor_mod(x, y, scope); | return gen_math_ops.floor_mod(x, y, scope); | ||||
}); | }); | ||||
} | } | ||||
public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); | |||||
public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); | |||||
public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); | |||||
public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); | |||||
} | } | ||||
} | } |
@@ -102,7 +102,7 @@ namespace Tensorflow | |||||
/// </param> | /// </param> | ||||
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | ||||
/// <returns>A wrapped TF_Operation*.</returns> | /// <returns>A wrapped TF_Operation*.</returns> | ||||
public static IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs, Operation[] control_inputs) | |||||
public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
{ | { | ||||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | ||||
@@ -111,15 +111,12 @@ namespace Tensorflow | |||||
{ | { | ||||
foreach (var op_input in inputs) | foreach (var op_input in inputs) | ||||
{ | { | ||||
bool isList = false; | |||||
if (!isList) | |||||
{ | |||||
c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | |||||
} | |||||
if (op_input is Tensor[] op_inputs) | |||||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Length); | |||||
else if (op_input is Tensor op_input1) | |||||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||||
else | else | ||||
{ | |||||
c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); | |||||
} | |||||
throw new NotImplementedException("_create_c_op"); | |||||
} | } | ||||
} | } | ||||
@@ -291,17 +288,28 @@ namespace Tensorflow | |||||
return (oper, out_grads) => | return (oper, out_grads) => | ||||
{ | { | ||||
Console.WriteLine($"get_gradient_function: {oper.type} '{oper.Name}'"); | |||||
switch (oper.type) | switch (oper.type) | ||||
{ | { | ||||
case "Add": | case "Add": | ||||
return math_grad._AddGrad(oper, out_grads); | return math_grad._AddGrad(oper, out_grads); | ||||
case "Identity": | |||||
return math_grad._IdGrad(oper, out_grads); | |||||
case "Mul": | |||||
return math_grad._MulGrad(oper, out_grads); | |||||
case "Sum": | case "Sum": | ||||
return math_grad._SumGrad(oper, out_grads); | return math_grad._SumGrad(oper, out_grads); | ||||
case "Sub": | |||||
return math_grad._SubGrad(oper, out_grads); | |||||
case "Pow": | |||||
return math_grad._PowGrad(oper, out_grads); | |||||
case "RealDiv": | case "RealDiv": | ||||
return math_grad._RealDivGrad(oper, out_grads); | return math_grad._RealDivGrad(oper, out_grads); | ||||
default: | default: | ||||
throw new NotImplementedException($"get_gradient_function {oper.type}"); | throw new NotImplementedException($"get_gradient_function {oper.type}"); | ||||
} | } | ||||
/*var result = typeof(math_grad).GetMethod($"_{op.type}Grad").Invoke(null, new object[] { op, out_grads }); | /*var result = typeof(math_grad).GetMethod($"_{op.type}Grad").Invoke(null, new object[] { op, out_grads }); | ||||
var p1 = result.GetType().GetProperty("Item1"); | var p1 = result.GetType().GetProperty("Item1"); | ||||
var p2 = result.GetType().GetProperty("Item2"); | var p2 = result.GetType().GetProperty("Item2"); | ||||