Esther2013 6 years ago
parent
commit
7780c25181
26 changed files with 447 additions and 133 deletions
  1. +1
    -1
      .github/FUNDING.yml
  2. +3
    -2
      src/TensorFlowNET.Core/APIs/tf.train.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  4. +20
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  5. +46
    -48
      src/TensorFlowNET.Core/Gradients/control_flow_grad.cs
  6. +1
    -4
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  7. +54
    -0
      src/TensorFlowNET.Core/Gradients/image_grad.cs
  8. +88
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  9. +3
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  10. +0
    -14
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  11. +13
    -4
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  12. +27
    -0
      src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs
  13. +29
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Input.cs
  15. +41
    -0
      src/TensorFlowNET.Core/Operations/Operation.Instance.cs
  16. +9
    -8
      src/TensorFlowNET.Core/Operations/Operation.cs
  17. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  18. +26
    -3
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  19. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  20. +9
    -2
      src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs
  21. +14
    -0
      src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs
  22. +16
    -21
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  23. +1
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  24. +19
    -18
      src/TensorFlowNET.Core/ops.cs
  25. +20
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
  26. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs

+ 1
- 1
.github/FUNDING.yml View File

@@ -9,4 +9,4 @@ community_bridge: # Replace with a single Community Bridge project-name e.g., cl
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: ['https://paypal.me/pools/c/8fK9eKwbbL']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
custom: ['https://bit.ly/2op1mu5']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

+ 3
- 2
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -43,7 +43,8 @@ namespace Tensorflow
public ExponentialMovingAverage ExponentialMovingAverage(float decay)
=> new ExponentialMovingAverage(decay);

public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list);
public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5)
=> new Saver(var_list: var_list, max_to_keep: max_to_keep);

public string write_graph(Graph graph, string logdir, string name, bool as_text = true)
=> graph_io.write_graph(graph, logdir, name, as_text);
@@ -54,7 +55,7 @@ namespace Tensorflow
clear_devices,
import_scope).Item1;

public (MetaGraphDef, Dictionary<string, RefVariable>) export_meta_graph(string filename = "",
public (MetaGraphDef, Dictionary<string, VariableV1>) export_meta_graph(string filename = "",
bool as_text = false,
bool clear_devices = false,
bool clear_extraneous_savers = false,


src/TensorFlowNET.Core/Framework/meta_graph.py.cs → src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -167,7 +167,7 @@ namespace Tensorflow
/// <param name="strip_default_attrs"></param>
/// <param name="meta_info_def"></param>
/// <returns></returns>
public static (MetaGraphDef, Dictionary<string, RefVariable>) export_scoped_meta_graph(string filename = "",
public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_graph(string filename = "",
GraphDef graph_def = null,
bool as_text = false,
string unbound_inputs_col_name = "unbound_inputs",
@@ -179,8 +179,8 @@ namespace Tensorflow
{
var graph = ops.get_default_graph();

var var_list = new Dictionary<string, RefVariable>();
var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES);
var var_list = new Dictionary<string, VariableV1>();
var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES);

if (variables != null)
{

+ 20
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -190,6 +190,26 @@ namespace Tensorflow.Gradients
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
}

[RegisterGradient("Pad")]
public static Tensor[] _PadGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var x = op.inputs[0];
var a = op.inputs[1];
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 });
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size);

// Make it a 1-D tensor.
var begin = array_ops.reshape(pad_before, new[] { -1 });
var sizes = array_ops.shape(x);
var x_grad = array_ops.slice(grad, begin, sizes);

if (len(op.inputs) == 3)
return new Tensor[] { x_grad, null, null };
else
return new Tensor[] { x_grad, null };
}

