You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AutoTrackable.cs 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. using System.Collections.Generic;
  2. using System.Linq;
  3. using Tensorflow.Functions;
  4. using Tensorflow.Keras.Saving.SavedModel;
  5. using Tensorflow.Operations.Activation;
  6. using Tensorflow.Training;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Train
  9. {
  10. public class AutoTrackable : Trackable
  11. {
  12. public void _delete_tracking(string name)
  13. {
  14. _maybe_initialize_trackable();
  15. if (_unconditional_dependency_names.ContainsKey(name))
  16. {
  17. _unconditional_dependency_names.Remove(name);
  18. for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--)
  19. {
  20. if (_unconditional_checkpoint_dependencies[i].Name == name)
  21. {
  22. _unconditional_checkpoint_dependencies.RemoveAt(i);
  23. }
  24. }
  25. }
  26. }
  27. public override void SetAttr(string name, object value)
  28. {
  29. // TODO(Rinne): deal with `self_setattr_tracking`.
  30. value = TrackableDataStructure.sticky_attribute_assignment(this, name, value);
  31. base.SetAttr(name, value);
  32. }
  33. public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
  34. {
  35. if(save_type != SaveType.SAVEDMODEL)
  36. {
  37. return base._trackable_children(save_type, cache);
  38. }
  39. Dictionary<string, Trackable> functions = new();
  40. // TODO: process of logs.
  41. // TODO(Rinne): deal with members.
  42. var properties = this.GetType().GetProperties();
  43. foreach ( var property in properties )
  44. {
  45. if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction))
  46. {
  47. string name = property.Name;
  48. object value = property.GetValue(this, null);
  49. functions[name] = (Trackable)value;
  50. }
  51. }
  52. foreach(var item in CustomizedFields)
  53. {
  54. var name = item.Key;
  55. var value = item.Value;
  56. if (value is Function or ConcreteFunction)
  57. {
  58. functions[name] = (Trackable)value;
  59. }
  60. }
  61. // TODO: process the type `core_types.GenericFunction`.
  62. Dictionary<string, Trackable> children = new();
  63. foreach(var pair in CheckpointDependencies)
  64. {
  65. var name = pair.Name;
  66. var child = pair.Refer;
  67. if(child is ConcreteFunction) // or Generic function
  68. {
  69. continue;
  70. }
  71. if(functions.ContainsKey(name) && functions[name] != child)
  72. {
  73. throw new ValueError($"Can't save object because it has multiple children with the same " +
  74. $"name. Object: {this}, attribute name: {name}, child 1: " +
  75. $"{child}, child 2: {functions[name]}");
  76. }
  77. children[name] = child;
  78. }
  79. return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value);
  80. }
  81. }
  82. }