You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TracingCompiler.cs 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Security.Cryptography.X509Certificates;
  4. using System.Text;
  5. using Tensorflow.Graphs;
  6. namespace Tensorflow.Functions
  7. {
  8. public class TracingCompiler
  9. {
  10. Func<Tensor[], Tensor[]> _csharp_function;
  11. //FunctionSpec _function_spec;
  12. internal string _name;
  13. bool _autograph;
  14. Dictionary<string, ConcreteFunction> _function_cache;
  15. Dictionary<string, AttrValue> _function_attributes;
  16. int _tracing_count;
  17. public TracingCompiler(Func<Tensor[], Tensor[]> csharp_function, string name, object? input_signatures = null,
  18. Dictionary<string, AttrValue> attributes = null, bool autograph = true, object? autograph_options = null,
  19. bool reduce_retracing = false, bool capture_by_value = false)
  20. {
  21. _csharp_function = csharp_function;
  22. bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME);
  23. _name = name;
  24. _autograph = autograph;
  25. _function_attributes = attributes ?? new Dictionary<string, AttrValue>();
  26. _function_cache = new Dictionary<string, ConcreteFunction>();
  27. _tracing_count = 0;
  28. }
  29. public Tensor[] Apply(Tensor[] inputs)
  30. {
  31. // TODO(Rinne): add lock here.
  32. var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs);
  33. return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs);
  34. }
  35. internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args)
  36. {
  37. var (concrete_function, _) = _maybe_define_concrete_function(args);
  38. return concrete_function;
  39. }
  40. private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args)
  41. {
  42. return _maybe_define_function(args);
  43. }
  44. private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args)
  45. {
  46. var lookup_func_key = make_cache_key(args);
  47. if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function))
  48. {
  49. return (concrete_function, args);
  50. }
  51. concrete_function = _create_concrete_function(args);
  52. _function_cache[lookup_func_key] = concrete_function;
  53. return (concrete_function, args);
  54. }
  55. private ConcreteFunction _create_concrete_function(Tensor[] args)
  56. {
  57. _tracing_count++;
  58. int arglen = args.Length;
  59. var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func(
  60. _name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()),
  61. args, new Dictionary<string, object>(), autograph: _autograph
  62. ), _function_attributes);
  63. return concrete_function;
  64. }
  65. private static string make_cache_key(Tensor[] inputs)
  66. {
  67. //string res = "";
  68. //foreach (var input in inputs)
  69. //{
  70. // res += $"{input.name}_{input.Id}";
  71. //}
  72. return inputs.Length.ToString();
  73. }
  74. }
  75. }