@@ -9,4 +9,4 @@ community_bridge: # Replace with a single Community Bridge project-name e.g., cl | |||||
liberapay: # Replace with a single Liberapay username | liberapay: # Replace with a single Liberapay username | ||||
issuehunt: # Replace with a single IssueHunt username | issuehunt: # Replace with a single IssueHunt username | ||||
otechie: # Replace with a single Otechie username | otechie: # Replace with a single Otechie username | ||||
custom: ['https://paypal.me/pools/c/8fK9eKwbbL']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] | |||||
custom: ['https://bit.ly/2op1mu5']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] |
@@ -43,7 +43,8 @@ namespace Tensorflow | |||||
public ExponentialMovingAverage ExponentialMovingAverage(float decay) | public ExponentialMovingAverage ExponentialMovingAverage(float decay) | ||||
=> new ExponentialMovingAverage(decay); | => new ExponentialMovingAverage(decay); | ||||
public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); | |||||
public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5) | |||||
=> new Saver(var_list: var_list, max_to_keep: max_to_keep); | |||||
public string write_graph(Graph graph, string logdir, string name, bool as_text = true) | public string write_graph(Graph graph, string logdir, string name, bool as_text = true) | ||||
=> graph_io.write_graph(graph, logdir, name, as_text); | => graph_io.write_graph(graph, logdir, name, as_text); | ||||
@@ -54,7 +55,7 @@ namespace Tensorflow | |||||
clear_devices, | clear_devices, | ||||
import_scope).Item1; | import_scope).Item1; | ||||
public (MetaGraphDef, Dictionary<string, RefVariable>) export_meta_graph(string filename = "", | |||||
public (MetaGraphDef, Dictionary<string, VariableV1>) export_meta_graph(string filename = "", | |||||
bool as_text = false, | bool as_text = false, | ||||
bool clear_devices = false, | bool clear_devices = false, | ||||
bool clear_extraneous_savers = false, | bool clear_extraneous_savers = false, | ||||
@@ -167,7 +167,7 @@ namespace Tensorflow | |||||
/// <param name="strip_default_attrs"></param> | /// <param name="strip_default_attrs"></param> | ||||
/// <param name="meta_info_def"></param> | /// <param name="meta_info_def"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static (MetaGraphDef, Dictionary<string, RefVariable>) export_scoped_meta_graph(string filename = "", | |||||
public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_graph(string filename = "", | |||||
GraphDef graph_def = null, | GraphDef graph_def = null, | ||||
bool as_text = false, | bool as_text = false, | ||||
string unbound_inputs_col_name = "unbound_inputs", | string unbound_inputs_col_name = "unbound_inputs", | ||||
@@ -179,8 +179,8 @@ namespace Tensorflow | |||||
{ | { | ||||
var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
var var_list = new Dictionary<string, RefVariable>(); | |||||
var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES); | |||||
var var_list = new Dictionary<string, VariableV1>(); | |||||
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES); | |||||
if (variables != null) | if (variables != null) | ||||
{ | { |
@@ -190,6 +190,26 @@ namespace Tensorflow.Gradients | |||||
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | ||||
} | } | ||||
[RegisterGradient("Pad")] | |||||
public static Tensor[] _PadGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
var grad = grads[0]; | |||||
var x = op.inputs[0]; | |||||
var a = op.inputs[1]; | |||||
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); | |||||
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); | |||||
// Make it a 1-D tensor. | |||||
var begin = array_ops.reshape(pad_before, new[] { -1 }); | |||||
var sizes = array_ops.shape(x); | |||||
var x_grad = array_ops.slice(grad, begin, sizes); | |||||
if (len(op.inputs) == 3) | |||||
return new Tensor[] { x_grad, null, null }; | |||||
else | |||||
return new Tensor[] { x_grad, null }; | |||||
} | |||||
[RegisterGradient("Squeeze")] | [RegisterGradient("Squeeze")] | ||||
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) | public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) | ||||
{ | { | ||||
@@ -36,56 +36,54 @@ namespace Tensorflow.Gradients | |||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[RegisterGradient("Switch")] | [RegisterGradient("Switch")] | ||||
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||||
public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads) | |||||
{ | { | ||||
var grad = grads[0]; | |||||
var graph = ops.get_default_graph(); | |||||
var op_ctxt = op._get_control_flow_context(); | |||||
var grad_ctxt = graph._get_control_flow_context(); | |||||
switch (op_ctxt) | |||||
{ | |||||
case WhileContext cwhile: | |||||
throw new NotImplementedException("_SwitchGrad WhileContext"); | |||||
case CondContext ccond: | |||||
{ | |||||
var zero_grad = grads[1 - op_ctxt.branch]; | |||||
// At this point, we have created zero_grad guarded by the right switch. | |||||
// Unfortunately, we may still get None here for not trainable data types. | |||||
if(zero_grad == null) | |||||
{ | |||||
throw new NotImplementedException("_SwitchGrad CondContext zero_grad"); | |||||
} | |||||
return new Tensor[] | |||||
{ | |||||
merge(grads, name: "cond_grad")[0], | |||||
null | |||||
}; | |||||
} | |||||
default: | |||||
throw new NotImplementedException("_SwitchGrad WhileContext"); | |||||
} | |||||
throw new NotImplementedException("_SwitchGrad"); | throw new NotImplementedException("_SwitchGrad"); | ||||
//graph = ops.get_default_graph() | |||||
//# pylint: disable=protected-access | |||||
//op_ctxt = op._get_control_flow_context() | |||||
//grad_ctxt = graph._get_control_flow_context() | |||||
//# pylint: enable=protected-access | |||||
//if isinstance(op_ctxt, WhileContext): | |||||
// merge_grad = grad_ctxt.grad_state.switch_map.get(op) | |||||
// if merge_grad is not None: | |||||
// # This is the second time this Switch is visited. It comes from | |||||
// # the non-exit branch of the Switch, so update the second input | |||||
// # to the Merge. | |||||
// # TODO(yuanbyu): Perform shape inference with this new input. | |||||
// if grad[1] is not None: | |||||
// # pylint: disable=protected-access | |||||
// control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], | |||||
// enforce_shape_invariant=False) | |||||
// # pylint: enable=protected-access | |||||
// return None, None | |||||
// elif grad[0] is not None: | |||||
// # This is the first time this Switch is visited. It comes from | |||||
// # the Exit branch, which is grad[0]. grad[1] is empty at this point. | |||||
// # Use grad[0] for both inputs to merge for now, but update the second | |||||
// # input of merge when we see this Switch the second time. | |||||
// merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] | |||||
// grad_ctxt.grad_state.switch_map[op] = merge_grad | |||||
// return merge_grad, None | |||||
// else: | |||||
// # This is the first time this Switch is visited. It comes from the | |||||
// # Identity branch. Such a Switch has `None` gradient for the Exit branch, | |||||
// # meaning the output is not differentiable. | |||||
// return None, None | |||||
//elif isinstance(op_ctxt, CondContext): | |||||
// zero_grad = grad[1 - op_ctxt.branch] | |||||
// # At this point, we have created zero_grad guarded by the right switch. | |||||
// # Unfortunately, we may still get None here for not trainable data types. | |||||
// if zero_grad is None: | |||||
// # For resource variables we get None always on the other branch, so bypass | |||||
// # this. | |||||
// if op.inputs[0].dtype == dtypes.resource: | |||||
// return merge( | |||||
// [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None | |||||
// return None, None | |||||
// return merge(grad, name="cond_grad")[0], None | |||||
//else: | |||||
// false_grad = switch(grad[0], op.inputs[1])[0] | |||||
// true_grad = switch(grad[1], op.inputs[1])[1] | |||||
// return merge([false_grad, true_grad])[0], None | |||||
} | |||||
/// <summary> | |||||
/// Returns the value of an available element of `inputs`. | |||||
/// </summary> | |||||
/// <param name="inputs"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
internal static Tensor[] merge(Tensor[] inputs, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "Merge", inputs), scope => | |||||
{ | |||||
name = scope; | |||||
if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length) | |||||
return gen_control_flow_ops.ref_merge(inputs, name: name); | |||||
else | |||||
return gen_control_flow_ops.merge(inputs, name: name); | |||||
}); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -108,10 +108,7 @@ namespace Tensorflow | |||||
{ | { | ||||
// generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
if(tf.get_default_graph()._nodes_by_name.Count > 18505) | |||||
{ | |||||
} | |||||
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | ||||
//if (loop_state != null) | //if (loop_state != null) | ||||
//loop_state.EnterGradWhileContext(op, before: true); | //loop_state.EnterGradWhileContext(op, before: true); | ||||
@@ -216,7 +213,7 @@ namespace Tensorflow | |||||
in_grad.Tag == null && // maybe a IndexedSlice | in_grad.Tag == null && // maybe a IndexedSlice | ||||
t_in.dtype != TF_DataType.TF_RESOURCE) | t_in.dtype != TF_DataType.TF_RESOURCE) | ||||
{ | { | ||||
in_grad.shape = t_in.shape; | |||||
in_grad.set_shape(t_in.TensorShape); | |||||
} | } | ||||
_SetGrad(grads, t_in, in_grad); | _SetGrad(grads, t_in, in_grad); | ||||
@@ -0,0 +1,54 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Framework; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Gradients | |||||
{ | |||||
[RegisterGradient("image_grad")] | |||||
public class image_grad | |||||
{ | |||||
[RegisterGradient("ResizeNearestNeighbor")] | |||||
public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
var grad = grads[0]; | |||||
var image = op.inputs[0]; | |||||
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | |||||
Tensor image_shape = null; | |||||
if (shape.is_fully_defined()) | |||||
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); | |||||
else | |||||
image_shape = array_ops.shape(image)["1:3"]; | |||||
grad = gen_image_ops.resize_nearest_neighbor_grad( | |||||
grad, | |||||
image_shape, | |||||
align_corners: op.get_attr<bool>("align_corners"), | |||||
half_pixel_centers: op.get_attr<bool>("half_pixel_centers")); | |||||
return new Tensor[] | |||||
{ | |||||
grad, | |||||
null | |||||
}; | |||||
} | |||||
} | |||||
} |
@@ -166,6 +166,94 @@ namespace Tensorflow.Gradients | |||||
}; | }; | ||||
} | } | ||||
[RegisterGradient("FusedBatchNorm")] | |||||
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) | |||||
=> _BaseFusedBatchNormGrad(op, 0, grads); | |||||
/// <summary> | |||||
/// Return the gradients for the 3 inputs of BatchNorm. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="version"></param> | |||||
/// <param name="grads"></param> | |||||
/// <returns></returns> | |||||
public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads) | |||||
{ | |||||
var x = op.inputs[0]; | |||||
var grad_y = grads[0]; | |||||
var scale = op.inputs[1]; | |||||
var epsilon = op.get_attr<float>("epsilon"); | |||||
var data_format = op.get_attr<string>("data_format"); | |||||
var is_training = op.get_attr<bool>("is_training"); | |||||
Func<FusedBatchNormParams, Tensor[]> grad_fun = null; | |||||
switch (version) | |||||
{ | |||||
case 2: | |||||
throw new NotImplementedException(""); | |||||
case 1: | |||||
throw new NotImplementedException(""); | |||||
default: | |||||
grad_fun = gen_nn_ops.fused_batch_norm_grad; | |||||
break; | |||||
} | |||||
if (is_training) | |||||
{ | |||||
return grad_fun(new FusedBatchNormParams | |||||
{ | |||||
YBackprop = grad_y, | |||||
X = x, | |||||
Scale = scale, | |||||
ReserveSpace1 = op.outputs[3], | |||||
ReserveSpace2 = op.outputs[4], | |||||
ReserveSpace3 = version == 2 ? op.outputs[5] : null, | |||||
Epsilon = epsilon, | |||||
DataFormat = data_format, | |||||
IsTraining = is_training | |||||
}); | |||||
} | |||||
else | |||||
{ | |||||
var pop_mean = op.inputs[3]; | |||||
var pop_var = op.inputs[4]; | |||||
if (data_format == "NCHW") | |||||
throw new NotImplementedException(""); | |||||
var results = grad_fun(new FusedBatchNormParams | |||||
{ | |||||
YBackprop = grad_y, | |||||
X = x, | |||||
Scale = scale, | |||||
ReserveSpace1 = op.outputs[3], | |||||
ReserveSpace2 = op.outputs[4], | |||||
ReserveSpace3 = version == 2 ? op.outputs[5] : null, | |||||
Epsilon = epsilon, | |||||
DataFormat = data_format, | |||||
IsTraining = is_training | |||||
}); | |||||
var (dx, dscale, doffset) = (results[0], results[1], results[2]); | |||||
if (data_format == "NCHW") | |||||
throw new NotImplementedException(""); | |||||
return new Tensor[] | |||||
{ | |||||
dx, | |||||
dscale, | |||||
doffset, | |||||
null, | |||||
null | |||||
}; | |||||
} | |||||
} | |||||
[RegisterGradient("BatchNormWithGlobalNormalization")] | |||||
public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
throw new NotImplementedException("BatchNormWithGlobalNormalization"); | |||||
} | |||||
private static bool IsZero(Tensor g) | private static bool IsZero(Tensor g) | ||||
{ | { | ||||
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) | if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) | ||||
@@ -440,6 +440,9 @@ namespace Tensorflow | |||||
case List<VariableV1> list: | case List<VariableV1> list: | ||||
t = list.Select(x => (T)(object)x).ToList(); | t = list.Select(x => (T)(object)x).ToList(); | ||||
break; | break; | ||||
case List<ResourceVariable> list: | |||||
t = list.Select(x => (T)(object)x).ToList(); | |||||
break; | |||||
case List<RefVariable> list: | case List<RefVariable> list: | ||||
t = list.Select(x => (T)(object)x).ToList(); | t = list.Select(x => (T)(object)x).ToList(); | ||||
break; | break; | ||||
@@ -27,20 +27,6 @@ namespace Tensorflow.Operations | |||||
/// </summary> | /// </summary> | ||||
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | ||||
{ | { | ||||
/// <summary> | |||||
/// The boolean tensor for the cond predicate | |||||
/// </summary> | |||||
private Tensor _pred; | |||||
public Tensor pred => _pred; | |||||
/// <summary> | |||||
/// 0 or 1 representing this branch | |||||
/// </summary> | |||||
private int _branch; | |||||
private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | ||||
/// <summary> | /// <summary> | ||||
@@ -45,10 +45,19 @@ namespace Tensorflow.Operations | |||||
/// The predicate tensor in this branch | /// The predicate tensor in this branch | ||||
/// </summary> | /// </summary> | ||||
protected Tensor _pivot; | protected Tensor _pivot; | ||||
public Tensor pivot | |||||
{ | |||||
get => _pivot; | |||||
} | |||||
public Tensor pivot => _pivot; | |||||
/// <summary> | |||||
/// The boolean tensor for the cond predicate | |||||
/// </summary> | |||||
protected Tensor _pred; | |||||
public Tensor pred => _pred; | |||||
/// <summary> | |||||
/// 0 or 1 representing this branch | |||||
/// </summary> | |||||
protected int _branch; | |||||
public int branch => _branch; | |||||
protected Stack<ControlFlowContext> _context_stack; | protected Stack<ControlFlowContext> _context_stack; | ||||
protected ControlFlowContext _outer_context; | protected ControlFlowContext _outer_context; | ||||
@@ -0,0 +1,27 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public class FusedBatchNormParams | |||||
{ | |||||
public string Name { get; set; } | |||||
public Tensor YBackprop { get; set; } | |||||
public Tensor X { get; set; } | |||||
public Tensor Scale { get; set; } | |||||
public Tensor ReserveSpace1 { get; set; } | |||||
public Tensor ReserveSpace2 { get; set; } | |||||
public Tensor ReserveSpace3 { get; set; } | |||||
public float Epsilon { get; set; } | |||||
public string DataFormat { get; set; } | |||||
public bool IsTraining { get; set; } | |||||
public FusedBatchNormParams() | |||||
{ | |||||
Epsilon = 0.0001f; | |||||
DataFormat = "NHWC"; | |||||
IsTraining = true; | |||||
} | |||||
} | |||||
} |
@@ -156,6 +156,35 @@ namespace Tensorflow.Operations | |||||
return op.output; | return op.output; | ||||
} | } | ||||
/// <summary> | |||||
/// Gradient for batch normalization. | |||||
/// </summary> | |||||
/// <param name="y_backprop"></param> | |||||
/// <param name="x"></param> | |||||
/// <param name="scale"></param> | |||||
/// <param name="reserve_space_1"></param> | |||||
/// <param name="reserve_space_2"></param> | |||||
/// <param name="epsilon"></param> | |||||
/// <param name="data_format"></param> | |||||
/// <param name="is_training"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params) | |||||
{ | |||||
var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new | |||||
{ | |||||
y_backprop = @params.YBackprop, | |||||
x = @params.X, | |||||
scale = @params.Scale, | |||||
reserve_space_1 = @params.ReserveSpace1, | |||||
reserve_space_2 = @params.ReserveSpace2, | |||||
epsilon = @params.Epsilon, | |||||
data_format = @params.DataFormat, | |||||
is_training = @params.IsTraining | |||||
}); | |||||
return op.outputs; | |||||
} | |||||
public static Tensor[] fused_batch_norm(Tensor x, | public static Tensor[] fused_batch_norm(Tensor x, | ||||
Tensor scale, | Tensor scale, | ||||
Tensor offset, | Tensor offset, | ||||
@@ -53,7 +53,7 @@ namespace Tensorflow | |||||
for (int i = 0; i < NumInputs; i++) | for (int i = 0; i < NumInputs; i++) | ||||
{ | { | ||||
var tf_output = Input(i); | var tf_output = Input(i); | ||||
var op = new Operation(tf_output.oper); | |||||
var op = GetOperation(tf_output.oper); | |||||
retval[i] = op.outputs[tf_output.index]; | retval[i] = op.outputs[tf_output.index]; | ||||
} | } | ||||
@@ -0,0 +1,41 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class Operation | |||||
{ | |||||
// cache the mapping between managed and unmanaged op | |||||
// some data is stored in managed instance, so when | |||||
// create Operation by IntPtr, it will lost some data. | |||||
private static Dictionary<IntPtr, Operation> OpInstances = new Dictionary<IntPtr, Operation>(); | |||||
/// <summary> | |||||
/// Get operation by handle | |||||
/// </summary> | |||||
/// <param name="handle"></param> | |||||
/// <returns></returns> | |||||
public Operation GetOperation(IntPtr handle) | |||||
{ | |||||
return OpInstances.ContainsKey(handle) ? | |||||
OpInstances[handle] : | |||||
new Operation(handle); | |||||
} | |||||
} | |||||
} |
@@ -84,9 +84,10 @@ namespace Tensorflow | |||||
_control_flow_context = _graph._get_control_flow_context(); | _control_flow_context = _graph._get_control_flow_context(); | ||||
// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | ||||
OpInstances[_handle] = this; | |||||
} | } | ||||
public Operation(Graph g, string opType, string oper_name) | |||||
/*public Operation(Graph g, string opType, string oper_name) | |||||
{ | { | ||||
_graph = g; | _graph = g; | ||||
@@ -102,7 +103,7 @@ namespace Tensorflow | |||||
// Dict mapping op name to file and line information for op colocation | // Dict mapping op name to file and line information for op colocation | ||||
// context managers. | // context managers. | ||||
_control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
} | |||||
}*/ | |||||
/// <summary> | /// <summary> | ||||
/// Creates an `Operation`. | /// Creates an `Operation`. | ||||
@@ -151,11 +152,6 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1") | |||||
{ | |||||
} | |||||
// Dict mapping op name to file and line information for op colocation | // Dict mapping op name to file and line information for op colocation | ||||
// context managers. | // context managers. | ||||
_control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
@@ -164,7 +160,7 @@ namespace Tensorflow | |||||
if (op_def == null) | if (op_def == null) | ||||
op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | ||||
// Initialize self._outputs. | // Initialize self._outputs. | ||||
@@ -180,6 +176,8 @@ namespace Tensorflow | |||||
if (_handle != IntPtr.Zero) | if (_handle != IntPtr.Zero) | ||||
_control_flow_post_processing(); | _control_flow_post_processing(); | ||||
OpInstances[_handle] = this; | |||||
} | } | ||||
public void run(FeedItem[] feed_dict = null, Session session = null) | public void run(FeedItem[] feed_dict = null, Session session = null) | ||||
@@ -220,6 +218,9 @@ namespace Tensorflow | |||||
return grouped_inputs.ToArray(); | return grouped_inputs.ToArray(); | ||||
} | } | ||||
public T get_attr<T>(string name) | |||||
=> (T)get_attr(name); | |||||
public object get_attr(string name) | public object get_attr(string name) | ||||
{ | { | ||||
AttrValue x = null; | AttrValue x = null; | ||||
@@ -611,7 +611,7 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||||
=> gen_array_ops.slice(input, begin, size, name: name); | => gen_array_ops.slice(input, begin, size, name: name); | ||||
public static Tensor stack(object values, int axis = 0, string name = "stack") | public static Tensor stack(object values, int axis = 0, string name = "stack") | ||||
@@ -518,7 +518,7 @@ namespace Tensorflow | |||||
inputs = inputs.Select(inp => | inputs = inputs.Select(inp => | ||||
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | ||||
.ToArray(); | .ToArray(); | ||||
return gen_control_flow_ops.merge(inputs, name).Item1; | |||||
return gen_control_flow_ops.merge(inputs, name)[0]; | |||||
}); | }); | ||||
} | } | ||||
@@ -557,8 +557,31 @@ namespace Tensorflow | |||||
throw new NotImplementedException("ZerosLikeOutsideLoop"); | throw new NotImplementedException("ZerosLikeOutsideLoop"); | ||||
return array_ops.zeros_like(val, optimize: false); | return array_ops.zeros_like(val, optimize: false); | ||||
} | } | ||||
throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||||
else | |||||
{ | |||||
var op_ctxt = op._get_control_flow_context(); | |||||
if(op_ctxt != null) | |||||
{ | |||||
// We are in a cond context. Use a switch to create zeros only when needed. | |||||
var pred = op_ctxt.pred; | |||||
var branch = op_ctxt.branch; | |||||
var switch_val = @switch(op.inputs[0], pred)[1 - branch]; | |||||
var pivot = array_ops.identity(switch_val); | |||||
if (val.dtype == dtypes.resource) | |||||
throw new NotImplementedException(""); | |||||
var zeros_shape = array_ops.shape_internal(switch_val, optimize: false); | |||||
// Ensure ops created within array_ops.zeros are dominated by switch in | |||||
// cond context. | |||||
return tf_with(ops.control_dependencies(new[] { pivot }), delegate | |||||
{ | |||||
return array_ops.zeros(zeros_shape, dtype: val.dtype); | |||||
}); | |||||
} | |||||
else | |||||
{ | |||||
return array_ops.zeros_like(val, optimize: false); | |||||
} | |||||
} | |||||
} | } | ||||
/// <summary> | /// <summary> |
@@ -475,7 +475,7 @@ namespace Tensorflow | |||||
return op.output; | return op.output; | ||||
} | } | ||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||||
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -148,11 +148,18 @@ namespace Tensorflow | |||||
return new []{_op.outputs[0], _op.outputs[1]}; | return new []{_op.outputs[0], _op.outputs[1]}; | ||||
} | } | ||||
public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | |||||
public static Tensor[] ref_merge(Tensor[] inputs, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); | |||||
return _op.outputs; | |||||
} | |||||
public static Tensor[] merge(Tensor[] inputs, string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | ||||
return (_op.outputs[0], _op.outputs[1]); | |||||
return _op.outputs; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -183,5 +183,19 @@ namespace Tensorflow | |||||
return op.output; | return op.output; | ||||
} | } | ||||
public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false, | |||||
bool half_pixel_centers = false, string name = null) | |||||
{ | |||||
var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new | |||||
{ | |||||
grads, | |||||
size, | |||||
align_corners, | |||||
half_pixel_centers | |||||
}); | |||||
return op.output; | |||||
} | |||||
} | } | ||||
} | } |
@@ -105,10 +105,13 @@ namespace Tensorflow | |||||
if (_handle == IntPtr.Zero) | if (_handle == IntPtr.Zero) | ||||
{ | { | ||||
var status = new Status(); | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||||
status.Check(); | |||||
} else | |||||
using (var status = new Status()) | |||||
{ | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||||
status.Check(); | |||||
} | |||||
} | |||||
else | |||||
{ | { | ||||
for (int i = 0; i < rank; i++) | for (int i = 0; i < rank; i++) | ||||
dims[i] = c_api.TF_Dim(_handle, i); | dims[i] = c_api.TF_Dim(_handle, i); | ||||
@@ -119,14 +122,15 @@ namespace Tensorflow | |||||
set | set | ||||
{ | { | ||||
var status = new Status(); | |||||
if (value == null) | |||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||||
else | |||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
using (var status = new Status()) | |||||
{ | |||||
if (value == null) | |||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||||
else | |||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
status.Check(true); | |||||
status.Check(true); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -142,16 +146,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public void set_shape(TensorShape shape) | public void set_shape(TensorShape shape) | ||||
{ | { | ||||
this.shape = (int[]) shape.dims.Clone(); | |||||
} | |||||
/// <summary> | |||||
/// Updates the shape of this tensor. | |||||
/// </summary> | |||||
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
public void SetShape(TensorShape shape) | |||||
{ | |||||
this.shape = (int[]) shape.dims.Clone(); | |||||
this.shape = shape.rank > 0 ? shape.dims : null; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -33,6 +33,7 @@ namespace Tensorflow | |||||
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | ||||
public static TF_DataType float16 = TF_DataType.TF_HALF; | public static TF_DataType float16 = TF_DataType.TF_HALF; | ||||
public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | ||||
public static TF_DataType resource = TF_DataType.TF_RESOURCE; | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -227,29 +227,30 @@ namespace Tensorflow | |||||
throw new NotImplementedException("_create_c_op"); | throw new NotImplementedException("_create_c_op"); | ||||
} | } | ||||
var status = new Status(); | |||||
// Add control inputs | |||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
// Add attrs | |||||
foreach (var attr in node_def.Attr) | |||||
using (var status = new Status()) | |||||
{ | { | ||||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||||
uint len = (uint) bytes.Length; | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||||
// Add control inputs | |||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
status.Check(true); | |||||
} | |||||
// Add attrs | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
var protoHandle = Marshal.AllocHGlobal(bytes.Length); | |||||
Marshal.Copy(bytes, 0, protoHandle, bytes.Length); | |||||
uint len = (uint)bytes.Length; | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); | |||||
status.Check(true); | |||||
Marshal.FreeHGlobal(protoHandle); | |||||
} | |||||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
status.Check(true); | |||||
status.Check(true); | |||||
return c_op; | |||||
return c_op; | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -54,6 +54,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
List<RefVariable> first_stage_trainable_var_list; | List<RefVariable> first_stage_trainable_var_list; | ||||
Operation train_op_with_frozen_variables; | Operation train_op_with_frozen_variables; | ||||
Operation train_op_with_all_variables; | Operation train_op_with_all_variables; | ||||
Saver loader; | |||||
Saver saver; | |||||
#endregion | #endregion | ||||
public bool Run() | public bool Run() | ||||
@@ -74,7 +76,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
public void Train(Session sess) | public void Train(Session sess) | ||||
{ | { | ||||
sess.run(tf.global_variables_initializer()); | |||||
print($"=> Restoring weights from: {cfg.TRAIN.INITIAL_WEIGHT} ... "); | |||||
loader.restore(sess, cfg.TRAIN.INITIAL_WEIGHT); | |||||
} | } | ||||
public void Test(Session sess) | public void Test(Session sess) | ||||
@@ -184,6 +188,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
}); | }); | ||||
}); | }); | ||||
tf_with(tf.name_scope("loader_and_saver"), delegate | |||||
{ | |||||
loader = tf.train.Saver(net_var); | |||||
saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10); | |||||
}); | |||||
tf_with(tf.name_scope("summary"), delegate | |||||
{ | |||||
tf.summary.scalar("learn_rate", learn_rate); | |||||
tf.summary.scalar("giou_loss", giou_loss); | |||||
tf.summary.scalar("conf_loss", conf_loss); | |||||
tf.summary.scalar("prob_loss", prob_loss); | |||||
tf.summary.scalar("total_loss", loss); | |||||
}); | |||||
return graph; | return graph; | ||||
} | } | ||||
@@ -60,7 +60,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
public TrainConfig(string root) | public TrainConfig(string root) | ||||
{ | { | ||||
_root = root; | _root = root; | ||||
INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt"); | |||||
INITIAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); | |||||
ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); | ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); | ||||
} | } | ||||
} | } | ||||