You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

KerasApi.cs 3.0 kB

5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Data;
  4. using System.Linq;
  5. using Tensorflow.Keras;
  6. using Tensorflow.Keras.ArgsDefinition;
  7. using Tensorflow.Keras.Datasets;
  8. using Tensorflow.Keras.Engine;
  9. using Tensorflow.Keras.Layers;
  10. using Tensorflow.Keras.Losses;
  11. using static Tensorflow.Binding;
  12. namespace Tensorflow
  13. {
  14. public class KerasApi
  15. {
  16. public KerasDataset datasets { get; } = new KerasDataset();
  17. public Initializers initializers { get; } = new Initializers();
  18. public LayersApi layers { get; } = new LayersApi();
  19. public LossesApi losses { get; } = new LossesApi();
  20. public Activations activations { get; } = new Activations();
  21. public Preprocessing preprocessing { get; } = new Preprocessing();
  22. public BackendImpl backend { get; } = new BackendImpl();
  23. public Sequential Sequential(List<Layer> layers = null,
  24. string name = null)
  25. => new Sequential(new SequentialArgs
  26. {
  27. Layers = layers,
  28. Name = name
  29. });
  30. /// <summary>
  31. /// `Model` groups layers into an object with training and inference features.
  32. /// </summary>
  33. /// <param name="input"></param>
  34. /// <param name="output"></param>
  35. /// <returns></returns>
  36. public Model Model(Tensor input, Tensor output)
  37. => new Model(new ModelArgs
  38. {
  39. Inputs = new[] { input },
  40. Outputs = new[] { output }
  41. });
  42. /// <summary>
  43. /// Instantiate a Keras tensor.
  44. /// </summary>
  45. /// <param name="shape"></param>
  46. /// <param name="batch_size"></param>
  47. /// <param name="dtype"></param>
  48. /// <param name="name"></param>
  49. /// <param name="sparse">
  50. /// A boolean specifying whether the placeholder to be created is sparse.
  51. /// </param>
  52. /// <param name="ragged">
  53. /// A boolean specifying whether the placeholder to be created is ragged.
  54. /// </param>
  55. /// <param name="tensor">
  56. /// Optional existing tensor to wrap into the `Input` layer.
  57. /// If set, the layer will not create a placeholder tensor.
  58. /// </param>
  59. /// <returns></returns>
  60. public Tensor Input(TensorShape shape = null,
  61. int batch_size = -1,
  62. TF_DataType dtype = TF_DataType.DtInvalid,
  63. string name = null,
  64. bool sparse = false,
  65. bool ragged = false,
  66. Tensor tensor = null)
  67. {
  68. var args = new InputLayerArgs
  69. {
  70. Name = name,
  71. InputShape = shape,
  72. BatchSize = batch_size,
  73. DType = dtype,
  74. Sparse = sparse,
  75. Ragged = ragged,
  76. InputTensor = tensor
  77. };
  78. var layer = new InputLayer(args);
  79. return layer.InboundNodes[0].Outputs;
  80. }
  81. }
  82. }