@@ -54,7 +54,7 @@ namespace Tensorflow | |||
clear_devices, | |||
import_scope).Item1; | |||
public (MetaGraphDef, Dictionary<string, RefVariable>) export_meta_graph(string filename = "", | |||
public (MetaGraphDef, Dictionary<string, VariableV1>) export_meta_graph(string filename = "", | |||
bool as_text = false, | |||
bool clear_devices = false, | |||
bool clear_extraneous_savers = false, | |||
@@ -167,7 +167,7 @@ namespace Tensorflow | |||
/// <param name="strip_default_attrs"></param> | |||
/// <param name="meta_info_def"></param> | |||
/// <returns></returns> | |||
public static (MetaGraphDef, Dictionary<string, RefVariable>) export_scoped_meta_graph(string filename = "", | |||
public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_graph(string filename = "", | |||
GraphDef graph_def = null, | |||
bool as_text = false, | |||
string unbound_inputs_col_name = "unbound_inputs", | |||
@@ -179,8 +179,8 @@ namespace Tensorflow | |||
{ | |||
var graph = ops.get_default_graph(); | |||
var var_list = new Dictionary<string, RefVariable>(); | |||
var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES); | |||
var var_list = new Dictionary<string, VariableV1>(); | |||
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES); | |||
if (variables != null) | |||
{ |
@@ -190,6 +190,26 @@ namespace Tensorflow.Gradients | |||
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | |||
} | |||
[RegisterGradient("Pad")] | |||
public static Tensor[] _PadGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var x = op.inputs[0]; | |||
var a = op.inputs[1]; | |||
var pad_before = array_ops.slice(a, new[] { 0, 0 }, | |||
new[] { array_ops.stack(new object[] { array_ops.rank(x), 1 }) }); | |||
// Make it a 1-D tensor. | |||
var begin = array_ops.reshape(pad_before, new[] { -1 }); | |||
var sizes = array_ops.shape(x); | |||
var x_grad = array_ops.slice(grad, new[] { begin }, new[] { sizes }); | |||
if (len(op.inputs) == 3) | |||
return new Tensor[] { x_grad, null, null }; | |||
else | |||
return new Tensor[] { x_grad, null }; | |||
} | |||
[RegisterGradient("Squeeze")] | |||
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) | |||
{ | |||
@@ -36,56 +36,54 @@ namespace Tensorflow.Gradients | |||
/// </summary> | |||
/// <returns></returns> | |||
[RegisterGradient("Switch")] | |||
public Tensor[] _SwitchGrad(Operation op, Tensor[] grads) | |||
public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var graph = ops.get_default_graph(); | |||
var op_ctxt = op._get_control_flow_context(); | |||
var grad_ctxt = graph._get_control_flow_context(); | |||
switch (op_ctxt) | |||
{ | |||
case WhileContext cwhile: | |||
throw new NotImplementedException("_SwitchGrad WhileContext"); | |||
case CondContext ccond: | |||
{ | |||
var zero_grad = grads[1 - op_ctxt.branch]; | |||
// At this point, we have created zero_grad guarded by the right switch. | |||
// Unfortunately, we may still get None here for not trainable data types. | |||
if(zero_grad == null) | |||
{ | |||
throw new NotImplementedException("_SwitchGrad CondContext zero_grad"); | |||
} | |||
return new Tensor[] | |||
{ | |||
merge(grads, name: "cond_grad")[0], | |||
null | |||
}; | |||
} | |||
default: | |||
throw new NotImplementedException("_SwitchGrad WhileContext"); | |||
} | |||
throw new NotImplementedException("_SwitchGrad"); | |||
//graph = ops.get_default_graph() | |||
//# pylint: disable=protected-access | |||
//op_ctxt = op._get_control_flow_context() | |||
//grad_ctxt = graph._get_control_flow_context() | |||
//# pylint: enable=protected-access | |||
//if isinstance(op_ctxt, WhileContext): | |||
// merge_grad = grad_ctxt.grad_state.switch_map.get(op) | |||
// if merge_grad is not None: | |||
// # This is the second time this Switch is visited. It comes from | |||
// # the non-exit branch of the Switch, so update the second input | |||
// # to the Merge. | |||
// # TODO(yuanbyu): Perform shape inference with this new input. | |||
// if grad[1] is not None: | |||
// # pylint: disable=protected-access | |||
// control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], | |||
// enforce_shape_invariant=False) | |||
// # pylint: enable=protected-access | |||
// return None, None | |||
// elif grad[0] is not None: | |||
// # This is the first time this Switch is visited. It comes from | |||
// # the Exit branch, which is grad[0]. grad[1] is empty at this point. | |||
// # Use grad[0] for both inputs to merge for now, but update the second | |||
// # input of merge when we see this Switch the second time. | |||
// merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] | |||
// grad_ctxt.grad_state.switch_map[op] = merge_grad | |||
// return merge_grad, None | |||
// else: | |||
// # This is the first time this Switch is visited. It comes from the | |||
// # Identity branch. Such a Switch has `None` gradient for the Exit branch, | |||
// # meaning the output is not differentiable. | |||
// return None, None | |||
//elif isinstance(op_ctxt, CondContext): | |||
// zero_grad = grad[1 - op_ctxt.branch] | |||
// # At this point, we have created zero_grad guarded by the right switch. | |||
// # Unfortunately, we may still get None here for not trainable data types. | |||
// if zero_grad is None: | |||
// # For resource variables we get None always on the other branch, so bypass | |||
// # this. | |||
// if op.inputs[0].dtype == dtypes.resource: | |||
// return merge( | |||
// [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None | |||
// return None, None | |||
// return merge(grad, name="cond_grad")[0], None | |||
//else: | |||
// false_grad = switch(grad[0], op.inputs[1])[0] | |||
// true_grad = switch(grad[1], op.inputs[1])[1] | |||
// return merge([false_grad, true_grad])[0], None | |||
} | |||
/// <summary> | |||
/// Returns the value of an available element of `inputs`. | |||
/// </summary> | |||
/// <param name="inputs"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
internal static Tensor[] merge(Tensor[] inputs, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "Merge", inputs), scope => | |||
{ | |||
name = scope; | |||
if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length) | |||
return gen_control_flow_ops.ref_merge(inputs, name: name); | |||
else | |||
return gen_control_flow_ops.merge(inputs, name: name); | |||
}); | |||
} | |||
/// <summary> | |||
@@ -108,7 +108,7 @@ namespace Tensorflow | |||
{ | |||
// generate gradient subgraph for op. | |||
var op = queue.Dequeue(); | |||
if(tf.get_default_graph()._nodes_by_name.Count > 18577) | |||
if(tf.get_default_graph()._nodes_by_name.Count >= 20611) | |||
{ | |||
} | |||
@@ -0,0 +1,54 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Gradients | |||
{ | |||
[RegisterGradient("image_grad")] | |||
public class image_grad | |||
{ | |||
[RegisterGradient("ResizeNearestNeighbor")] | |||
public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
var image = op.inputs[0]; | |||
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | |||
Tensor image_shape = null; | |||
if (shape.is_fully_defined()) | |||
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); | |||
else | |||
image_shape = array_ops.shape(image)["1:3"]; | |||
grad = gen_image_ops.resize_nearest_neighbor_grad( | |||
grad, | |||
image_shape, | |||
align_corners: op.get_attr<bool>("align_corners"), | |||
half_pixel_centers: op.get_attr<bool>("half_pixel_centers")); | |||
return new Tensor[] | |||
{ | |||
grad, | |||
null | |||
}; | |||
} | |||
} | |||
} |
@@ -440,6 +440,9 @@ namespace Tensorflow | |||
case List<VariableV1> list: | |||
t = list.Select(x => (T)(object)x).ToList(); | |||
break; | |||
case List<ResourceVariable> list: | |||
t = list.Select(x => (T)(object)x).ToList(); | |||
break; | |||
case List<RefVariable> list: | |||
t = list.Select(x => (T)(object)x).ToList(); | |||
break; | |||
@@ -518,7 +518,7 @@ namespace Tensorflow | |||
inputs = inputs.Select(inp => | |||
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | |||
.ToArray(); | |||
return gen_control_flow_ops.merge(inputs, name).Item1; | |||
return gen_control_flow_ops.merge(inputs, name)[0]; | |||
}); | |||
} | |||
@@ -148,11 +148,18 @@ namespace Tensorflow | |||
return new []{_op.outputs[0], _op.outputs[1]}; | |||
} | |||
public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | |||
public static Tensor[] ref_merge(Tensor[] inputs, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); | |||
return _op.outputs; | |||
} | |||
public static Tensor[] merge(Tensor[] inputs, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | |||
return (_op.outputs[0], _op.outputs[1]); | |||
return _op.outputs; | |||
} | |||
} | |||
} |
@@ -183,5 +183,19 @@ namespace Tensorflow | |||
return op.output; | |||
} | |||
public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false, | |||
bool half_pixel_centers = false, string name = null) | |||
{ | |||
var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new | |||
{ | |||
grads, | |||
size, | |||
align_corners, | |||
half_pixel_centers | |||
}); | |||
return op.output; | |||
} | |||
} | |||
} |