You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 7.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # pylint: disable=E0401
  2. """
  3. Some utils while building models
  4. """
  5. import collections.abc
  6. import difflib
  7. import logging
  8. import os
  9. from copy import deepcopy
  10. from itertools import repeat
  11. from typing import Callable, Dict, List, Optional
  12. from mindspore import nn
  13. from mindspore import load_checkpoint, load_param_into_net
  14. from utils.download import DownLoad, get_default_download_root
  15. # from ..utils.download import DownLoad, get_default_download_root
  16. from model.features import FeatureExtractWrapper
  17. def get_checkpoint_download_root():
  18. """ get checkpoint download root """
  19. return os.path.join(get_default_download_root(), "models")
  20. class ConfigDict(dict):
  21. """dot.notation access to dictionary attributes"""
  22. __getattr__ = dict.get
  23. __setattr__ = dict.__setitem__
  24. __delattr__ = dict.__delitem__
  25. def load_pretrained(model, default_cfg, num_classes=1000, in_channels=3, filter_fn=None):
  26. """load pretrained model depending on cfgs of model"""
  27. if "url" not in default_cfg or not default_cfg["url"]:
  28. logging.warning("Pretrained model URL is invalid")
  29. return
  30. # download files
  31. download_path = get_checkpoint_download_root()
  32. os.makedirs(download_path, exist_ok=True)
  33. DownLoad().download_url(default_cfg["url"], path=download_path)
  34. param_dict = load_checkpoint(os.path.join(download_path, os.path.basename(default_cfg["url"])))
  35. if in_channels == 1:
  36. conv1_name = default_cfg["first_conv"]
  37. logging.info("Converting first conv (%s) from 3 to 1 channel", conv1_name)
  38. con1_weight = param_dict[conv1_name + ".weight"]
  39. con1_weight.set_data(con1_weight.sum(axis=1, keepdims=True), slice_shape=True)
  40. elif in_channels != 3:
  41. raise ValueError("Invalid in_channels for pretrained weights")
  42. classifier_name = default_cfg["classifier"]
  43. if num_classes == 1000 and default_cfg["num_classes"] == 1001:
  44. classifier_weight = param_dict[classifier_name + ".weight"]
  45. classifier_weight.set_data(classifier_weight[:1000], slice_shape=True)
  46. classifier_bias = param_dict[classifier_name + ".bias"]
  47. classifier_bias.set_data(classifier_bias[:1000], slice_shape=True)
  48. elif num_classes != default_cfg["num_classes"]:
  49. params_names = list(param_dict.keys())
  50. for param_name in _search_param_name(params_names, classifier_name + ".weight"):
  51. param_dict.pop(
  52. param_name,
  53. f"Parameter {param_name} has been deleted from ParamDict.",
  54. )
  55. for param_name in _search_param_name(params_names, classifier_name + ".bias"):
  56. param_dict.pop(
  57. param_name,
  58. f"Parameter {param_name} has been deleted from ParamDict.",
  59. )
  60. if filter_fn is not None:
  61. param_dict = filter_fn(param_dict)
  62. load_param_into_net(model, param_dict)
  63. def make_divisible(
  64. v: float,
  65. divisor: int,
  66. min_value: Optional[int] = None,
  67. ) -> int:
  68. """Find the smallest integer larger than v and divisible by divisor."""
  69. if not min_value:
  70. min_value = divisor
  71. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  72. # Make sure that round down does not go down by more than 10%.
  73. if new_v < 0.9 * v:
  74. new_v += divisor
  75. return new_v
  76. def _ntuple(n):
  77. def parse(x):
  78. if isinstance(x, collections.abc.Iterable):
  79. return x
  80. return tuple(repeat(x, n))
  81. return parse
  82. def _search_param_name(params_names: List, param_name: str) -> list:
  83. search_results = []
  84. for pi in params_names:
  85. if param_name in pi:
  86. search_results.append(pi)
  87. return search_results
  88. def auto_map(model, param_dict):
  89. """Raname part of the param_dict such that names from checkpoint and model are consistent"""
  90. updated_param_dict = deepcopy(param_dict)
  91. net_param = model.get_parameters()
  92. ckpt_param = list(updated_param_dict.keys())
  93. remap = {}
  94. for param in net_param:
  95. if param.name not in ckpt_param:
  96. print("Cannot find a param to load: ", param.name)
  97. poss = difflib.get_close_matches(param.name, ckpt_param, n=3, cutoff=0.6)
  98. if len(poss) > 0:
  99. print("=> Find most matched param: ", poss[0], ", loaded")
  100. updated_param_dict[param.name] = updated_param_dict.pop(poss[0]) # replace
  101. remap[param.name] = poss[0]
  102. else:
  103. raise ValueError("Cannot find any matching param from: ", ckpt_param)
  104. if len(remap) != 0:
  105. print("WARNING: Auto mapping succeed. Please check the found mapping names to ensure correctness")
  106. print("\tNet Param\t<---\tCkpt Param")
  107. for k in remap.items():
  108. print(f"\t{k}\t<---\t{remap[k]}")
  109. return updated_param_dict
  110. def load_model_checkpoint(model: nn.Cell, checkpoint_path: str = "", ema: bool = False, auto_mapping: bool = False):
  111. """Model loads checkpoint.
  112. Args:
  113. model (nn.Cell): The model which loads the checkpoint.
  114. checkpoint_path (str): The path of checkpoint files. Default: "".
  115. ema (bool): Whether use ema method. Default: False.
  116. auto_mapping (bool): Whether to automatically map the names of checkpoint weights
  117. to the names of model weights when there are differences in names. Default: False.
  118. """
  119. if os.path.exists(checkpoint_path):
  120. checkpoint_param = load_checkpoint(checkpoint_path)
  121. if auto_mapping:
  122. checkpoint_param = auto_map(model, checkpoint_param)
  123. ema_param_dict = {}
  124. for param in checkpoint_param:
  125. if param.startswith("ema"):
  126. new_name = param.split("ema.")[1]
  127. ema_data = checkpoint_param[param]
  128. ema_data.name = new_name
  129. ema_param_dict[new_name] = ema_data
  130. if ema_param_dict and ema:
  131. load_param_into_net(model, ema_param_dict)
  132. elif bool(ema_param_dict) is False and ema:
  133. raise ValueError("chekpoint_param does not contain ema_parameter, please set ema is False.")
  134. else:
  135. load_param_into_net(model, checkpoint_param)
  136. def build_model_with_cfg(
  137. model_cls: Callable,
  138. pretrained: bool,
  139. default_cfg: Dict,
  140. features_only: bool = False,
  141. out_indices=None,
  142. **kwargs,
  143. ):
  144. """Build model with specific model configurations
  145. Args:
  146. model_cls (nn.Cell): Model class
  147. pretrained (bool): Whether to load pretrained weights.
  148. default_cfg (Dict): Configuration for pretrained weights.
  149. features_only (bool): Output the features at different strides instead. Default: False
  150. out_indices (list[int]): The indicies of the output features when `features_only` is `True`.
  151. Default: [0, 1, 2, 3, 4]
  152. """
  153. if out_indices is None:
  154. out_indices = [0, 1, 2, 3, 4]
  155. model = model_cls(**kwargs)
  156. if pretrained:
  157. load_pretrained(model, default_cfg, kwargs.get("num_classes", 1000), kwargs.get("in_channels", 3))
  158. if features_only:
  159. # wrap the model, output the feature pyramid instead
  160. try:
  161. model = FeatureExtractWrapper(model, out_indices=out_indices)
  162. except AttributeError as e:
  163. raise RuntimeError(f"`feature_only` is not implemented for `{model_cls.__name__}` model.") from e
  164. return model

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN