@@ -54,7 +54,7 @@ namespace Tensorflow | |||||
clear_devices, | clear_devices, | ||||
import_scope).Item1; | import_scope).Item1; | ||||
public (MetaGraphDef, Dictionary<string, RefVariable>) export_meta_graph(string filename = "", | |||||
public (MetaGraphDef, Dictionary<string, VariableV1>) export_meta_graph(string filename = "", | |||||
bool as_text = false, | bool as_text = false, | ||||
bool clear_devices = false, | bool clear_devices = false, | ||||
bool clear_extraneous_savers = false, | bool clear_extraneous_savers = false, | ||||
@@ -167,7 +167,7 @@ namespace Tensorflow | |||||
/// <param name="strip_default_attrs"></param> | /// <param name="strip_default_attrs"></param> | ||||
/// <param name="meta_info_def"></param> | /// <param name="meta_info_def"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static (MetaGraphDef, Dictionary<string, RefVariable>) export_scoped_meta_graph(string filename = "", | |||||
public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_graph(string filename = "", | |||||
GraphDef graph_def = null, | GraphDef graph_def = null, | ||||
bool as_text = false, | bool as_text = false, | ||||
string unbound_inputs_col_name = "unbound_inputs", | string unbound_inputs_col_name = "unbound_inputs", | ||||
@@ -179,8 +179,8 @@ namespace Tensorflow | |||||
{ | { | ||||
var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
var var_list = new Dictionary<string, RefVariable>(); | |||||
var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES); | |||||
var var_list = new Dictionary<string, VariableV1>(); | |||||
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES); | |||||
if (variables != null) | if (variables != null) | ||||
{ | { |
@@ -190,6 +190,26 @@ namespace Tensorflow.Gradients | |||||
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | ||||
} | } | ||||
[RegisterGradient("Pad")] | |||||
public static Tensor[] _PadGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
var grad = grads[0]; | |||||
var x = op.inputs[0]; | |||||
var a = op.inputs[1]; | |||||
var pad_before = array_ops.slice(a, new[] { 0, 0 }, | |||||
new[] { array_ops.stack(new object[] { array_ops.rank(x), 1 }) }); | |||||
// Make it a 1-D tensor. | |||||
var begin = array_ops.reshape(pad_before, new[] { -1 }); | |||||
var sizes = array_ops.shape(x); | |||||
var x_grad = array_ops.slice(grad, new[] { begin }, new[] { sizes }); | |||||
if (len(op.inputs) == 3) | |||||
return new Tensor[] { x_grad, null, null }; | |||||
else | |||||
return new Tensor[] { x_grad, null }; | |||||
} | |||||
[RegisterGradient("Squeeze")] | [RegisterGradient("Squeeze")] | ||||
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) | public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
@@ -36,56 +36,54 @@ namespace Tensorflow.Gradients | |||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[RegisterGradient("Switch")] | [RegisterGradient("Switch")] | ||||
public Tensor[] _SwitchGrad(Operation op, Tensor[] grads) | |||||
public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads) | |||||
{ | { | ||||
var grad = grads[0]; | |||||
var graph = ops.get_default_graph(); | |||||
var op_ctxt = op._get_control_flow_context(); | |||||
var grad_ctxt = graph._get_control_flow_context(); | |||||
switch (op_ctxt) | |||||
{ | |||||
case WhileContext cwhile: | |||||
throw new NotImplementedException("_SwitchGrad WhileContext"); | |||||
case CondContext ccond: | |||||
{ | |||||
var zero_grad = grads[1 - op_ctxt.branch]; | |||||
// At this point, we have created zero_grad guarded by the right switch. | |||||
// Unfortunately, we may still get None here for not trainable data types. | |||||
if(zero_grad == null) | |||||
{ | |||||
throw new NotImplementedException("_SwitchGrad CondContext zero_grad"); | |||||
} | |||||
return new Tensor[] | |||||
{ | |||||
merge(grads, name: "cond_grad")[0], | |||||
null | |||||
}; | |||||
} | |||||
default: | |||||
throw new NotImplementedException("_SwitchGrad WhileContext"); | |||||
} | |||||
throw new NotImplementedException("_SwitchGrad"); | throw new NotImplementedException("_SwitchGrad"); | ||||
//graph = ops.get_default_graph() | |||||
//# pylint: disable=protected-access | |||||
//op_ctxt = op._get_control_flow_context() | |||||
//grad_ctxt = graph._get_control_flow_context() | |||||
//# pylint: enable=protected-access | |||||
//if isinstance(op_ctxt, WhileContext): | |||||
// merge_grad = grad_ctxt.grad_state.switch_map.get(op) | |||||
// if merge_grad is not None: | |||||
// # This is the second time this Switch is visited. It comes from | |||||
// # the non-exit branch of the Switch, so update the second input | |||||
// # to the Merge. | |||||
// # TODO(yuanbyu): Perform shape inference with this new input. | |||||
// if grad[1] is not None: | |||||
// # pylint: disable=protected-access | |||||
// control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], | |||||
// enforce_shape_invariant=False) | |||||
// # pylint: enable=protected-access | |||||
// return None, None | |||||
// elif grad[0] is not None: | |||||
// # This is the first time this Switch is visited. It comes from | |||||
// # the Exit branch, which is grad[0]. grad[1] is empty at this point. | |||||
// # Use grad[0] for both inputs to merge for now, but update the second | |||||
// # input of merge when we see this Switch the second time. | |||||
// merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] | |||||
// grad_ctxt.grad_state.switch_map[op] = merge_grad | |||||
// return merge_grad, None | |||||
// else: | |||||
// # This is the first time this Switch is visited. It comes from the | |||||
// # Identity branch. Such a Switch has `None` gradient for the Exit branch, | |||||
// # meaning the output is not differentiable. | |||||
// return None, None | |||||
//elif isinstance(op_ctxt, CondContext): | |||||
// zero_grad = grad[1 - op_ctxt.branch] | |||||
// # At this point, we have created zero_grad guarded by the right switch. | |||||
// # Unfortunately, we may still get None here for not trainable data types. | |||||
// if zero_grad is None: | |||||
// # For resource variables we get None always on the other branch, so bypass | |||||
// # this. | |||||
// if op.inputs[0].dtype == dtypes.resource: | |||||
// return merge( | |||||
// [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None | |||||
// return None, None | |||||
// return merge(grad, name="cond_grad")[0], None | |||||
//else: | |||||
// false_grad = switch(grad[0], op.inputs[1])[0] | |||||
// true_grad = switch(grad[1], op.inputs[1])[1] | |||||
// return merge([false_grad, true_grad])[0], None | |||||
} | |||||
/// <summary> | |||||
/// Returns the value of an available element of `inputs`. | |||||
/// </summary> | |||||
/// <param name="inputs"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
internal static Tensor[] merge(Tensor[] inputs, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "Merge", inputs), scope => | |||||
{ | |||||
name = scope; | |||||
if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length) | |||||
return gen_control_flow_ops.ref_merge(inputs, name: name); | |||||
else | |||||
return gen_control_flow_ops.merge(inputs, name: name); | |||||
}); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -108,7 +108,7 @@ namespace Tensorflow | |||||
{ | { | ||||
// generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
if(tf.get_default_graph()._nodes_by_name.Count > 18577) | |||||
if(tf.get_default_graph()._nodes_by_name.Count >= 20611) | |||||
{ | { | ||||
} | } | ||||
@@ -0,0 +1,54 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Framework; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Gradients | |||||
{ | |||||
[RegisterGradient("image_grad")] | |||||
public class image_grad | |||||
{ | |||||
[RegisterGradient("ResizeNearestNeighbor")] | |||||
public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
var grad = grads[0]; | |||||
var image = op.inputs[0]; | |||||
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | |||||
Tensor image_shape = null; | |||||
if (shape.is_fully_defined()) | |||||
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); | |||||
else | |||||
image_shape = array_ops.shape(image)["1:3"]; | |||||
grad = gen_image_ops.resize_nearest_neighbor_grad( | |||||
grad, | |||||
image_shape, | |||||
align_corners: op.get_attr<bool>("align_corners"), | |||||
half_pixel_centers: op.get_attr<bool>("half_pixel_centers")); | |||||
return new Tensor[] | |||||
{ | |||||
grad, | |||||
null | |||||
}; | |||||
} | |||||
} | |||||
} |
@@ -440,6 +440,9 @@ namespace Tensorflow | |||||
case List<VariableV1> list: | case List<VariableV1> list: | ||||
t = list.Select(x => (T)(object)x).ToList(); | t = list.Select(x => (T)(object)x).ToList(); | ||||
break; | break; | ||||
case List<ResourceVariable> list: | |||||
t = list.Select(x => (T)(object)x).ToList(); | |||||
break; | |||||
case List<RefVariable> list: | case List<RefVariable> list: | ||||
t = list.Select(x => (T)(object)x).ToList(); | t = list.Select(x => (T)(object)x).ToList(); | ||||
break; | break; | ||||
@@ -518,7 +518,7 @@ namespace Tensorflow | |||||
inputs = inputs.Select(inp => | inputs = inputs.Select(inp => | ||||
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | ||||
.ToArray(); | .ToArray(); | ||||
return gen_control_flow_ops.merge(inputs, name).Item1; | |||||
return gen_control_flow_ops.merge(inputs, name)[0]; | |||||
}); | }); | ||||
} | } | ||||
@@ -148,11 +148,18 @@ namespace Tensorflow | |||||
return new []{_op.outputs[0], _op.outputs[1]}; | return new []{_op.outputs[0], _op.outputs[1]}; | ||||
} | } | ||||
public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | |||||
public static Tensor[] ref_merge(Tensor[] inputs, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); | |||||
return _op.outputs; | |||||
} | |||||
public static Tensor[] merge(Tensor[] inputs, string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | ||||
return (_op.outputs[0], _op.outputs[1]); | |||||
return _op.outputs; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -183,5 +183,19 @@ namespace Tensorflow | |||||
return op.output; | return op.output; | ||||
} | } | ||||
public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false, | |||||
bool half_pixel_centers = false, string name = null) | |||||
{ | |||||
var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new | |||||
{ | |||||
grads, | |||||
size, | |||||
align_corners, | |||||
half_pixel_centers | |||||
}); | |||||
return op.output; | |||||
} | |||||
} | } | ||||
} | } |