[RegisterGradient("Squeeze")]
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{


+ 46
- 48
src/TensorFlowNET.Core/Gradients/control_flow_grad.cs View File

@@ -36,56 +36,54 @@ namespace Tensorflow.Gradients
/// </summary>
/// <returns></returns>
[RegisterGradient("Switch")]
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var graph = ops.get_default_graph();
var op_ctxt = op._get_control_flow_context();
var grad_ctxt = graph._get_control_flow_context();
switch (op_ctxt)
{
case WhileContext cwhile:
throw new NotImplementedException("_SwitchGrad WhileContext");
case CondContext ccond:
{
var zero_grad = grads[1 - op_ctxt.branch];
// At this point, we have created zero_grad guarded by the right switch.
// Unfortunately, we may still get None here for not trainable data types.
if(zero_grad == null)
{
throw new NotImplementedException("_SwitchGrad CondContext zero_grad");
}

return new Tensor[]
{
merge(grads, name: "cond_grad")[0],
null
};
}
default:
throw new NotImplementedException("_SwitchGrad WhileContext");
}
throw new NotImplementedException("_SwitchGrad");
//graph = ops.get_default_graph()
//# pylint: disable=protected-access
//op_ctxt = op._get_control_flow_context()
//grad_ctxt = graph._get_control_flow_context()
//# pylint: enable=protected-access
//if isinstance(op_ctxt, WhileContext):
// merge_grad = grad_ctxt.grad_state.switch_map.get(op)
// if merge_grad is not None:
// # This is the second time this Switch is visited. It comes from
// # the non-exit branch of the Switch, so update the second input
// # to the Merge.
// # TODO(yuanbyu): Perform shape inference with this new input.
// if grad[1] is not None:
// # pylint: disable=protected-access
// control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1],
// enforce_shape_invariant=False)
// # pylint: enable=protected-access
// return None, None
// elif grad[0] is not None:
// # This is the first time this Switch is visited. It comes from
// # the Exit branch, which is grad[0]. grad[1] is empty at this point.
// # Use grad[0] for both inputs to merge for now, but update the second
// # input of merge when we see this Switch the second time.
// merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
// grad_ctxt.grad_state.switch_map[op] = merge_grad
// return merge_grad, None
// else:
// # This is the first time this Switch is visited. It comes from the
// # Identity branch. Such a Switch has `None` gradient for the Exit branch,
// # meaning the output is not differentiable.
// return None, None
//elif isinstance(op_ctxt, CondContext):
// zero_grad = grad[1 - op_ctxt.branch]
// # At this point, we have created zero_grad guarded by the right switch.
// # Unfortunately, we may still get None here for not trainable data types.
// if zero_grad is None:
// # For resource variables we get None always on the other branch, so bypass
// # this.
// if op.inputs[0].dtype == dtypes.resource:
// return merge(
// [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None
// return None, None
// return merge(grad, name="cond_grad")[0], None
//else:
// false_grad = switch(grad[0], op.inputs[1])[0]
// true_grad = switch(grad[1], op.inputs[1])[1]
// return merge([false_grad, true_grad])[0], None
}

/// <summary>
/// Returns the value of an available element of `inputs`.
/// </summary>
/// <param name="inputs"></param>
/// <param name="name"></param>
/// <returns></returns>
internal static Tensor[] merge(Tensor[] inputs, string name = null)
{
return tf_with(ops.name_scope(name, "Merge", inputs), scope =>
{
name = scope;
if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length)
return gen_control_flow_ops.ref_merge(inputs, name: name);
else
return gen_control_flow_ops.merge(inputs, name: name);
});
}

/// <summary>


+ 1
- 4
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -108,10 +108,7 @@ namespace Tensorflow
{
// generate gradient subgraph for op.
var op = queue.Dequeue();
if(tf.get_default_graph()._nodes_by_name.Count > 18505)
{

}
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
//if (loop_state != null)
//loop_state.EnterGradWhileContext(op, before: true);
@@ -216,7 +213,7 @@ namespace Tensorflow
in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE)
{
in_grad.shape = t_in.shape;
in_grad.set_shape(t_in.TensorShape);
}

_SetGrad(grads, t_in, in_grad);


+ 54
- 0
src/TensorFlowNET.Core/Gradients/image_grad.cs View File

@@ -0,0 +1,54 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow.Gradients
{
[RegisterGradient("image_grad")]
public class image_grad
{
[RegisterGradient("ResizeNearestNeighbor")]
public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var image = op.inputs[0];
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
Tensor image_shape = null;
if (shape.is_fully_defined())
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined");
else
image_shape = array_ops.shape(image)["1:3"];

grad = gen_image_ops.resize_nearest_neighbor_grad(
grad,
image_shape,
align_corners: op.get_attr<bool>("align_corners"),
half_pixel_centers: op.get_attr<bool>("half_pixel_centers"));

return new Tensor[]
{
grad,
null
};
}
}
}

+ 88
- 0
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -166,6 +166,94 @@ namespace Tensorflow.Gradients
};
}

