You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Reshape.cs 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. using Tensorflow.Keras.ArgsDefinition;
  2. using Tensorflow.Keras.Engine;
  3. using static Tensorflow.Binding;
  4. using System.Collections.Generic;
  5. using System;
  6. using System.Linq;
  7. namespace Tensorflow.Keras.Layers
  8. {
  9. /// <summary>
  10. /// Layer that reshapes inputs into the given shape.
  11. /// </summary>
  12. public class Reshape : Layer
  13. {
  14. ReshapeArgs args;
  15. public Reshape(ReshapeArgs args)
  16. : base(args)
  17. {
  18. this.args = args;
  19. }
  20. protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
  21. {
  22. var shapes = new List<Tensor>();
  23. shapes.Add(array_ops.shape(inputs)[0]);
  24. var dtype = shapes[0].dtype;
  25. if (args.TargetShapeObjects != null)
  26. // shapes.AddRange(args.TargetShapeObjects);
  27. throw new NotImplementedException("");
  28. if (args.TargetShape != null)
  29. shapes.AddRange(args.TargetShape.dims.Select(x => constant_op.constant(x, dtype)));
  30. var shape = ops.convert_to_tensor(shapes);
  31. var result = array_ops.reshape(inputs, shape);
  32. if (!tf.Context.executing_eagerly())
  33. result.shape = ComputeOutputShape(inputs.shape);
  34. return result;
  35. }
  36. public override Shape ComputeOutputShape(Shape input_shape)
  37. {
  38. if (input_shape.dims.Skip(1).Contains(-1))
  39. {
  40. throw new NotImplementedException("");
  41. }
  42. else
  43. {
  44. input_shape = new Shape(input_shape.dims[0]);
  45. var output_shape = input_shape.concatenate(args.TargetShape.dims);
  46. return output_shape;
  47. }
  48. }
  49. }
  50. }