You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

execute.cs 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using System.Xml.Linq;
  6. using Tensorflow.Contexts;
  7. using static Tensorflow.ApiDef.Types;
  8. using static Tensorflow.CostGraphDef.Types;
  9. using static Tensorflow.Binding;
  10. namespace Tensorflow.Eager
  11. {
  12. internal static class execute
  13. {
  14. public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
  15. {
  16. var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx));
  17. var types = v.Select(t => t.dtype.as_datatype_enum());
  18. return (types.ToArray(), v.ToArray());
  19. }
  20. public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
  21. {
  22. return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name);
  23. }
  24. public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
  25. {
  26. string device_name = ctx.DeviceName;
  27. ctx.ensure_initialized();
  28. var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs);
  29. return tensors;
  30. }
  31. public static bool must_record_gradient()
  32. {
  33. return false;
  34. }
  35. }
  36. }