[RegisterGradient("FusedBatchNorm")]
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
=> _BaseFusedBatchNormGrad(op, 0, grads);

/// <summary>
/// Return the gradients for the 3 inputs of BatchNorm.
/// </summary>
/// <param name="op"></param>
/// <param name="version"></param>
/// <param name="grads"></param>
/// <returns></returns>
public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads)
{
var x = op.inputs[0];
var grad_y = grads[0];
var scale = op.inputs[1];
var epsilon = op.get_attr<float>("epsilon");
var data_format = op.get_attr<string>("data_format");
var is_training = op.get_attr<bool>("is_training");
Func<FusedBatchNormParams, Tensor[]> grad_fun = null;

switch (version)
{
case 2:
throw new NotImplementedException("");
case 1:
throw new NotImplementedException("");
default:
grad_fun = gen_nn_ops.fused_batch_norm_grad;
break;
}

if (is_training)
{
return grad_fun(new FusedBatchNormParams
{
YBackprop = grad_y,
X = x,
Scale = scale,
ReserveSpace1 = op.outputs[3],
ReserveSpace2 = op.outputs[4],
ReserveSpace3 = version == 2 ? op.outputs[5] : null,
Epsilon = epsilon,
DataFormat = data_format,
IsTraining = is_training
});
}
else
{
var pop_mean = op.inputs[3];
var pop_var = op.inputs[4];
if (data_format == "NCHW")
throw new NotImplementedException("");

var results = grad_fun(new FusedBatchNormParams
{
YBackprop = grad_y,
X = x,
Scale = scale,
ReserveSpace1 = op.outputs[3],
ReserveSpace2 = op.outputs[4],
ReserveSpace3 = version == 2 ? op.outputs[5] : null,
Epsilon = epsilon,
DataFormat = data_format,
IsTraining = is_training
});

var (dx, dscale, doffset) = (results[0], results[1], results[2]);
if (data_format == "NCHW")
throw new NotImplementedException("");

return new Tensor[]
{
dx,
dscale,
doffset,
null,
null
};
}
}

[RegisterGradient("BatchNormWithGlobalNormalization")]
public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads)
{
throw new NotImplementedException("BatchNormWithGlobalNormalization");
}

private static bool IsZero(Tensor g)
{
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type))


+ 3
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -440,6 +440,9 @@ namespace Tensorflow
case List<VariableV1> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<ResourceVariable> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<RefVariable> list:
t = list.Select(x => (T)(object)x).ToList();
break;


+ 0
- 14
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -27,20 +27,6 @@ namespace Tensorflow.Operations
/// </summary>
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext>
{


/// <summary>
/// The boolean tensor for the cond predicate
/// </summary>
private Tensor _pred;

public Tensor pred => _pred;

/// <summary>
/// 0 or 1 representing this branch
/// </summary>
private int _branch;

private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>();

/// <summary>


+ 13
- 4
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -45,10 +45,19 @@ namespace Tensorflow.Operations
/// The predicate tensor in this branch
/// </summary>
protected Tensor _pivot;
public Tensor pivot
{
get => _pivot;
}
public Tensor pivot => _pivot;

/// <summary>
/// The boolean tensor for the cond predicate
/// </summary>
protected Tensor _pred;
public Tensor pred => _pred;

/// <summary>
/// 0 or 1 representing this branch
/// </summary>
protected int _branch;
public int branch => _branch;

protected Stack<ControlFlowContext> _context_stack;
protected ControlFlowContext _outer_context;


+ 27
- 0
src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
public class FusedBatchNormParams
{
public string Name { get; set; }
public Tensor YBackprop { get; set; }
public Tensor X { get; set; }
public Tensor Scale { get; set; }
public Tensor ReserveSpace1 { get; set; }
public Tensor ReserveSpace2 { get; set; }
public Tensor ReserveSpace3 { get; set; }
public float Epsilon { get; set; }
public string DataFormat { get; set; }
public bool IsTraining { get; set; }

public FusedBatchNormParams()
{
Epsilon = 0.0001f;
DataFormat = "NHWC";
IsTraining = true;
}
}
}

+ 29
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -156,6 +156,35 @@ namespace Tensorflow.Operations
return op.output;
}

