Browse Source

add _ResizeNearestNeighborGrad and _SwitchGrad

tags/v0.12
Oceania2018 6 years ago
parent
commit
58d714badd
10 changed files with 152 additions and 56 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  3. +20
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  4. +46
    -48
      src/TensorFlowNET.Core/Gradients/control_flow_grad.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  6. +54
    -0
      src/TensorFlowNET.Core/Gradients/image_grad.cs
  7. +3
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  9. +9
    -2
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs
  10. +14
    -0
      src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -54,7 +54,7 @@ namespace Tensorflow
clear_devices, clear_devices,
import_scope).Item1; import_scope).Item1;


public (MetaGraphDef, Dictionary<string, RefVariable>) export_meta_graph(string filename = "",
public (MetaGraphDef, Dictionary<string, VariableV1>) export_meta_graph(string filename = "",
bool as_text = false, bool as_text = false,
bool clear_devices = false, bool clear_devices = false,
bool clear_extraneous_savers = false, bool clear_extraneous_savers = false,


src/TensorFlowNET.Core/Framework/meta_graph.py.cs → src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -167,7 +167,7 @@ namespace Tensorflow
/// <param name="strip_default_attrs"></param> /// <param name="strip_default_attrs"></param>
/// <param name="meta_info_def"></param> /// <param name="meta_info_def"></param>
/// <returns></returns> /// <returns></returns>
public static (MetaGraphDef, Dictionary<string, RefVariable>) export_scoped_meta_graph(string filename = "",
public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_graph(string filename = "",
GraphDef graph_def = null, GraphDef graph_def = null,
bool as_text = false, bool as_text = false,
string unbound_inputs_col_name = "unbound_inputs", string unbound_inputs_col_name = "unbound_inputs",
@@ -179,8 +179,8 @@ namespace Tensorflow
{ {
var graph = ops.get_default_graph(); var graph = ops.get_default_graph();


var var_list = new Dictionary<string, RefVariable>();
var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES);
var var_list = new Dictionary<string, VariableV1>();
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES);


if (variables != null) if (variables != null)
{ {

+ 20
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -190,6 +190,26 @@ namespace Tensorflow.Gradients
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
} }


[RegisterGradient("Pad")]
public static Tensor[] _PadGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var x = op.inputs[0];
var a = op.inputs[1];
var pad_before = array_ops.slice(a, new[] { 0, 0 },
new[] { array_ops.stack(new object[] { array_ops.rank(x), 1 }) });

// Make it a 1-D tensor.
var begin = array_ops.reshape(pad_before, new[] { -1 });
var sizes = array_ops.shape(x);
var x_grad = array_ops.slice(grad, new[] { begin }, new[] { sizes });

if (len(op.inputs) == 3)
return new Tensor[] { x_grad, null, null };
else
return new Tensor[] { x_grad, null };
}

[RegisterGradient("Squeeze")] [RegisterGradient("Squeeze")]
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{ {


+ 46
- 48
src/TensorFlowNET.Core/Gradients/control_flow_grad.cs View File

@@ -36,56 +36,54 @@ namespace Tensorflow.Gradients
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
[RegisterGradient("Switch")] [RegisterGradient("Switch")]
public Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
{ {
var grad = grads[0];
var graph = ops.get_default_graph();
var op_ctxt = op._get_control_flow_context();
var grad_ctxt = graph._get_control_flow_context();
switch (op_ctxt)
{
case WhileContext cwhile:
throw new NotImplementedException("_SwitchGrad WhileContext");
case CondContext ccond:
{
var zero_grad = grads[1 - op_ctxt.branch];
// At this point, we have created zero_grad guarded by the right switch.
// Unfortunately, we may still get None here for not trainable data types.
if(zero_grad == null)
{
throw new NotImplementedException("_SwitchGrad CondContext zero_grad");
}

return new Tensor[]
{
merge(grads, name: "cond_grad")[0],
null
};
}
default:
throw new NotImplementedException("_SwitchGrad WhileContext");
}
throw new NotImplementedException("_SwitchGrad"); throw new NotImplementedException("_SwitchGrad");
//graph = ops.get_default_graph()
//# pylint: disable=protected-access
//op_ctxt = op._get_control_flow_context()
//grad_ctxt = graph._get_control_flow_context()
//# pylint: enable=protected-access
//if isinstance(op_ctxt, WhileContext):
// merge_grad = grad_ctxt.grad_state.switch_map.get(op)
// if merge_grad is not None:
// # This is the second time this Switch is visited. It comes from
// # the non-exit branch of the Switch, so update the second input
// # to the Merge.
// # TODO(yuanbyu): Perform shape inference with this new input.
// if grad[1] is not None:
// # pylint: disable=protected-access
// control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1],
// enforce_shape_invariant=False)
// # pylint: enable=protected-access
// return None, None
// elif grad[0] is not None:
// # This is the first time this Switch is visited. It comes from
// # the Exit branch, which is grad[0]. grad[1] is empty at this point.
// # Use grad[0] for both inputs to merge for now, but update the second
// # input of merge when we see this Switch the second time.
// merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
// grad_ctxt.grad_state.switch_map[op] = merge_grad
// return merge_grad, None
// else:
// # This is the first time this Switch is visited. It comes from the
// # Identity branch. Such a Switch has `None` gradient for the Exit branch,
// # meaning the output is not differentiable.
// return None, None
//elif isinstance(op_ctxt, CondContext):
// zero_grad = grad[1 - op_ctxt.branch]
// # At this point, we have created zero_grad guarded by the right switch.
// # Unfortunately, we may still get None here for not trainable data types.
// if zero_grad is None:
// # For resource variables we get None always on the other branch, so bypass
// # this.
// if op.inputs[0].dtype == dtypes.resource:
// return merge(
// [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None
// return None, None
// return merge(grad, name="cond_grad")[0], None
//else:
// false_grad = switch(grad[0], op.inputs[1])[0]
// true_grad = switch(grad[1], op.inputs[1])[1]
// return merge([false_grad, true_grad])[0], None
}

/// <summary>
/// Returns the value of an available element of `inputs`.
/// </summary>
/// <param name="inputs"></param>
/// <param name="name"></param>
/// <returns></returns>
internal static Tensor[] merge(Tensor[] inputs, string name = null)
{
return tf_with(ops.name_scope(name, "Merge", inputs), scope =>
{
name = scope;
if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length)
return gen_control_flow_ops.ref_merge(inputs, name: name);
else
return gen_control_flow_ops.merge(inputs, name: name);
});
} }


/// <summary> /// <summary>


+ 1
- 1
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow
{ {
// generate gradient subgraph for op. // generate gradient subgraph for op.
var op = queue.Dequeue(); var op = queue.Dequeue();
if(tf.get_default_graph()._nodes_by_name.Count > 18577)
if(tf.get_default_graph()._nodes_by_name.Count >= 20611)
{ {


} }


+ 54
- 0
src/TensorFlowNET.Core/Gradients/image_grad.cs View File

@@ -0,0 +1,54 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow.Gradients
{
[RegisterGradient("image_grad")]
public class image_grad
{
[RegisterGradient("ResizeNearestNeighbor")]
public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var image = op.inputs[0];
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
Tensor image_shape = null;
if (shape.is_fully_defined())
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined");
else
image_shape = array_ops.shape(image)["1:3"];

grad = gen_image_ops.resize_nearest_neighbor_grad(
grad,
image_shape,
align_corners: op.get_attr<bool>("align_corners"),
half_pixel_centers: op.get_attr<bool>("half_pixel_centers"));

return new Tensor[]
{
grad,
null
};
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -440,6 +440,9 @@ namespace Tensorflow
case List<VariableV1> list: case List<VariableV1> list:
t = list.Select(x => (T)(object)x).ToList(); t = list.Select(x => (T)(object)x).ToList();
break; break;
case List<ResourceVariable> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<RefVariable> list: case List<RefVariable> list:
t = list.Select(x => (T)(object)x).ToList(); t = list.Select(x => (T)(object)x).ToList();
break; break;


src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs → src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -518,7 +518,7 @@ namespace Tensorflow
inputs = inputs.Select(inp => inputs = inputs.Select(inp =>
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true))
.ToArray(); .ToArray();
return gen_control_flow_ops.merge(inputs, name).Item1;
return gen_control_flow_ops.merge(inputs, name)[0];
}); });
} }



src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs → src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs View File

@@ -148,11 +148,18 @@ namespace Tensorflow
return new []{_op.outputs[0], _op.outputs[1]}; return new []{_op.outputs[0], _op.outputs[1]};
} }


public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)
public static Tensor[] ref_merge(Tensor[] inputs, string name = null)
{
var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs });

return _op.outputs;
}

public static Tensor[] merge(Tensor[] inputs, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs });


return (_op.outputs[0], _op.outputs[1]);
return _op.outputs;
} }
} }
} }

+ 14
- 0
src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs View File

@@ -183,5 +183,19 @@ namespace Tensorflow


return op.output; return op.output;
} }

public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false,
bool half_pixel_centers = false, string name = null)
{
var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new
{
grads,
size,
align_corners,
half_pixel_centers
});

return op.output;
}
} }
} }

Loading…
Cancel
Save