You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

load.cs 3.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. using Google.Protobuf;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Text;
  6. using Tensorflow.Keras.Engine;
  7. using Tensorflow.Train;
  8. using ThirdParty.Tensorflow.Python.Keras.Protobuf;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. namespace Tensorflow.Keras.Saving.SavedModel
  12. {
  13. public class KerasLoadModelUtils
  14. {
  15. /// <summary>
  16. /// Corresponding to keras/saving/save.py/load_model
  17. /// </summary>
  18. /// <param name="filepath"></param>
  19. /// <param name="custom_objects"></param>
  20. /// <param name="compile"></param>
  21. /// <param name="options"></param>
  22. /// <returns></returns>
  23. public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null,
  24. bool compile = true, LoadOptions? options = null)
  25. {
  26. using (SharedObjectSavingScope.Enter())
  27. {
  28. using (LoadContext.load_context(options))
  29. {
  30. if (!File.Exists(filepath) && !Directory.Exists(filepath))
  31. {
  32. throw new IOException($"No file or directory found at {filepath}.");
  33. }
  34. if (Directory.Exists(filepath))
  35. {
  36. return load(filepath, compile, options);
  37. }
  38. else
  39. {
  40. throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
  41. }
  42. }
  43. }
  44. }
  45. private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
  46. {
  47. SavedMetadata metadata = new SavedMetadata();
  48. var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
  49. var object_graph_def = meta_graph_def.ObjectGraphDef;
  50. string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
  51. if (File.Exists(path_to_metadata_pb))
  52. {
  53. metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read));
  54. }
  55. else
  56. {
  57. throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" +
  58. " use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it.");
  59. }
  60. if (metadata.Nodes is null || metadata.Nodes.Count == 0)
  61. {
  62. return Loader.load(path, options: options) as Model;
  63. }
  64. var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
  65. keras_loader.load_layers(compile: compile);
  66. Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new();
  67. nodes_to_load["root"] = (null, null);
  68. foreach(var item in keras_loader.LoadedNodes)
  69. {
  70. nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
  71. }
  72. var loaded = Loader.load_partial(path, nodes_to_load, options);
  73. keras_loader.finalize_objects();
  74. // keras_loader.del_tracking();
  75. var model = loaded["root"];
  76. if(model is Model && compile)
  77. {
  78. // TODO(Rinne): implement it.
  79. }
  80. if (!tf.Context.executing_eagerly())
  81. {
  82. // TODO(Rinne): implement it.
  83. }
  84. return model;
  85. }
  86. }
  87. }