/// <summary>
/// Gradient for batch normalization.
/// </summary>
/// <param name="y_backprop"></param>
/// <param name="x"></param>
/// <param name="scale"></param>
/// <param name="reserve_space_1"></param>
/// <param name="reserve_space_2"></param>
/// <param name="epsilon"></param>
/// <param name="data_format"></param>
/// <param name="is_training"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params)
{
var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new
{
y_backprop = @params.YBackprop,
x = @params.X,
scale = @params.Scale,
reserve_space_1 = @params.ReserveSpace1,
reserve_space_2 = @params.ReserveSpace2,
epsilon = @params.Epsilon,
data_format = @params.DataFormat,
is_training = @params.IsTraining
});
return op.outputs;
}

public static Tensor[] fused_batch_norm(Tensor x,
Tensor scale,
Tensor offset,


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Input.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow
for (int i = 0; i < NumInputs; i++)
{
var tf_output = Input(i);
var op = new Operation(tf_output.oper);
var op = GetOperation(tf_output.oper);
retval[i] = op.outputs[tf_output.index];
}


+ 41
- 0
src/TensorFlowNET.Core/Operations/Operation.Instance.cs View File

@@ -0,0 +1,41 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;

namespace Tensorflow
{
public partial class Operation
{
// cache the mapping between managed and unmanaged op
// some data is stored in managed instance, so when
// create Operation by IntPtr, it will lost some data.
private static Dictionary<IntPtr, Operation> OpInstances = new Dictionary<IntPtr, Operation>();

/// <summary>
/// Get operation by handle
/// </summary>
/// <param name="handle"></param>
/// <returns></returns>
public Operation GetOperation(IntPtr handle)
{
return OpInstances.ContainsKey(handle) ?
OpInstances[handle] :
new Operation(handle);
}
}
}

+ 9
- 8
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -84,9 +84,10 @@ namespace Tensorflow
_control_flow_context = _graph._get_control_flow_context();

// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
OpInstances[_handle] = this;
}

public Operation(Graph g, string opType, string oper_name)
/*public Operation(Graph g, string opType, string oper_name)
{
_graph = g;

@@ -102,7 +103,7 @@ namespace Tensorflow
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
}
}*/

/// <summary>
/// Creates an `Operation`.
@@ -151,11 +152,6 @@ namespace Tensorflow
}
}

if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1")
{
}

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
@@ -164,7 +160,7 @@ namespace Tensorflow
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);

var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

// Initialize self._outputs.
@@ -180,6 +176,8 @@ namespace Tensorflow

if (_handle != IntPtr.Zero)
_control_flow_post_processing();

OpInstances[_handle] = this;
}

public void run(FeedItem[] feed_dict = null, Session session = null)
@@ -220,6 +218,9 @@ namespace Tensorflow
return grouped_inputs.ToArray();
}

public T get_attr<T>(string name)
=> (T)get_attr(name);

public object get_attr(string name)
{
AttrValue x = null;


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -611,7 +611,7 @@ namespace Tensorflow
});
}
public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name);
public static Tensor stack(object values, int axis = 0, string name = "stack")


src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs → src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -518,7 +518,7 @@ namespace Tensorflow
inputs = inputs.Select(inp =>
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true))
.ToArray();
return gen_control_flow_ops.merge(inputs, name).Item1;
return gen_control_flow_ops.merge(inputs, name)[0];
});
}

@@ -557,8 +557,31 @@ namespace Tensorflow
throw new NotImplementedException("ZerosLikeOutsideLoop");
return array_ops.zeros_like(val, optimize: false);
}

throw new NotImplementedException("ZerosLikeOutsideLoop");
else
{
var op_ctxt = op._get_control_flow_context();
if(op_ctxt != null)
{
// We are in a cond context. Use a switch to create zeros only when needed.
var pred = op_ctxt.pred;
var branch = op_ctxt.branch;
var switch_val = @switch(op.inputs[0], pred)[1 - branch];
var pivot = array_ops.identity(switch_val);
if (val.dtype == dtypes.resource)
throw new NotImplementedException("");
var zeros_shape = array_ops.shape_internal(switch_val, optimize: false);
// Ensure ops created within array_ops.zeros are dominated by switch in
// cond context.
return tf_with(ops.control_dependencies(new[] { pivot }), delegate
{
return array_ops.zeros(zeros_shape, dtype: val.dtype);
});
}
else
{
return array_ops.zeros_like(val, optimize: false);
}
}
}

