You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hdf5_format.cs 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using HDF.PInvoke;
  5. using Tensorflow.NumPy;
  6. using HDF5CSharp;
  7. using static Tensorflow.Binding;
  8. using static Tensorflow.KerasApi;
  9. using System.Linq;
  10. using System.Text.RegularExpressions;
  11. namespace Tensorflow.Keras.Saving
  12. {
  13. public class hdf5_format
  14. {
  15. private static int HDF5_OBJECT_HEADER_LIMIT = 64512;
  16. public static void load_model_from_hdf5(string filepath = "", Dictionary<string, object> custom_objects = null, bool compile = false)
  17. {
  18. long root = Hdf5.OpenFile(filepath,true);
  19. load_model_from_hdf5(root, custom_objects, compile);
  20. }
  21. public static void load_model_from_hdf5(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
  22. {
  23. //long fileId = filepath;
  24. //try
  25. //{
  26. // groupId = H5G.open(fileId, "/");
  27. // (bool success, string[] attrId) = Hdf5.ReadStringAttributes(groupId, "model_config", "");
  28. // H5G.close(groupId);
  29. // if (success == true) {
  30. // Console.WriteLine(attrId[0]);
  31. // }
  32. //}
  33. //catch (Exception ex)
  34. //{
  35. // if (filepath != -1) {
  36. // Hdf5.CloseFile(filepath);
  37. // }
  38. // if (groupId != -1) {
  39. // H5G.close(groupId);
  40. // }
  41. // throw new Exception(ex.ToString());
  42. //}
  43. }
  44. public static void save_model_to_hdf5(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
  45. {
  46. }
  47. /// <summary>
  48. /// Preprocess layer weights between different Keras formats.
  49. /// </summary>
  50. /// <param name="layer"></param>
  51. /// <param name="weights"></param>
  52. /// <param name="original_keras_version"></param>
  53. /// <param name="original_backend"></param>
  54. public static List<NDArray> preprocess_weights_for_loading(ILayer layer, List<NDArray> weights, string original_keras_version = null, string original_backend = null)
  55. {
  56. // convert CuDNN layers
  57. return _convert_rnn_weights(layer, weights);
  58. }
  59. /// <summary>
  60. /// Converts weights for RNN layers between native and CuDNN format.
  61. /// </summary>
  62. /// <param name="layer"></param>
  63. /// <param name="weights"></param>
  64. static List<NDArray> _convert_rnn_weights(ILayer layer, List<NDArray> weights)
  65. {
  66. var target_class = layer.GetType().Name;
  67. return weights;
  68. }
  69. public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
  70. {
  71. }
  72. public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
  73. {
  74. }
  75. public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
  76. {
  77. string original_keras_version = "2.5.0";
  78. string original_backend = null;
  79. var (success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", "", true);
  80. if (success)
  81. original_keras_version = attr.First();
  82. // keras version should be 2.5.0+
  83. var ver_major = int.Parse(original_keras_version.Split('.')[0]);
  84. var ver_minor = int.Parse(original_keras_version.Split('.')[1]);
  85. if (ver_major < 2 || (ver_major == 2 && ver_minor < 5))
  86. throw new ValueError("keras version should be 2.5.0 or later.");
  87. (success, attr) = Hdf5.ReadStringAttributes(f, "backend", "", true);
  88. if (success)
  89. original_backend = attr.First();
  90. var filtered_layers = new List<ILayer>();
  91. foreach (var layer in layers)
  92. {
  93. var weights = _legacy_weights(layer);
  94. if (weights.Count > 0)
  95. filtered_layers.append(layer);
  96. }
  97. string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names");
  98. var filtered_layer_names = new List<string>();
  99. foreach(var name in layer_names)
  100. {
  101. if (!filtered_layers.Select(x => x.Name).Contains(name))
  102. continue;
  103. long g = H5G.open(f, name);
  104. var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
  105. if (weight_names.Count() > 0)
  106. filtered_layer_names.Add(name);
  107. H5G.close(g);
  108. }
  109. layer_names = filtered_layer_names.ToArray();
  110. if (layer_names.Length != filtered_layers.Count())
  111. throw new ValueError("You are trying to load a weight file " +
  112. $"containing {layer_names}" +
  113. $" layers into a model with {filtered_layers.Count} layers.");
  114. var weight_value_tuples = new List<(IVariableV1, NDArray)>();
  115. foreach (var (k, name) in enumerate(layer_names))
  116. {
  117. var weight_values = new List<NDArray>();
  118. long g = H5G.open(f, name);
  119. var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
  120. foreach (var i_ in weight_names)
  121. {
  122. (success, Array result) = Hdf5.ReadDataset<float>(g, i_);
  123. if (success)
  124. weight_values.Add(np.array(result));
  125. }
  126. H5G.close(g);
  127. var layer = filtered_layers[k];
  128. var symbolic_weights = _legacy_weights(layer);
  129. preprocess_weights_for_loading(layer, weight_values, original_keras_version, original_backend);
  130. if (weight_values.Count() != symbolic_weights.Count())
  131. throw new ValueError($"Layer #{k} (named {layer.Name}" +
  132. "in the current model) was found to " +
  133. $"correspond to layer {name} in the save file." +
  134. $"However the new layer {layer.Name} expects " +
  135. $"{symbolic_weights.Count()} weights, but the saved weights have " +
  136. $"{weight_values.Count()} elements.");
  137. weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
  138. }
  139. keras.backend.batch_set_value(weight_value_tuples);
  140. }
  141. public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
  142. {
  143. }
  144. public static void load_weights_from_hdf5_group_by_name(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
  145. {
  146. }
  147. public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
  148. {
  149. List<string> layerName=new List<string>();
  150. foreach (var layer in layers)
  151. {
  152. layerName.Add(layer.Name);
  153. }
  154. save_attributes_to_hdf5_group(f, "layer_names", layerName.ToArray());
  155. Hdf5.WriteAttribute(f, "backend", "tensorflow");
  156. Hdf5.WriteAttribute(f, "keras_version", "2.5.0");
  157. foreach (var layer in layers)
  158. {
  159. var weights = _legacy_weights(layer);
  160. if (weights.Count == 0)
  161. continue;
  162. var weight_names = new List<string>();
  163. // weight_values= keras.backend.batch_get_value(weights);
  164. foreach (var weight in weights)
  165. weight_names.Add(weight.Name);
  166. var g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name));
  167. save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray());
  168. foreach (var (name, val) in zip(weight_names, weights))
  169. {
  170. var tensor = val.AsTensor();
  171. if (name.IndexOf("/") > 1)
  172. {
  173. var crDataGroup = g;
  174. string[] name_split = name.Split('/');
  175. for(int i = 0; i < name_split.Length - 1; i++)
  176. {
  177. crDataGroup = Hdf5.CreateOrOpenGroup(crDataGroup, Hdf5Utils.NormalizedName(name_split[i]));
  178. }
  179. WriteDataset(crDataGroup, name_split[name_split.Length - 1], tensor);
  180. Hdf5.CloseGroup(crDataGroup);
  181. }
  182. else
  183. {
  184. WriteDataset(g, name, tensor);
  185. }
  186. }
  187. Hdf5.CloseGroup(g);
  188. }
  189. }
  190. private static void save_attributes_to_hdf5_group(long f, string name, Array data)
  191. {
  192. int num_chunks = 1;
  193. var chunked_data = Split(data, num_chunks);
  194. int getSize = 0;
  195. string getType = data.Length > 0 ? data.GetValue(0).GetType().Name.ToLower() : "string";
  196. switch (getType)
  197. {
  198. case "single":
  199. getSize = sizeof(float);
  200. break;
  201. case "double":
  202. getSize = sizeof(double);
  203. break;
  204. case "string":
  205. getSize = -1;
  206. break;
  207. case "int32":
  208. getSize = sizeof(int);
  209. break;
  210. case "int64":
  211. getSize = sizeof(long);
  212. break;
  213. default:
  214. getSize = -1;
  215. break;
  216. }
  217. int getCount = chunked_data.Count;
  218. if (getSize != -1)
  219. {
  220. num_chunks = (int)Math.Ceiling((double)(getCount * getSize) / HDF5_OBJECT_HEADER_LIMIT);
  221. if (num_chunks > 1) chunked_data = Split(data, num_chunks);
  222. }
  223. if (num_chunks > 1)
  224. {
  225. foreach (var (chunk_id, chunk_data) in enumerate(chunked_data))
  226. WriteAttrs(f, getType, $"{name}{chunk_id}", chunk_data.ToArray());
  227. }
  228. else
  229. {
  230. WriteAttrs(f, getType, name, data);
  231. }
  232. }
  233. private static void WriteDataset(long f, string name, Tensor data)
  234. {
  235. switch (data.dtype)
  236. {
  237. case TF_DataType.TF_FLOAT:
  238. Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMultiDimArray<float>());
  239. break;
  240. case TF_DataType.TF_DOUBLE:
  241. Hdf5.WriteDatasetFromArray<double>(f, name, data.numpy().ToMultiDimArray<double>());
  242. break;
  243. case TF_DataType.TF_INT32:
  244. Hdf5.WriteDatasetFromArray<int>(f, name, data.numpy().ToMultiDimArray<int>());
  245. break;
  246. case TF_DataType.TF_INT64:
  247. Hdf5.WriteDatasetFromArray<long>(f, name, data.numpy().ToMultiDimArray<long>());
  248. break;
  249. default:
  250. Hdf5.WriteDatasetFromArray<float>(f, name, data.numpy().ToMultiDimArray<float>());
  251. break;
  252. }
  253. }
  254. private static void WriteAttrs(long f,string typename, string name, Array data)
  255. {
  256. switch (typename)
  257. {
  258. case "single":
  259. Hdf5.WriteAttributes<float>(f, name, data);
  260. break;
  261. case "double":
  262. Hdf5.WriteAttributes<double>(f, name, data);
  263. break;
  264. case "string":
  265. Hdf5.WriteAttributes<string>(f, name, data);
  266. break;
  267. case "int32":
  268. Hdf5.WriteAttributes<int>(f, name, data);
  269. break;
  270. case "int64":
  271. Hdf5.WriteAttributes<long>(f, name, data);
  272. break;
  273. default:
  274. Hdf5.WriteAttributes<string>(f, name,data);
  275. break;
  276. }
  277. }
  278. private static List<List<object>> Split(Array list, int chunkSize)
  279. {
  280. var splitList = new List<List<object>>();
  281. var chunkCount = (int)Math.Ceiling((double)list.Length / (double)chunkSize);
  282. for (int c = 0; c < chunkCount; c++)
  283. {
  284. var skip = c * chunkSize;
  285. var take = skip + chunkSize;
  286. var chunk = new List<object>(chunkSize);
  287. for (int e = skip; e < take && e < list.Length; e++)
  288. {
  289. chunk.Add(list.GetValue(e));
  290. }
  291. splitList.Add(chunk);
  292. }
  293. return splitList;
  294. }
  295. public static string[] load_attributes_from_hdf5_group(long group, string name)
  296. {
  297. var (success, attr) = Hdf5.ReadStringAttributes(group, name, "", true);
  298. if (success)
  299. return attr.ToArray();
  300. return null;
  301. }
  302. public static void load_attributes_from_hdf5_group(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
  303. {
  304. }
  305. public static List<IVariableV1> _legacy_weights(ILayer layer)
  306. {
  307. var weights = layer.TrainableWeights.Select(x => x).ToList();
  308. weights.AddRange(layer.NonTrainableWeights);
  309. return weights;
  310. }
  311. }
  312. }