Browse Source

Use operation with customized C API.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
1d1657dd2c
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
4 changed files with 22 additions and 32 deletions
  1. +13
    -0
      src/TensorFlowNET.Core/APIs/c_api.customize.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  3. +7
    -30
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 13
- 0
src/TensorFlowNET.Core/APIs/c_api.customize.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class c_api
{
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -223,7 +223,7 @@ namespace Tensorflow.Functions
{
input_tangents = new TangentInfo();
}
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER)
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient())
{
if(input_tangents.Indices is not null || executing_eagerly)
{


+ 7
- 30
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -317,27 +317,18 @@ namespace Tensorflow
{
Debug.Assert(types.Length == shapes.Length);
int orig_num_outputs = this.outputs.Length;
//var new_outputs = new List<Tensor>(_outputs);

var old_outputs = _outputs;
_outputs = new Tensor[orig_num_outputs + types.Length];
for(int i = 0; i < orig_num_outputs; i++)
{
_outputs[i] = old_outputs[i];
}
var new_outputs = new List<Tensor>(_outputs);

// Since the `_outputs` is defined as `Array`, when we add new output, we
// have to create a new array, which brings some performance concerns.
// In the future maybe the type of `outputs` should be reconsidered.
for(int i = 0; i < types.Length; i++)
{
var t = new Tensor(this, orig_num_outputs + 1, types[i]);
_outputs[i] = t;
//t = tf.ensure_shape(t, shapes[i]);
var t = new Tensor(this, orig_num_outputs + i, types[i]);
t.shape = shapes[i];
//new_outputs.Add(t);
new_outputs.Add(t);
}
//_outputs = new_outputs.ToArray();
_outputs = new_outputs.ToArray();
}

internal void _set_func_attr(string attr_name, string func_name)
@@ -372,23 +363,9 @@ namespace Tensorflow

internal void _set_attr_with_buf(string attr_name, Buffer attr_buf)
{
//if(_op_desc is null)
//{
// //var new_node_def = NodeDef.Parser.ParseFrom(node_def.ToByteArray());
// //new_node_def.Name += "_temp";
// //var op = new Operation(new_node_def, graph, inputs, _output_types, control_inputs, _input_types);
// //Status status = new();
// //c_api.TF_SetAttrBool(op._op_desc, "trainable", true);
// ////c_api.TF_SetAttrValueProto(op._op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status);
// //status.Check(true);
// // TODO(Rinne): deal with it. Give a warning or make the Operation always contains `op_desc`.
//}
//else
//{
// //Status status = new();
// //c_api.TF_SetAttrValueProto(_op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status);
// //status.Check(true);
//}
Status status = new();
c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status);
status.Check(true);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -135,7 +135,7 @@ namespace Tensorflow

protected virtual void SetShapeInternal(Shape value)
{
if (value == null)
if (value is null || value.ndim == 0 || value.ndim == -1)
c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), null, -1, tf.Status);
else
c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), value.dims, value.ndim, tf.Status);


Loading…
Cancel
Save