You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_lineage.py 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This module is used to collect lineage information of model training."""
  16. import json
  17. import os
  18. import numpy as np
  19. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, LineageErrors
  20. from mindinsight.lineagemgr.common.log import logger as log
  21. from mindinsight.utils.exceptions import MindInsightException
  22. from ._summary_record import LineageSummary
  23. from .base import Metadata
  24. from .utils import try_except, LineageParamRunContextError, LineageGetModelFileError, LineageLogError, \
  25. validate_int_params, validate_file_path, validate_raise_exception, \
  26. validate_user_defined_info, make_directory
  27. try:
  28. from mindspore.common.tensor import Tensor
  29. from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep
  30. from mindspore.nn import Cell, Optimizer
  31. from mindspore.nn.loss.loss import _Loss
  32. from mindspore.dataset.engine import Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset, \
  33. VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset
  34. import mindspore.dataset as ds
  35. except (ImportError, ModuleNotFoundError):
  36. log.warning('MindSpore Not Found!')
  37. class TrainLineage(Callback):
  38. """
  39. Collect lineage of a training job.
  40. Args:
  41. summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which
  42. is used to record the summary value(see mindspore.train.summary.SummaryRecord),
  43. or a log dir(as a `str`) to be passed to `LineageSummary` to create
  44. a lineage summary recorder. It should be noted that instead of making
  45. use of summary_record to record lineage info directly, we obtain
  46. log dir from it then create a new summary file to write lineage info.
  47. raise_exception (bool): Whether to raise exception when error occurs in
  48. TrainLineage. If True, raise exception. If False, catch exception
  49. and continue. Default: False.
  50. user_defined_info (dict): User defined information. Only flatten dict with
  51. str key and int/float/str value is supported. Default: None.
  52. Raises:
  53. MindInsightException: If validating parameter fails.
  54. LineageLogError: If recording lineage information fails.
  55. Examples:
  56. >>> from mindinsight.lineagemgr import TrainLineage
  57. >>> from mindspore.train.callback import ModelCheckpoint, SummaryStep
  58. >>> from mindspore.train.summary import SummaryRecord
  59. >>> model = Model(train_network)
  60. >>> model_ckpt = ModelCheckpoint(directory='/dir/to/save/model/')
  61. >>> summary_writer = SummaryRecord(log_dir='./')
  62. >>> summary_callback = SummaryStep(summary_writer, flush_step=2)
  63. >>> lineagemgr = TrainLineage(summary_record=summary_writer)
  64. >>> model.train(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr])
  65. """
  66. def __init__(self,
  67. summary_record,
  68. raise_exception=False,
  69. user_defined_info=None):
  70. super(TrainLineage, self).__init__()
  71. try:
  72. validate_raise_exception(raise_exception)
  73. self.raise_exception = raise_exception
  74. if isinstance(summary_record, str):
  75. # make directory if not exist
  76. self.lineage_log_dir = make_directory(summary_record)
  77. else:
  78. summary_log_path = summary_record.full_file_name
  79. validate_file_path(summary_log_path)
  80. self.lineage_log_dir = os.path.dirname(summary_log_path)
  81. self.lineage_summary = LineageSummary(self.lineage_log_dir)
  82. self.initial_learning_rate = None
  83. self.user_defined_info = user_defined_info
  84. if user_defined_info:
  85. validate_user_defined_info(user_defined_info)
  86. except MindInsightException as err:
  87. log.error(err)
  88. if raise_exception:
  89. raise
  90. @try_except(log)
  91. def begin(self, run_context):
  92. """
  93. Initialize the training progress when the training job begins.
  94. Args:
  95. run_context (RunContext): It contains all lineage information,
  96. see mindspore.train.callback.RunContext.
  97. Raises:
  98. MindInsightException: If validating parameter fails.
  99. """
  100. log.info('Initialize training lineage collection...')
  101. if self.user_defined_info:
  102. self.lineage_summary.record_user_defined_info(self.user_defined_info)
  103. if not isinstance(run_context, RunContext):
  104. error_msg = f'Invalid TrainLineage run_context.'
  105. log.error(error_msg)
  106. raise LineageParamRunContextError(error_msg)
  107. run_context_args = run_context.original_args()
  108. if not self.initial_learning_rate:
  109. optimizer = run_context_args.get('optimizer')
  110. if optimizer and not isinstance(optimizer, Optimizer):
  111. log.error("The parameter optimizer is invalid. It should be an instance of "
  112. "mindspore.nn.optim.optimizer.Optimizer.")
  113. raise MindInsightException(error=LineageErrors.PARAM_OPTIMIZER_ERROR,
  114. message=LineageErrorMsg.PARAM_OPTIMIZER_ERROR.value)
  115. if optimizer:
  116. log.info('Obtaining initial learning rate...')
  117. self.initial_learning_rate = AnalyzeObject.analyze_optimizer(optimizer)
  118. log.debug('initial_learning_rate: %s', self.initial_learning_rate)
  119. else:
  120. network = run_context_args.get('train_network')
  121. optimizer = AnalyzeObject.get_optimizer_by_network(network)
  122. self.initial_learning_rate = AnalyzeObject.analyze_optimizer(optimizer)
  123. log.debug('initial_learning_rate: %s', self.initial_learning_rate)
  124. # get train dataset graph
  125. train_dataset = run_context_args.get('train_dataset')
  126. dataset_graph_dict = ds.serialize(train_dataset)
  127. dataset_graph_json_str = json.dumps(dataset_graph_dict, indent=2)
  128. dataset_graph_dict = json.loads(dataset_graph_json_str)
  129. log.info('Logging dataset graph...')
  130. try:
  131. self.lineage_summary.record_dataset_graph(dataset_graph=dataset_graph_dict)
  132. except Exception as error:
  133. error_msg = f'Dataset graph log error in TrainLineage begin: {error}'
  134. log.error(error_msg)
  135. raise LineageLogError(error_msg)
  136. log.info('Dataset graph logged successfully.')
  137. @try_except(log)
  138. def end(self, run_context):
  139. """
  140. Collect lineage information when the training job ends.
  141. Args:
  142. run_context (RunContext): It contains all lineage information,
  143. see mindspore.train.callback.RunContext.
  144. Raises:
  145. LineageLogError: If recording lineage information fails.
  146. """
  147. log.info('Start to collect training lineage...')
  148. if not isinstance(run_context, RunContext):
  149. error_msg = f'Invalid TrainLineage run_context.'
  150. log.error(error_msg)
  151. raise LineageParamRunContextError(error_msg)
  152. run_context_args = run_context.original_args()
  153. train_lineage = dict()
  154. train_lineage = AnalyzeObject.get_network_args(
  155. run_context_args, train_lineage
  156. )
  157. train_dataset = run_context_args.get('train_dataset')
  158. callbacks = run_context_args.get('list_callback')
  159. list_callback = getattr(callbacks, '_callbacks', [])
  160. log.info('Obtaining model files...')
  161. ckpt_file_path, _ = AnalyzeObject.get_file_path(list_callback)
  162. train_lineage[Metadata.learning_rate] = self.initial_learning_rate
  163. train_lineage[Metadata.epoch] = run_context_args.get('epoch_num')
  164. train_lineage[Metadata.step_num] = run_context_args.get('cur_step_num')
  165. train_lineage[Metadata.parallel_mode] = run_context_args.get('parallel_mode')
  166. train_lineage[Metadata.device_num] = run_context_args.get('device_number')
  167. train_lineage[Metadata.batch_size] = run_context_args.get('batch_num')
  168. model_path_dict = {
  169. 'ckpt': ckpt_file_path
  170. }
  171. train_lineage[Metadata.model_path] = json.dumps(model_path_dict)
  172. log.info('Calculating model size...')
  173. train_lineage[Metadata.model_size] = AnalyzeObject.get_model_size(
  174. ckpt_file_path
  175. )
  176. log.debug('model_size: %s', train_lineage[Metadata.model_size])
  177. log.info('Analyzing dataset object...')
  178. train_lineage = AnalyzeObject.analyze_dataset(train_dataset, train_lineage, 'train')
  179. log.info('Logging lineage information...')
  180. try:
  181. self.lineage_summary.record_train_lineage(train_lineage)
  182. except IOError as error:
  183. error_msg = f'End error in TrainLineage: {error}'
  184. log.error(error_msg)
  185. raise LineageLogError(error_msg)
  186. except Exception as error:
  187. error_msg = f'End error in TrainLineage: {error}'
  188. log.error(error_msg)
  189. log.error('Fail to log the lineage of the training job.')
  190. raise LineageLogError(error_msg)
  191. log.info('The lineage of the training job has logged successfully.')
  192. class EvalLineage(Callback):
  193. """
  194. Collect lineage of an evaluation job.
  195. Args:
  196. summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which
  197. is used to record the summary value(see mindspore.train.summary.SummaryRecord),
  198. or a log dir(as a `str`) to be passed to `LineageSummary` to create
  199. a lineage summary recorder. It should be noted that instead of making
  200. use of summary_record to record lineage info directly, we obtain
  201. log dir from it then create a new summary file to write lineage info.
  202. raise_exception (bool): Whether to raise exception when error occurs in
  203. EvalLineage. If True, raise exception. If False, catch exception
  204. and continue. Default: False.
  205. user_defined_info (dict): User defined information. Only flatten dict with
  206. str key and int/float/str value is supported. Default: None.
  207. Raises:
  208. MindInsightException: If validating parameter fails.
  209. LineageLogError: If recording lineage information fails.
  210. Examples:
  211. >>> from mindinsight.lineagemgr import EvalLineage
  212. >>> from mindspore.train.callback import ModelCheckpoint, SummaryStep
  213. >>> from mindspore.train.summary import SummaryRecord
  214. >>> model = Model(train_network)
  215. >>> model_ckpt = ModelCheckpoint(directory='/dir/to/save/model/')
  216. >>> summary_writer = SummaryRecord(log_dir='./')
  217. >>> summary_callback = SummaryStep(summary_writer, flush_step=2)
  218. >>> lineagemgr = EvalLineage(summary_record=summary_writer)
  219. >>> model.eval(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr])
  220. """
  221. def __init__(self,
  222. summary_record,
  223. raise_exception=False,
  224. user_defined_info=None):
  225. super(EvalLineage, self).__init__()
  226. try:
  227. validate_raise_exception(raise_exception)
  228. self.raise_exception = raise_exception
  229. if isinstance(summary_record, str):
  230. # make directory if not exist
  231. self.lineage_log_dir = make_directory(summary_record)
  232. else:
  233. summary_log_path = summary_record.full_file_name
  234. validate_file_path(summary_log_path)
  235. self.lineage_log_dir = os.path.dirname(summary_log_path)
  236. self.lineage_summary = LineageSummary(self.lineage_log_dir)
  237. self.user_defined_info = user_defined_info
  238. if self.user_defined_info:
  239. validate_user_defined_info(self.user_defined_info)
  240. except MindInsightException as err:
  241. log.error(err)
  242. if raise_exception:
  243. raise
  244. @try_except(log)
  245. def end(self, run_context):
  246. """
  247. Collect lineage information when the training job ends.
  248. Args:
  249. run_context (RunContext): It contains all lineage information,
  250. see mindspore.train.callback.RunContext.
  251. Raises:
  252. MindInsightException: If validating parameter fails.
  253. LineageLogError: If recording lineage information fails.
  254. """
  255. if self.user_defined_info:
  256. self.lineage_summary.record_user_defined_info(self.user_defined_info)
  257. if not isinstance(run_context, RunContext):
  258. error_msg = f'Invalid EvalLineage run_context.'
  259. log.error(error_msg)
  260. raise LineageParamRunContextError(error_msg)
  261. run_context_args = run_context.original_args()
  262. valid_dataset = run_context_args.get('valid_dataset')
  263. eval_lineage = dict()
  264. metrics = run_context_args.get('metrics')
  265. eval_lineage[Metadata.metrics] = json.dumps(metrics)
  266. eval_lineage[Metadata.step_num] = run_context_args.get('cur_step_num')
  267. log.info('Analyzing dataset object...')
  268. eval_lineage = AnalyzeObject.analyze_dataset(valid_dataset, eval_lineage, 'valid')
  269. log.info('Logging evaluation job lineage...')
  270. try:
  271. self.lineage_summary.record_evaluation_lineage(eval_lineage)
  272. except IOError as error:
  273. error_msg = f'End error in EvalLineage: {error}'
  274. log.error(error_msg)
  275. log.error('Fail to log the lineage of the evaluation job.')
  276. raise LineageLogError(error_msg)
  277. except Exception as error:
  278. error_msg = f'End error in EvalLineage: {error}'
  279. log.error(error_msg)
  280. log.error('Fail to log the lineage of the evaluation job.')
  281. raise LineageLogError(error_msg)
  282. log.info('The lineage of the evaluation job has logged successfully.')
  283. class AnalyzeObject:
  284. """Analyze class object in MindSpore."""
  285. @staticmethod
  286. def get_optimizer_by_network(network):
  287. """
  288. Get optimizer by analyzing network.
  289. Args:
  290. network (Cell): See mindspore.nn.Cell.
  291. Returns:
  292. Optimizer, an Optimizer object.
  293. """
  294. optimizer = None
  295. net_args = vars(network) if network else {}
  296. net_cell = net_args.get('_cells') if net_args else {}
  297. for _, value in net_cell.items():
  298. if isinstance(value, Optimizer):
  299. optimizer = value
  300. break
  301. return optimizer
  302. @staticmethod
  303. def get_loss_fn_by_network(network):
  304. """
  305. Get loss function by analyzing network.
  306. Args:
  307. network (Cell): See mindspore.nn.Cell.
  308. Returns:
  309. Loss_fn, a Cell object.
  310. """
  311. loss_fn = None
  312. inner_cell_list = []
  313. net_args = vars(network) if network else {}
  314. net_cell = net_args.get('_cells') if net_args else {}
  315. for _, value in net_cell.items():
  316. if isinstance(value, Cell) and \
  317. not isinstance(value, Optimizer):
  318. inner_cell_list.append(value)
  319. while inner_cell_list:
  320. inner_net_args = vars(inner_cell_list[0])
  321. inner_net_cell = inner_net_args.get('_cells')
  322. for value in inner_net_cell.values():
  323. if isinstance(value, _Loss):
  324. loss_fn = value
  325. break
  326. if isinstance(value, Cell):
  327. inner_cell_list.append(value)
  328. if loss_fn:
  329. break
  330. inner_cell_list.pop(0)
  331. return loss_fn
  332. @staticmethod
  333. def get_backbone_network(network):
  334. """
  335. Get the name of backbone network.
  336. Args:
  337. network (Cell): The train network.
  338. Returns:
  339. str, the name of the backbone network.
  340. """
  341. backbone_name = None
  342. has_network = False
  343. network_key = 'network'
  344. backbone_key = '_backbone'
  345. net_args = vars(network) if network else {}
  346. net_cell = net_args.get('_cells') if net_args else {}
  347. for key, value in net_cell.items():
  348. if key == network_key:
  349. network = value
  350. has_network = True
  351. break
  352. if has_network:
  353. while hasattr(network, network_key):
  354. network = getattr(network, network_key)
  355. if hasattr(network, backbone_key):
  356. backbone = getattr(network, backbone_key)
  357. backbone_name = type(backbone).__name__
  358. if backbone_name is None and network is not None:
  359. backbone_name = type(network).__name__
  360. return backbone_name
  361. @staticmethod
  362. def analyze_optimizer(optimizer):
  363. """
  364. Analyze Optimizer, a Cell object of MindSpore.
  365. In this way, we can obtain the following attributes:
  366. learning_rate (float),
  367. weight_decay (float),
  368. momentum (float),
  369. weights (float).
  370. Args:
  371. optimizer (Optimizer): See mindspore.nn.optim.Optimizer.
  372. Returns:
  373. float, the learning rate that the optimizer adopted.
  374. """
  375. learning_rate = None
  376. if isinstance(optimizer, Optimizer):
  377. learning_rate = getattr(optimizer, 'learning_rate', None)
  378. if learning_rate:
  379. learning_rate = learning_rate.default_input
  380. # Get the real learning rate value
  381. if isinstance(learning_rate, Tensor):
  382. learning_rate = learning_rate.asnumpy()
  383. if learning_rate.ndim == 0:
  384. learning_rate = np.atleast_1d(learning_rate)
  385. learning_rate = list(learning_rate)
  386. elif isinstance(learning_rate, float):
  387. learning_rate = [learning_rate]
  388. return learning_rate[0] if learning_rate else None
  389. @staticmethod
  390. def analyze_dataset(dataset, lineage_dict, dataset_type):
  391. """
  392. Analyze Dataset, a Dataset object of MindSpore.
  393. In this way, we can obtain the following attributes:
  394. dataset_path (str),
  395. train_dataset_size (int),
  396. valid_dataset_size (int),
  397. batch_size (int)
  398. Args:
  399. dataset (Dataset): See mindspore.dataengine.datasets.Dataset.
  400. lineage_dict (dict): A dict contains lineage metadata.
  401. dataset_type (str): Dataset type, train or valid.
  402. Returns:
  403. dict, the lineage metadata.
  404. """
  405. batch_num = dataset.get_dataset_size()
  406. batch_size = dataset.get_batch_size()
  407. if batch_num is not None:
  408. validate_int_params(batch_num, 'dataset_batch_num')
  409. validate_int_params(batch_num, 'dataset_batch_size')
  410. log.debug('dataset_batch_num: %d', batch_num)
  411. log.debug('dataset_batch_size: %d', batch_size)
  412. dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset)
  413. if dataset_path and os.path.isfile(dataset_path):
  414. dataset_path, _ = os.path.split(dataset_path)
  415. dataset_size = int(batch_num * batch_size)
  416. if dataset_type == 'train':
  417. lineage_dict[Metadata.train_dataset_path] = dataset_path
  418. lineage_dict[Metadata.train_dataset_size] = dataset_size
  419. elif dataset_type == 'valid':
  420. lineage_dict[Metadata.valid_dataset_path] = dataset_path
  421. lineage_dict[Metadata.valid_dataset_size] = dataset_size
  422. return lineage_dict
  423. def get_dataset_path(self, output_dataset):
  424. """
  425. Get dataset path of MindDataset object.
  426. Args:
  427. output_dataset (Union[Dataset, ImageFolderDatasetV2, MnistDataset, Cifar10Dataset, Cifar100Dataset,
  428. VOCDataset, CelebADataset, MindDataset, ManifestDataset, TFRecordDataset, TextFileDataset]):
  429. See mindspore.dataengine.datasets.Dataset.
  430. Returns:
  431. str, dataset path.
  432. """
  433. dataset_dir_set = (ImageFolderDatasetV2, MnistDataset, Cifar10Dataset,
  434. Cifar100Dataset, VOCDataset, CelebADataset)
  435. dataset_file_set = (MindDataset, ManifestDataset)
  436. dataset_files_set = (TFRecordDataset, TextFileDataset)
  437. if isinstance(output_dataset, dataset_file_set):
  438. return output_dataset.dataset_file
  439. if isinstance(output_dataset, dataset_dir_set):
  440. return output_dataset.dataset_dir
  441. if isinstance(output_dataset, dataset_files_set):
  442. return output_dataset.dataset_files[0]
  443. return self.get_dataset_path(output_dataset.inputs[0])
  444. @staticmethod
  445. def get_dataset_path_wrapped(dataset):
  446. """
  447. A wrapper for obtaining dataset path.
  448. Args:
  449. dataset (Union[MindDataset, Dataset]): See
  450. mindspore.dataengine.datasets.Dataset.
  451. Returns:
  452. str, dataset path.
  453. """
  454. dataset_path = None
  455. if isinstance(dataset, Dataset):
  456. try:
  457. dataset_path = AnalyzeObject().get_dataset_path(dataset)
  458. except IndexError:
  459. dataset_path = None
  460. dataset_path = validate_file_path(dataset_path, allow_empty=True)
  461. return dataset_path
  462. @staticmethod
  463. def get_file_path(list_callback):
  464. """
  465. Get ckpt_file_name and summary_log_path from MindSpore callback list.
  466. Args:
  467. list_callback (list[Callback]): The MindSpore training Callback list.
  468. Returns:
  469. tuple, contains ckpt_file_name and summary_log_path.
  470. """
  471. ckpt_file_path = None
  472. summary_log_path = None
  473. for callback in list_callback:
  474. if isinstance(callback, ModelCheckpoint):
  475. ckpt_file_path = callback.latest_ckpt_file_name
  476. if isinstance(callback, SummaryStep):
  477. summary_log_path = callback.summary_file_name
  478. if ckpt_file_path:
  479. validate_file_path(ckpt_file_path)
  480. ckpt_file_path = os.path.realpath(ckpt_file_path)
  481. if summary_log_path:
  482. validate_file_path(summary_log_path)
  483. summary_log_path = os.path.realpath(summary_log_path)
  484. return ckpt_file_path, summary_log_path
  485. @staticmethod
  486. def get_file_size(file_path):
  487. """
  488. Get the file size.
  489. Args:
  490. file_path (str): The file path.
  491. Returns:
  492. int, the file size.
  493. """
  494. try:
  495. return os.path.getsize(file_path)
  496. except (OSError, IOError) as error:
  497. error_msg = f"Error when get model file size: {error}"
  498. log.error(error_msg)
  499. raise LineageGetModelFileError(error_msg)
  500. @staticmethod
  501. def get_model_size(ckpt_file_path):
  502. """
  503. Get model the total size of the model file and the checkpoint file.
  504. Args:
  505. ckpt_file_path (str): The checkpoint file path.
  506. Returns:
  507. int, the total file size.
  508. """
  509. if ckpt_file_path:
  510. ckpt_file_path = os.path.realpath(ckpt_file_path)
  511. ckpt_file_size = AnalyzeObject.get_file_size(ckpt_file_path)
  512. else:
  513. ckpt_file_size = 0
  514. return ckpt_file_size
  515. @staticmethod
  516. def get_network_args(run_context_args, train_lineage):
  517. """
  518. Get the parameters related to the network,
  519. such as optimizer, loss function.
  520. Args:
  521. run_context_args (dict): It contains all information of the training job.
  522. train_lineage (dict): A dict contains lineage metadata.
  523. Returns:
  524. dict, the lineage metadata.
  525. """
  526. network = run_context_args.get('train_network')
  527. optimizer = run_context_args.get('optimizer')
  528. if not optimizer:
  529. optimizer = AnalyzeObject.get_optimizer_by_network(network)
  530. loss_fn = run_context_args.get('loss_fn')
  531. if not loss_fn:
  532. loss_fn = AnalyzeObject.get_loss_fn_by_network(network)
  533. loss = None
  534. else:
  535. loss = run_context_args.get('net_outputs')
  536. if loss:
  537. log.info('Calculating loss...')
  538. loss_numpy = loss.asnumpy()
  539. loss = float(np.atleast_1d(loss_numpy)[0])
  540. log.debug('loss: %s', loss)
  541. train_lineage[Metadata.loss] = loss
  542. else:
  543. train_lineage[Metadata.loss] = None
  544. # Analyze classname of optimizer, loss function and training network.
  545. train_lineage[Metadata.optimizer] = type(optimizer).__name__ \
  546. if optimizer else None
  547. train_lineage[Metadata.train_network] = AnalyzeObject.get_backbone_network(network)
  548. train_lineage[Metadata.loss_function] = type(loss_fn).__name__ \
  549. if loss_fn else None
  550. return train_lineage