You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Sequential.cs 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System.Collections.Generic;
  14. using Tensorflow.Keras.ArgsDefinition;
  15. using Tensorflow.Keras.Layers;
  16. using static Tensorflow.KerasApi;
  17. namespace Tensorflow.Keras.Engine
  18. {
  19. /// <summary>
  20. /// `Sequential` groups a linear stack of layers into a `tf.keras.Model`.
  21. /// `Sequential` provides training and inference features on this model.
  22. /// </summary>
  23. public class Sequential : Model
  24. {
  25. SequentialArgs args;
  26. bool _is_graph_network;
  27. Tensor inputs;
  28. Tensor outputs;
  29. bool computeOutputAndMaskJointly;
  30. bool autoTrackSubLayers;
  31. TensorShape inferredInputShape;
  32. bool hasExplicitInputShape;
  33. TF_DataType inputDType;
  34. List<ILayer> layers => args.Layers;
  35. public TensorShape output_shape => outputs.TensorShape;
  36. bool built = false;
  37. public Sequential(SequentialArgs args)
  38. : base(new ModelArgs
  39. {
  40. Name = args.Name
  41. })
  42. {
  43. this.args = args;
  44. if (args.Layers == null)
  45. args.Layers = new List<ILayer>();
  46. // SupportsMasking = true;
  47. computeOutputAndMaskJointly = true;
  48. autoTrackSubLayers = false;
  49. hasExplicitInputShape = false;
  50. _is_graph_network = false;
  51. }
  52. public void add(Tensor tensor)
  53. {
  54. var layer = tensor.KerasHistory.Layer as Layer;
  55. add(layer);
  56. }
  57. /// <summary>
  58. /// Adds a layer instance on top of the layer stack.
  59. /// </summary>
  60. /// <param name="layer"></param>
  61. public void add(Layer layer)
  62. {
  63. built = false;
  64. var set_inputs = false;
  65. if (layers.Count == 0)
  66. {
  67. if (layer is InputLayer)
  68. {
  69. set_inputs = true;
  70. }
  71. else
  72. {
  73. if (layer.BatchInputShape != null)
  74. {
  75. // Instantiate an input layer.
  76. var x = keras.Input(
  77. shape: layer.BatchInputShape,
  78. dtype: layer.DType,
  79. name: layer.Name + "_input");
  80. // This will build the current layer
  81. // and create the node connecting the current layer
  82. // to the input layer we just created.
  83. layer.Apply(x);
  84. set_inputs = true;
  85. }
  86. }
  87. if (set_inputs)
  88. {
  89. // If an input layer (placeholder) is available.
  90. outputs = layer.InboundNodes[^1].Outputs;
  91. }
  92. }
  93. else if (outputs != null)
  94. {
  95. outputs = layer.Apply(outputs);
  96. }
  97. if (set_inputs || _is_graph_network)
  98. {
  99. _init_graph_network(inputs, outputs);
  100. }
  101. else
  102. {
  103. }
  104. }
  105. void _init_graph_network(Tensor inputs, Tensor outputs)
  106. {
  107. _is_graph_network = true;
  108. this.inputs = inputs;
  109. this.outputs = outputs;
  110. built = true;
  111. _map_graph_network(inputs, outputs);
  112. }
  113. void _map_graph_network(Tensor inputs, Tensor outputs)
  114. {
  115. layers.add(outputs.KerasHistory.Layer);
  116. }
  117. }
  118. }