Browse Source

Revise some implementation details.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
c8691387ab
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
11 changed files with 31 additions and 65 deletions
  1. +4
    -0
      src/TensorFlowNET.Core/APIs/c_api.customize.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.gradients.cs
  3. +0
    -36
      src/TensorFlowNET.Core/Framework/importer.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  5. +18
    -1
      src/TensorFlowNET.Core/Functions/Function.cs
  6. +2
    -2
      src/TensorFlowNET.Core/Functions/IGenericFunction.cs
  7. +2
    -2
      src/TensorFlowNET.Core/Functions/TracingCompiler.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  9. +0
    -15
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  10. +0
    -4
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  11. +2
    -2
      src/TensorFlowNET.Core/Operations/gen_ops.cs

+ 4
- 0
src/TensorFlowNET.Core/APIs/c_api.customize.cs View File

@@ -9,5 +9,9 @@ namespace Tensorflow
{
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
}

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.gradients.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow
{
public partial class tensorflow
{
internal GradientTape _tapeSet;
GradientTape _tapeSet;

/// <summary>
/// Record operations for automatic differentiation.


+ 0
- 36
src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -79,42 +79,6 @@ namespace Tensorflow
return _GatherReturnElements(return_elements, graph, results);
}

//private static ITensorOrOperation[] _import_graph_def_internal(GraphDef graph_def, Dictionary<string, Tensor> input_map = null, string[] return_elements = null,
// bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null)
//{
// graph_def = _ProcessGraphDefParam(graph_def);
// input_map = _ProcessInputMapParam(input_map);
// return_elements = _ProcessReturnElementsParam(return_elements);

// if(producer_op_list is not null)
// {
// _RemoveDefaultAttrs(producer_op_list, graph_def);
// }

// var graph = ops.get_default_graph();
// string prefix = null;
// tf_with(ops.name_scope(name, "import", input_map.Values), scope =>
// {
// if (scope is not null)
// {
// Debug.Assert(scope.scope_name.EndsWith("/"));
// prefix = scope.scope_name[scope.scope_name.Length - 1].ToString();
// }
// else
// {
// prefix = "";
// }

// input_map = _ConvertInputMapValues(name, input_map);
// });

// var scope_options = c_api_util.ScopedTFImportGraphDefOptions();
// var options = scope_options.Options;
// _PopulateTFImportGraphDefOptions(scope_options, prefix, input_map, return_elements, validate_colocation_constraints);
//}

private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements,
Graph graph,
TF_ImportGraphDefResults results)


+ 1
- 1
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -305,7 +305,7 @@ namespace Tensorflow.Functions

private Tensors _build_call_outputs(Tensors result)
{
// TODO(Rinne): dwal with `func_graph.structured_outputs`
// TODO(Rinne): deal with `func_graph.structured_outputs`

return result;
}


+ 18
- 1
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -4,7 +4,7 @@ using Tensorflow.Train;

namespace Tensorflow
{
public class Function: Trackable
public class Function: Trackable, IGenericFunction
{
#pragma warning disable CS0169 // The field 'Function._handle' is never used
private IntPtr _handle;
@@ -34,6 +34,11 @@ namespace Tensorflow
return result;
}

public ConcreteFunction get_concrete_function(params Tensor[] args)
{
return _get_concrete_function_garbage_collected(args);
}

protected virtual Tensors _call(Tensors inputs)
{
if(_variable_creation_fn is not null)
@@ -57,6 +62,18 @@ namespace Tensorflow
return false;
}

protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args)
{
if(_variable_creation_fn is null)
{
_initialize(args);
// TODO(Rinne): _initialize_uninitialized_variables
}

var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
return concrete;
}

private void _initialize(Tensor[] args)
{
_variable_creation_fn = _compiler(_csharp_function);


+ 2
- 2
src/TensorFlowNET.Core/Functions/IGenericFunction.cs View File

@@ -6,7 +6,7 @@ namespace Tensorflow.Functions
{
public interface IGenericFunction
{
object[] Apply(params object[] args);
ConcreteFunction get_concrete_function(params object[] args);
Tensors Apply(Tensors args);
ConcreteFunction get_concrete_function(params Tensor[] args);
}
}

+ 2
- 2
src/TensorFlowNET.Core/Functions/TracingCompiler.cs View File

@@ -49,7 +49,7 @@ namespace Tensorflow.Functions

private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args)
{
var lookup_func_key = male_cache_key(args);
var lookup_func_key = make_cache_key(args);
if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function))
{
return (concrete_function, args);
@@ -71,7 +71,7 @@ namespace Tensorflow.Functions
return concrete_function;
}

private static string male_cache_key(Tensor[] inputs)
private static string make_cache_key(Tensor[] inputs)
{
//string res = "";
//foreach (var input in inputs)


+ 1
- 1
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -727,7 +727,7 @@ namespace Tensorflow

private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn)
{
//scope = scope.TrimEnd('/').Replace('/', '_');
// scope = scope.TrimEnd('/').Replace('/', '_');
return grad_fn(op, out_grads);
}



+ 0
- 15
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -38,21 +38,6 @@ namespace Tensorflow.Graphs

// make function as an Operation by autograph
// need to restore mode when exits

//var func_graph = new FuncGraph(func_name);
//func_graph.as_default();
//var input_placeholders = args.Arguments.Select(x => tf.placeholder(((Tensor)x).dtype)).ToArray();
//// stop the function from recursive call.
//already_in_boundary = true;
//var outputs = args.Method.Invoke(args.Instance, input_placeholders) as Tensors;
//already_in_boundary = false;

//var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
//func_graph.ToGraph(opers,
// input_placeholders,
// outputs,
// null);
//func_graph.Exit();
function = new ConcreteFunction(func_name);
function.Enter();



+ 0
- 4
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -208,9 +208,5 @@ namespace Tensorflow

[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
}

+ 2
- 2
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -10060,7 +10060,7 @@ namespace Tensorflow.Operations
}
catch (Exception)
{
Console.WriteLine();
}
try
{
@@ -10068,7 +10068,7 @@ namespace Tensorflow.Operations
}
catch (Exception)
{
Console.WriteLine();
}
}



Loading…
Cancel
Save