You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

KerasInterface.cs 3.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Reflection;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Datasets;
  7. using Tensorflow.Keras.Engine;
  8. using Tensorflow.Keras.Layers;
  9. using Tensorflow.Keras.Losses;
  10. using Tensorflow.Keras.Metrics;
  11. using Tensorflow.Keras.Models;
  12. using Tensorflow.Keras.Optimizers;
  13. using Tensorflow.Keras.Saving;
  14. using Tensorflow.Keras.Utils;
  15. namespace Tensorflow.Keras
  16. {
  17. public class KerasInterface
  18. {
  19. public KerasDataset datasets { get; } = new KerasDataset();
  20. public Initializers initializers { get; } = new Initializers();
  21. public Regularizers regularizers { get; } = new Regularizers();
  22. public LayersApi layers { get; } = new LayersApi();
  23. public LossesApi losses { get; } = new LossesApi();
  24. public Activations activations { get; } = new Activations();
  25. public Preprocessing preprocessing { get; } = new Preprocessing();
  26. public BackendImpl backend { get; } = new BackendImpl();
  27. public OptimizerApi optimizers { get; } = new OptimizerApi();
  28. public MetricsApi metrics { get; } = new MetricsApi();
  29. public ModelsApi models { get; } = new ModelsApi();
  30. public KerasUtils utils { get; } = new KerasUtils();
  31. public Sequential Sequential(List<ILayer> layers = null,
  32. string name = null)
  33. => new Sequential(new SequentialArgs
  34. {
  35. Layers = layers,
  36. Name = name
  37. });
  38. /// <summary>
  39. /// `Model` groups layers into an object with training and inference features.
  40. /// </summary>
  41. /// <param name="input"></param>
  42. /// <param name="output"></param>
  43. /// <returns></returns>
  44. public Functional Model(Tensors inputs, Tensors outputs, string name = null)
  45. => new Functional(inputs, outputs, name: name);
  46. /// <summary>
  47. /// Instantiate a Keras tensor.
  48. /// </summary>
  49. /// <param name="shape"></param>
  50. /// <param name="batch_size"></param>
  51. /// <param name="dtype"></param>
  52. /// <param name="name"></param>
  53. /// <param name="sparse">
  54. /// A boolean specifying whether the placeholder to be created is sparse.
  55. /// </param>
  56. /// <param name="ragged">
  57. /// A boolean specifying whether the placeholder to be created is ragged.
  58. /// </param>
  59. /// <param name="tensor">
  60. /// Optional existing tensor to wrap into the `Input` layer.
  61. /// If set, the layer will not create a placeholder tensor.
  62. /// </param>
  63. /// <returns></returns>
  64. public Tensor Input(Shape shape = null,
  65. int batch_size = -1,
  66. Shape batch_input_shape = null,
  67. TF_DataType dtype = TF_DataType.DtInvalid,
  68. string name = null,
  69. bool sparse = false,
  70. bool ragged = false,
  71. Tensor tensor = null)
  72. {
  73. if (batch_input_shape != null)
  74. shape = batch_input_shape.dims.Skip(1).ToArray();
  75. var args = new InputLayerArgs
  76. {
  77. Name = name,
  78. InputShape = shape,
  79. BatchInputShape = batch_input_shape,
  80. BatchSize = batch_size,
  81. DType = dtype,
  82. Sparse = sparse,
  83. Ragged = ragged,
  84. InputTensor = tensor
  85. };
  86. var layer = new InputLayer(args);
  87. return layer.InboundNodes[0].Outputs;
  88. }
  89. }
  90. }