You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

IEagerRunner.cs 1.6 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. using System;
  2. using Tensorflow.Contexts;
  3. using Tensorflow.Gradients;
  4. using static Tensorflow.tensorflow;
  5. namespace Tensorflow.Eager
  6. {
  7. public interface IEagerRunner
  8. {
  9. Tensor[] Execute(Context ctx, string op_name,
  10. int num_outputs,
  11. Tensor[] inputs,
  12. object[] attrs,
  13. string name = null);
  14. (TF_DataType, Tensor[]) ArgsToMatchingEager(Context ctx,
  15. TF_DataType default_dtype = TF_DataType.DtInvalid,
  16. object[] args = null);
  17. Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info);
  18. Tensor[] TFE_Execute(Context ctx,
  19. string device_name,
  20. string op_name,
  21. Tensor[] inputs,
  22. object[] attrs,
  23. int num_outputs);
  24. Tensor[] TFE_TapeGradient(ITape tape,
  25. Tensor[] target,
  26. Tensor[] sources,
  27. List<Tensor> output_gradients,
  28. Tensor[] sources_raw,
  29. string unconnected_gradients);
  30. void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
  31. Tensor[] input_tensors, BackwardFunction backward_function);
  32. int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors);
  33. bool RecordGradient(string op_name,
  34. Tensor[] inputs,
  35. object[] attrs,
  36. Tensor[] results,
  37. BackwardFunction getBackwardFunction = null);
  38. bool MustRecordGradient();
  39. int TapeSetPossibleGradientTypes(params Tensor[] args);
  40. void ClearEagerOperationMap();
  41. }
  42. }