You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AutoTrackable.cs 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. using System.Collections.Generic;
  2. using System.Linq;
  3. using Tensorflow.Functions;
  4. using Tensorflow.Operations.Activation;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow.Train
  7. {
  8. public abstract class AutoTrackable : Trackable
  9. {
  10. public void _delete_tracking(string name)
  11. {
  12. _maybe_initialize_trackable();
  13. if (_unconditional_dependency_names.ContainsKey(name))
  14. {
  15. _unconditional_dependency_names.Remove(name);
  16. for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--)
  17. {
  18. if (_unconditional_checkpoint_dependencies[i].Name == name)
  19. {
  20. _unconditional_checkpoint_dependencies.RemoveAt(i);
  21. }
  22. }
  23. }
  24. }
  25. public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
  26. {
  27. if(save_type != SaveType.SAVEDMODEL)
  28. {
  29. return base._trackable_children(save_type, cache);
  30. }
  31. Dictionary<string, Trackable> functions = new();
  32. // TODO: process of logs.
  33. var properties = this.GetType().GetProperties();
  34. foreach ( var property in properties )
  35. {
  36. string name = property.Name;
  37. object value = property.GetValue(this, null);
  38. if(value is Function || value is ConcreteFunction)
  39. {
  40. functions[name] = (Trackable)value;
  41. }
  42. }
  43. // TODO: process the type `core_types.GenericFunction`.
  44. Dictionary<string, Trackable> children = new();
  45. foreach(var pair in CheckpointDependencies)
  46. {
  47. var name = pair.Name;
  48. var child = pair.Refer;
  49. if(child is ConcreteFunction) // or Generic function
  50. {
  51. continue;
  52. }
  53. if(functions.ContainsKey(name) && functions[name] != child)
  54. {
  55. throw new ValueError($"Can't save object because it has multiple children with the same " +
  56. $"name. Object: {this}, attribute name: {name}, child 1: " +
  57. $"{child}, child 2: {functions[name]}");
  58. }
  59. children[name] = child;
  60. }
  61. return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value);
  62. }
  63. }
  64. }