You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

KerasInterface.cs 4.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Reflection;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Datasets;
  7. using Tensorflow.Keras.Engine;
  8. using Tensorflow.Keras.Layers;
  9. using Tensorflow.Keras.Losses;
  10. using Tensorflow.Keras.Metrics;
  11. using Tensorflow.Keras.Models;
  12. using Tensorflow.Keras.Optimizers;
  13. using Tensorflow.Keras.Utils;
  14. using System.Threading;
  15. using Tensorflow.Framework.Models;
  16. namespace Tensorflow.Keras
  17. {
  18. public class KerasInterface : IKerasApi
  19. {
  20. private static KerasInterface _instance = null;
  21. private static readonly object _lock = new object();
  22. public static KerasInterface Instance
  23. {
  24. get
  25. {
  26. lock (_lock)
  27. {
  28. if (_instance is null)
  29. {
  30. _instance = new KerasInterface();
  31. }
  32. return _instance;
  33. }
  34. }
  35. }
  36. public KerasDataset datasets { get; } = new KerasDataset();
  37. public IInitializersApi initializers { get; } = new InitializersApi();
  38. public Regularizers regularizers { get; } = new Regularizers();
  39. public ILayersApi layers { get; } = new LayersApi();
  40. public ILossesApi losses { get; } = new LossesApi();
  41. public IActivationsApi activations { get; } = new Activations();
  42. public Preprocessing preprocessing { get; } = new Preprocessing();
  43. ThreadLocal<BackendImpl> _backend = new ThreadLocal<BackendImpl>(() => new BackendImpl());
  44. public BackendImpl backend => _backend.Value;
  45. public IOptimizerApi optimizers { get; } = new OptimizerApi();
  46. public IMetricsApi metrics { get; } = new MetricsApi();
  47. public IModelsApi models { get; } = new ModelsApi();
  48. public KerasUtils utils { get; } = new KerasUtils();
  49. public Sequential Sequential(List<ILayer> layers = null,
  50. string name = null)
  51. => new Sequential(new SequentialArgs
  52. {
  53. Layers = layers,
  54. Name = name
  55. });
  56. public Sequential Sequential(params ILayer[] layers)
  57. => new Sequential(new SequentialArgs
  58. {
  59. Layers = layers.ToList()
  60. });
  61. /// <summary>
  62. /// `Model` groups layers into an object with training and inference features.
  63. /// </summary>
  64. /// <param name="input"></param>
  65. /// <param name="output"></param>
  66. /// <returns></returns>
  67. public IModel Model(Tensors inputs, Tensors outputs, string name = null)
  68. => new Functional(inputs, outputs, name: name);
  69. /// <summary>
  70. /// Instantiate a Keras tensor.
  71. /// </summary>
  72. /// <param name="shape"></param>
  73. /// <param name="batch_size"></param>
  74. /// <param name="dtype"></param>
  75. /// <param name="name"></param>
  76. /// <param name="sparse">
  77. /// A boolean specifying whether the placeholder to be created is sparse.
  78. /// </param>
  79. /// <param name="ragged">
  80. /// A boolean specifying whether the placeholder to be created is ragged.
  81. /// </param>
  82. /// <param name="tensor">
  83. /// Optional existing tensor to wrap into the `Input` layer.
  84. /// If set, the layer will not create a placeholder tensor.
  85. /// </param>
  86. /// <returns></returns>
  87. public Tensors Input(Shape shape = null,
  88. int batch_size = -1,
  89. string name = null,
  90. TF_DataType dtype = TF_DataType.DtInvalid,
  91. bool sparse = false,
  92. Tensor tensor = null,
  93. bool ragged = false,
  94. TypeSpec type_spec = null,
  95. Shape batch_input_shape = null,
  96. Shape batch_shape = null) => keras.layers.Input(shape, batch_size, name,
  97. dtype, sparse, tensor, ragged, type_spec, batch_input_shape, batch_shape);
  98. }
  99. }