You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Metric.cs 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. using System;
  2. using Tensorflow.Keras.ArgsDefinition;
  3. using Tensorflow.Keras.Engine;
  4. using static Tensorflow.Binding;
  5. namespace Tensorflow.Keras.Metrics
  6. {
  7. /// <summary>
  8. /// Encapsulates metric logic and state.
  9. /// </summary>
  10. public class Metric : Layer
  11. {
  12. protected IVariableV1 total;
  13. protected IVariableV1 count;
  14. protected string _reduction;
  15. protected TF_DataType _dtype;
  16. public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
  17. : base(new LayerArgs
  18. {
  19. Name = name,
  20. DType = dtype
  21. })
  22. {
  23. stateful = true;
  24. built = true;
  25. }
  26. protected override IVariableV1 add_weight(string name,
  27. TensorShape shape = null,
  28. TF_DataType dtype = TF_DataType.TF_FLOAT,
  29. IInitializer initializer = null,
  30. IRegularizer regularizer = null,
  31. VariableSynchronization synchronization = VariableSynchronization.OnRead,
  32. VariableAggregation aggregation = VariableAggregation.Sum,
  33. bool trainable = true,
  34. Func<VariableArgs, IVariableV1> getter = null)
  35. {
  36. if (shape == null)
  37. shape = new TensorShape(new int[0]);
  38. return tf_with(ops.init_scope(), delegate
  39. {
  40. return base.add_weight(name, shape,
  41. dtype: dtype,
  42. trainable: false,
  43. initializer: initializer,
  44. synchronization: synchronization,
  45. aggregation: aggregation);
  46. });
  47. }
  48. public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
  49. => throw new NotImplementedException("");
  50. public virtual Tensor result()
  51. => throw new NotImplementedException("");
  52. public override string ToString()
  53. => $"{name} {(float)total.numpy()}/{(float)count.numpy()}";
  54. }
  55. }