You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NDArray.Creation.cs 2.3 kB

4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Eager;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow.NumPy
  7. {
  8. public partial class NDArray
  9. {
  10. public NDArray(bool value) : base(value) { NewEagerTensorHandle(); }
  11. public NDArray(byte value) : base(value) { NewEagerTensorHandle(); }
  12. public NDArray(short value) : base(value) { NewEagerTensorHandle(); }
  13. public NDArray(int value) : base(value) { NewEagerTensorHandle(); }
  14. public NDArray(long value) : base(value) { NewEagerTensorHandle(); }
  15. public NDArray(float value) : base(value) { NewEagerTensorHandle(); }
  16. public NDArray(double value) : base(value) { NewEagerTensorHandle(); }
  17. public NDArray(Array value, Shape? shape = null)
  18. : base(value, shape) { NewEagerTensorHandle(); }
  19. public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
  20. : base(shape, dtype: dtype) { NewEagerTensorHandle(); }
  21. public NDArray(byte[] bytes, Shape shape, TF_DataType dtype)
  22. : base(bytes, shape, dtype) { NewEagerTensorHandle(); }
  23. public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
  24. : base(address, shape, dtype) { NewEagerTensorHandle(); }
  25. public NDArray(Tensor tensor) : base(tensor.Handle)
  26. {
  27. if (_handle is null)
  28. {
  29. tensor = tf.defaultSession.eval(tensor);
  30. _handle = tensor.Handle;
  31. }
  32. NewEagerTensorHandle();
  33. }
  34. public static NDArray Scalar<T>(T value) where T : unmanaged
  35. => value switch
  36. {
  37. bool val => new NDArray(val),
  38. byte val => new NDArray(val),
  39. int val => new NDArray(val),
  40. long val => new NDArray(val),
  41. float val => new NDArray(val),
  42. double val => new NDArray(val),
  43. _ => throw new NotImplementedException("")
  44. };
  45. void NewEagerTensorHandle()
  46. {
  47. _id = ops.uid();
  48. _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
  49. tf.Status.Check(true);
  50. }
  51. }
  52. }