/// <summary>

+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -475,7 +475,7 @@ namespace Tensorflow
return op.output;
}

public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null)
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size });
return _op.outputs[0];


src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs → src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs View File

@@ -148,11 +148,18 @@ namespace Tensorflow
return new []{_op.outputs[0], _op.outputs[1]};
}

public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null)
public static Tensor[] ref_merge(Tensor[] inputs, string name = null)
{
var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs });

return _op.outputs;
}

public static Tensor[] merge(Tensor[] inputs, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs });

return (_op.outputs[0], _op.outputs[1]);
return _op.outputs;
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs View File

@@ -183,5 +183,19 @@ namespace Tensorflow

return op.output;
}

public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false,
bool half_pixel_centers = false, string name = null)
{
var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new
{
grads,
size,
align_corners,
half_pixel_centers
});

return op.output;
}
}
}

+ 16
- 21
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -105,10 +105,13 @@ namespace Tensorflow

if (_handle == IntPtr.Zero)
{
var status = new Status();
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
} else
using (var status = new Status())
{
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
}
}
else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
@@ -119,14 +122,15 @@ namespace Tensorflow

set
{
var status = new Status();
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
using (var status = new Status())
{
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);

status.Check(true);
status.Check(true);
}
}
}

@@ -142,16 +146,7 @@ namespace Tensorflow
/// </summary>
public void set_shape(TensorShape shape)
{
this.shape = (int[]) shape.dims.Clone();
}

/// <summary>
/// Updates the shape of this tensor.
/// </summary>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public void SetShape(TensorShape shape)
{
this.shape = (int[]) shape.dims.Clone();
this.shape = shape.rank > 0 ? shape.dims : null;
}

/// <summary>


+ 1
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -33,6 +33,7 @@ namespace Tensorflow
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static TF_DataType resource = TF_DataType.TF_RESOURCE;

/// <summary>
///


+ 19
- 18
src/TensorFlowNET.Core/ops.cs View File

@@ -227,29 +227,30 @@ namespace Tensorflow
throw new NotImplementedException("_create_c_op");
}

var status = new Status();

// Add control inputs
foreach (var control_input in control_inputs)
c_api.TF_AddControlInput(op_desc, control_input);

// Add attrs
foreach (var attr in node_def.Attr)
using (var status = new Status())
{
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak
Marshal.Copy(bytes, 0, proto, bytes.Length);
uint len = (uint) bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);
// Add control inputs
foreach (var control_input in control_inputs)
c_api.TF_AddControlInput(op_desc, control_input);

status.Check(true);
}
// Add attrs
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
var protoHandle = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, protoHandle, bytes.Length);
uint len = (uint)bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status);
status.Check(true);
Marshal.FreeHGlobal(protoHandle);
}

var c_op = c_api.TF_FinishOperation(op_desc, status);
var c_op = c_api.TF_FinishOperation(op_desc, status);

status.Check(true);
status.Check(true);

return c_op;
return c_op;
}
}
}



+ 20
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs View File

@@ -54,6 +54,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
List<RefVariable> first_stage_trainable_var_list;
Operation train_op_with_frozen_variables;
Operation train_op_with_all_variables;
Saver loader;
Saver saver;
#endregion

public bool Run()
@@ -74,7 +76,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

public void Train(Session sess)
{

sess.run(tf.global_variables_initializer());
print($"=> Restoring weights from: {cfg.TRAIN.INITIAL_WEIGHT} ... ");
loader.restore(sess, cfg.TRAIN.INITIAL_WEIGHT);
}

public void Test(Session sess)
@@ -184,6 +188,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
});
});

tf_with(tf.name_scope("loader_and_saver"), delegate
{
loader = tf.train.Saver(net_var);
saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10);
});

tf_with(tf.name_scope("summary"), delegate
{
tf.summary.scalar("learn_rate", learn_rate);
tf.summary.scalar("giou_loss", giou_loss);
tf.summary.scalar("conf_loss", conf_loss);
tf.summary.scalar("prob_loss", prob_loss);
tf.summary.scalar("total_loss", loss);
});

return graph;
}



+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs View File

@@ -60,7 +60,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO
public TrainConfig(string root)
{
_root = root;
INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt");
INITIAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt");
ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt");
}
}


Loading…
Cancel
Save