You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_lineage.py 24 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This module is used to collect lineage information of model training."""
  16. import json
  17. import os
  18. import numpy as np
  19. from mindinsight.lineagemgr.summary.summary_record import LineageSummary
  20. from mindinsight.utils.exceptions import \
  21. MindInsightException
  22. from mindinsight.lineagemgr.common.validator.validate import validate_train_run_context, \
  23. validate_eval_run_context, validate_file_path, validate_network, \
  24. validate_int_params, validate_summary_record, validate_raise_exception
  25. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrors, LineageErrorMsg
  26. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamRunContextError, \
  27. LineageGetModelFileError, LineageLogError
  28. from mindinsight.lineagemgr.common.log import logger as log
  29. from mindinsight.lineagemgr.common.utils import try_except
  30. from mindinsight.lineagemgr.common.validator.model_parameter import RunContextArgs, \
  31. EvalParameter
  32. from mindinsight.lineagemgr.collection.model.base import Metadata
  33. try:
  34. from mindspore.common.tensor import Tensor
  35. from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep
  36. from mindspore.nn import Cell, Optimizer, WithLossCell, TrainOneStepWithLossScaleCell
  37. from mindspore.nn.loss.loss import _Loss
  38. from mindspore.dataset.engine import Dataset, MindDataset
  39. import mindspore.dataset as ds
  40. except (ImportError, ModuleNotFoundError):
  41. log.warning('MindSpore Not Found!')
  42. class TrainLineage(Callback):
  43. """
  44. Collect lineage of a training job.
  45. Args:
  46. summary_record (SummaryRecord): SummaryRecord is used to record
  47. the summary value, and summary_record is an instance of SummaryRecord,
  48. see mindspore.train.summary.SummaryRecord.
  49. raise_exception (bool): Whether to raise exception when error occurs in
  50. TrainLineage. If True, raise exception. If False, catch exception
  51. and continue. Default: False.
  52. Raises:
  53. MindInsightException: If validating parameter fails.
  54. LineageLogError: If recording lineage information fails.
  55. Examples:
  56. >>> from mindinsight.lineagemgr import TrainLineage
  57. >>> from mindspore.train.callback import ModelCheckpoint, SummaryStep
  58. >>> from mindspore.train.summary import SummaryRecord
  59. >>> model = Model(train_network)
  60. >>> model_ckpt = ModelCheckpoint(directory='/dir/to/save/model/')
  61. >>> summary_writer = SummaryRecord(log_dir='./')
  62. >>> summary_callback = SummaryStep(summary_writer, flush_step=2)
  63. >>> lineagemgr = TrainLineage(summary_record=summary_writer)
  64. >>> model.train(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr])
  65. """
  66. def __init__(self, summary_record, raise_exception=False):
  67. super(TrainLineage, self).__init__()
  68. try:
  69. validate_raise_exception(raise_exception)
  70. self.raise_exception = raise_exception
  71. validate_summary_record(summary_record)
  72. self.summary_record = summary_record
  73. summary_log_path = summary_record.full_file_name
  74. validate_file_path(summary_log_path)
  75. self.lineage_log_path = summary_log_path + '_lineage'
  76. self.initial_learning_rate = None
  77. except MindInsightException as err:
  78. log.error(err)
  79. if raise_exception:
  80. raise
  81. @try_except(log)
  82. def begin(self, run_context):
  83. """
  84. Initialize the training progress when the training job begins.
  85. Args:
  86. run_context (RunContext): It contains all lineage information,
  87. see mindspore.train.callback.RunContext.
  88. Raises:
  89. MindInsightException: If validating parameter fails.
  90. """
  91. log.info('Initialize training lineage collection...')
  92. if not isinstance(run_context, RunContext):
  93. error_msg = f'Invalid TrainLineage run_context.'
  94. log.error(error_msg)
  95. raise LineageParamRunContextError(error_msg)
  96. run_context_args = run_context.original_args()
  97. if not self.initial_learning_rate:
  98. optimizer = run_context_args.get('optimizer')
  99. if optimizer and not isinstance(optimizer, Optimizer):
  100. log.error("The parameter optimizer is invalid. It should be an instance of "
  101. "mindspore.nn.optim.optimizer.Optimizer.")
  102. raise MindInsightException(error=LineageErrors.PARAM_OPTIMIZER_ERROR,
  103. message=LineageErrorMsg.PARAM_OPTIMIZER_ERROR.value)
  104. if optimizer:
  105. log.info('Obtaining initial learning rate...')
  106. self.initial_learning_rate = AnalyzeObject.analyze_optimizer(optimizer)
  107. log.debug('initial_learning_rate: %s', self.initial_learning_rate)
  108. else:
  109. network = run_context_args.get('train_network')
  110. validate_network(network)
  111. optimizer = AnalyzeObject.get_optimizer_by_network(network)
  112. self.initial_learning_rate = AnalyzeObject.analyze_optimizer(optimizer)
  113. log.debug('initial_learning_rate: %s', self.initial_learning_rate)
  114. # get train dataset graph
  115. train_dataset = run_context_args.get('train_dataset')
  116. dataset_graph_dict = ds.serialize(train_dataset)
  117. dataset_graph_json_str = json.dumps(dataset_graph_dict, indent=2)
  118. dataset_graph_dict = json.loads(dataset_graph_json_str)
  119. log.info('Logging dataset graph...')
  120. try:
  121. lineage_summary = LineageSummary(self.lineage_log_path)
  122. lineage_summary.record_dataset_graph(dataset_graph=dataset_graph_dict)
  123. except Exception as error:
  124. error_msg = f'Dataset graph log error in TrainLineage begin: {error}'
  125. log.error(error_msg)
  126. raise LineageLogError(error_msg)
  127. log.info('Dataset graph logged successfully.')
  128. @try_except(log)
  129. def end(self, run_context):
  130. """
  131. Collect lineage information when the training job ends.
  132. Args:
  133. run_context (RunContext): It contains all lineage information,
  134. see mindspore.train.callback.RunContext.
  135. Raises:
  136. LineageLogError: If recording lineage information fails.
  137. """
  138. log.info('Start to collect training lineage...')
  139. if not isinstance(run_context, RunContext):
  140. error_msg = f'Invalid TrainLineage run_context.'
  141. log.error(error_msg)
  142. raise LineageParamRunContextError(error_msg)
  143. run_context_args = run_context.original_args()
  144. validate_train_run_context(RunContextArgs, run_context_args)
  145. train_lineage = dict()
  146. train_lineage = AnalyzeObject.get_network_args(
  147. run_context_args, train_lineage
  148. )
  149. train_dataset = run_context_args.get('train_dataset')
  150. callbacks = run_context_args.get('list_callback')
  151. list_callback = getattr(callbacks, '_callbacks', [])
  152. log.info('Obtaining model files...')
  153. ckpt_file_path, _ = AnalyzeObject.get_file_path(list_callback)
  154. train_lineage[Metadata.learning_rate] = self.initial_learning_rate
  155. train_lineage[Metadata.epoch] = run_context_args.get('epoch_num')
  156. train_lineage[Metadata.step_num] = run_context_args.get('cur_step_num')
  157. train_lineage[Metadata.parallel_mode] = run_context_args.get('parallel_mode')
  158. train_lineage[Metadata.device_num] = run_context_args.get('device_number')
  159. train_lineage[Metadata.batch_size] = run_context_args.get('batch_num')
  160. model_path_dict = {
  161. 'ckpt': ckpt_file_path
  162. }
  163. train_lineage[Metadata.model_path] = json.dumps(model_path_dict)
  164. log.info('Calculating model size...')
  165. train_lineage[Metadata.model_size] = AnalyzeObject.get_model_size(
  166. ckpt_file_path
  167. )
  168. log.debug('model_size: %s', train_lineage[Metadata.model_size])
  169. log.info('Analyzing dataset object...')
  170. train_lineage = AnalyzeObject.analyze_dataset(train_dataset, train_lineage, 'train')
  171. log.info('Logging lineage information...')
  172. try:
  173. lineage_summary = LineageSummary(self.lineage_log_path)
  174. lineage_summary.record_train_lineage(train_lineage)
  175. except IOError as error:
  176. error_msg = f'End error in TrainLineage: {error}'
  177. log.error(error_msg)
  178. raise LineageLogError(error_msg)
  179. except Exception as error:
  180. error_msg = f'End error in TrainLineage: {error}'
  181. log.error(error_msg)
  182. log.error('Fail to log the lineage of the training job.')
  183. raise LineageLogError(error_msg)
  184. log.info('The lineage of the training job has logged successfully.')
  185. class EvalLineage(Callback):
  186. """
  187. Collect lineage of an evaluation job.
  188. Args:
  189. summary_record (SummaryRecord): SummaryRecord is used to record
  190. the summary value, and summary_record is an instance of SummaryRecord,
  191. see mindspore.train.summary.SummaryRecord.
  192. raise_exception (bool): Whether to raise exception when error occurs in
  193. EvalLineage. If True, raise exception. If False, catch exception
  194. and continue. Default: False.
  195. Raises:
  196. MindInsightException: If validating parameter fails.
  197. LineageLogError: If recording lineage information fails.
  198. Examples:
  199. >>> from mindinsight.lineagemgr import EvalLineage
  200. >>> from mindspore.train.callback import ModelCheckpoint, SummaryStep
  201. >>> from mindspore.train.summary import SummaryRecord
  202. >>> model = Model(train_network)
  203. >>> model_ckpt = ModelCheckpoint(directory='/dir/to/save/model/')
  204. >>> summary_writer = SummaryRecord(log_dir='./')
  205. >>> summary_callback = SummaryStep(summary_writer, flush_step=2)
  206. >>> lineagemgr = EvalLineage(summary_record=summary_writer)
  207. >>> model.eval(epoch_num, dataset, callbacks=[model_ckpt, summary_callback, lineagemgr])
  208. """
  209. def __init__(self, summary_record, raise_exception=False):
  210. super(EvalLineage, self).__init__()
  211. try:
  212. validate_raise_exception(raise_exception)
  213. self.raise_exception = raise_exception
  214. validate_summary_record(summary_record)
  215. self.summary_record = summary_record
  216. summary_log_path = summary_record.full_file_name
  217. validate_file_path(summary_log_path)
  218. self.lineage_log_path = summary_log_path + '_lineage'
  219. except MindInsightException as err:
  220. log.error(err)
  221. if raise_exception:
  222. raise
  223. @try_except(log)
  224. def end(self, run_context):
  225. """
  226. Collect lineage information when the training job ends.
  227. Args:
  228. run_context (RunContext): It contains all lineage information,
  229. see mindspore.train.callback.RunContext.
  230. Raises:
  231. MindInsightException: If validating parameter fails.
  232. LineageLogError: If recording lineage information fails.
  233. """
  234. if not isinstance(run_context, RunContext):
  235. error_msg = f'Invalid EvalLineage run_context.'
  236. log.error(error_msg)
  237. raise LineageParamRunContextError(error_msg)
  238. run_context_args = run_context.original_args()
  239. validate_eval_run_context(EvalParameter, run_context_args)
  240. valid_dataset = run_context_args.get('valid_dataset')
  241. eval_lineage = dict()
  242. metrics = run_context_args.get('metrics')
  243. eval_lineage[Metadata.metrics] = json.dumps(metrics)
  244. eval_lineage[Metadata.step_num] = run_context_args.get('cur_step_num')
  245. log.info('Analyzing dataset object...')
  246. eval_lineage = AnalyzeObject.analyze_dataset(valid_dataset, eval_lineage, 'valid')
  247. log.info('Logging evaluation job lineage...')
  248. try:
  249. lineage_summary = LineageSummary(self.lineage_log_path)
  250. lineage_summary.record_evaluation_lineage(eval_lineage)
  251. except IOError as error:
  252. error_msg = f'End error in EvalLineage: {error}'
  253. log.error(error_msg)
  254. log.error('Fail to log the lineage of the evaluation job.')
  255. raise LineageLogError(error_msg)
  256. except Exception as error:
  257. error_msg = f'End error in EvalLineage: {error}'
  258. log.error(error_msg)
  259. log.error('Fail to log the lineage of the evaluation job.')
  260. raise LineageLogError(error_msg)
  261. log.info('The lineage of the evaluation job has logged successfully.')
  262. class AnalyzeObject:
  263. """Analyze class object in MindSpore."""
  264. @staticmethod
  265. def get_optimizer_by_network(network):
  266. """
  267. Get optimizer by analyzing network.
  268. Args:
  269. network (Cell): See mindspore.nn.Cell.
  270. Returns:
  271. Optimizer, an Optimizer object.
  272. """
  273. optimizer = None
  274. net_args = vars(network) if network else {}
  275. net_cell = net_args.get('_cells') if net_args else {}
  276. for _, value in net_cell.items():
  277. if isinstance(value, Optimizer):
  278. optimizer = value
  279. break
  280. return optimizer
  281. @staticmethod
  282. def get_loss_fn_by_network(network):
  283. """
  284. Get loss function by analyzing network.
  285. Args:
  286. network (Cell): See mindspore.nn.Cell.
  287. Returns:
  288. Loss_fn, a Cell object.
  289. """
  290. loss_fn = None
  291. inner_cell_list = []
  292. net_args = vars(network) if network else {}
  293. net_cell = net_args.get('_cells') if net_args else {}
  294. for _, value in net_cell.items():
  295. if isinstance(value, Cell) and \
  296. not isinstance(value, Optimizer):
  297. inner_cell_list.append(value)
  298. while inner_cell_list:
  299. inner_net_args = vars(inner_cell_list[0])
  300. inner_net_cell = inner_net_args.get('_cells')
  301. for value in inner_net_cell.values():
  302. if isinstance(value, _Loss):
  303. loss_fn = value
  304. break
  305. if isinstance(value, Cell):
  306. inner_cell_list.append(value)
  307. if loss_fn:
  308. break
  309. inner_cell_list.pop(0)
  310. return loss_fn
  311. @staticmethod
  312. def get_backbone_network(network):
  313. """
  314. Get the name of backbone network.
  315. Args:
  316. network (Cell): The train network.
  317. Returns:
  318. str, the name of the backbone network.
  319. """
  320. with_loss_cell = False
  321. backbone = None
  322. net_args = vars(network) if network else {}
  323. net_cell = net_args.get('_cells') if net_args else {}
  324. for _, value in net_cell.items():
  325. if isinstance(value, WithLossCell):
  326. backbone = getattr(value, '_backbone')
  327. with_loss_cell = True
  328. break
  329. if with_loss_cell:
  330. backbone_name = type(backbone).__name__ \
  331. if backbone else None
  332. elif isinstance(network, TrainOneStepWithLossScaleCell):
  333. backbone = getattr(network, 'network')
  334. backbone_name = type(backbone).__name__ \
  335. if backbone else None
  336. else:
  337. backbone_name = type(network).__name__ \
  338. if network else None
  339. return backbone_name
  340. @staticmethod
  341. def analyze_optimizer(optimizer):
  342. """
  343. Analyze Optimizer, a Cell object of MindSpore.
  344. In this way, we can obtain the following attributes:
  345. learning_rate (float),
  346. weight_decay (float),
  347. momentum (float),
  348. weights (float).
  349. Args:
  350. optimizer (Optimizer): See mindspore.nn.optim.Optimizer.
  351. Returns:
  352. float, the learning rate that the optimizer adopted.
  353. """
  354. learning_rate = None
  355. if isinstance(optimizer, Optimizer):
  356. learning_rate = getattr(optimizer, 'learning_rate', None)
  357. if learning_rate:
  358. learning_rate = learning_rate.default_input
  359. # Get the real learning rate value
  360. if isinstance(learning_rate, Tensor):
  361. learning_rate = learning_rate.asnumpy()
  362. if learning_rate.ndim == 0:
  363. learning_rate = np.atleast_1d(learning_rate)
  364. learning_rate = list(learning_rate)
  365. elif isinstance(learning_rate, float):
  366. learning_rate = [learning_rate]
  367. return learning_rate[0] if learning_rate else None
  368. @staticmethod
  369. def analyze_dataset(dataset, lineage_dict, dataset_type):
  370. """
  371. Analyze Dataset, a Dataset object of MindSpore.
  372. In this way, we can obtain the following attributes:
  373. dataset_path (str),
  374. train_dataset_size (int),
  375. valid_dataset_size (int),
  376. batch_size (int)
  377. Args:
  378. dataset (Dataset): See mindspore.dataengine.datasets.Dataset.
  379. lineage_dict (dict): A dict contains lineage metadata.
  380. dataset_type (str): Dataset type, train or valid.
  381. Returns:
  382. dict, the lineage metadata.
  383. """
  384. dataset_batch_size = dataset.get_dataset_size()
  385. if dataset_batch_size is not None:
  386. validate_int_params(dataset_batch_size, 'dataset_batch_size')
  387. log.debug('dataset_batch_size: %d', dataset_batch_size)
  388. dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset)
  389. if dataset_path:
  390. dataset_path = '/'.join(dataset_path.split('/')[:-1])
  391. step_num = lineage_dict.get('step_num')
  392. validate_int_params(step_num, 'step_num')
  393. log.debug('step_num: %d', step_num)
  394. if dataset_type == 'train':
  395. lineage_dict[Metadata.train_dataset_path] = dataset_path
  396. epoch = lineage_dict.get('epoch')
  397. train_dataset_size = dataset_batch_size * (step_num / epoch)
  398. lineage_dict[Metadata.train_dataset_size] = int(train_dataset_size)
  399. elif dataset_type == 'valid':
  400. lineage_dict[Metadata.valid_dataset_path] = dataset_path
  401. lineage_dict[Metadata.valid_dataset_size] = dataset_batch_size * step_num
  402. return lineage_dict
  403. def get_dataset_path(self, output_dataset):
  404. """
  405. Get dataset path of MindDataset object.
  406. Args:
  407. output_dataset (Union[MindDataset, Dataset]): See
  408. mindspore.dataengine.datasets.Dataset.
  409. Returns:
  410. str, dataset path.
  411. """
  412. if isinstance(output_dataset, MindDataset):
  413. return output_dataset.dataset_file
  414. return self.get_dataset_path(output_dataset.input[0])
  415. @staticmethod
  416. def get_dataset_path_wrapped(dataset):
  417. """
  418. A wrapper for obtaining dataset path.
  419. Args:
  420. dataset (Union[MindDataset, Dataset]): See
  421. mindspore.dataengine.datasets.Dataset.
  422. Returns:
  423. str, dataset path.
  424. """
  425. dataset_path = None
  426. if isinstance(dataset, Dataset):
  427. try:
  428. dataset_path = AnalyzeObject().get_dataset_path(dataset)
  429. except IndexError:
  430. dataset_path = None
  431. validate_file_path(dataset_path, allow_empty=True)
  432. return dataset_path
  433. @staticmethod
  434. def get_file_path(list_callback):
  435. """
  436. Get ckpt_file_name and summary_log_path from MindSpore callback list.
  437. Args:
  438. list_callback (list[Callback]): The MindSpore training Callback list.
  439. Returns:
  440. tuple, contains ckpt_file_name and summary_log_path.
  441. """
  442. ckpt_file_path = None
  443. summary_log_path = None
  444. for callback in list_callback:
  445. if isinstance(callback, ModelCheckpoint):
  446. ckpt_file_path = callback.latest_ckpt_file_name
  447. if isinstance(callback, SummaryStep):
  448. summary_log_path = callback.summary_file_name
  449. if ckpt_file_path:
  450. validate_file_path(ckpt_file_path)
  451. ckpt_file_path = os.path.realpath(ckpt_file_path)
  452. if summary_log_path:
  453. validate_file_path(summary_log_path)
  454. summary_log_path = os.path.realpath(summary_log_path)
  455. return ckpt_file_path, summary_log_path
  456. @staticmethod
  457. def get_file_size(file_path):
  458. """
  459. Get the file size.
  460. Args:
  461. file_path (str): The file path.
  462. Returns:
  463. int, the file size.
  464. """
  465. try:
  466. return os.path.getsize(file_path)
  467. except (OSError, IOError) as error:
  468. error_msg = f"Error when get model file size: {error}"
  469. log.error(error_msg)
  470. raise LineageGetModelFileError(error_msg)
  471. @staticmethod
  472. def get_model_size(ckpt_file_path):
  473. """
  474. Get model the total size of the model file and the checkpoint file.
  475. Args:
  476. ckpt_file_path (str): The checkpoint file path.
  477. Returns:
  478. int, the total file size.
  479. """
  480. if ckpt_file_path:
  481. ckpt_file_path = os.path.realpath(ckpt_file_path)
  482. ckpt_file_size = AnalyzeObject.get_file_size(ckpt_file_path)
  483. else:
  484. ckpt_file_size = 0
  485. return ckpt_file_size
  486. @staticmethod
  487. def get_network_args(run_context_args, train_lineage):
  488. """
  489. Get the parameters related to the network,
  490. such as optimizer, loss function.
  491. Args:
  492. run_context_args (dict): It contains all information of the training job.
  493. train_lineage (dict): A dict contains lineage metadata.
  494. Returns:
  495. dict, the lineage metadata.
  496. """
  497. network = run_context_args.get('train_network')
  498. validate_network(network)
  499. optimizer = run_context_args.get('optimizer')
  500. if not optimizer:
  501. optimizer = AnalyzeObject.get_optimizer_by_network(network)
  502. loss_fn = run_context_args.get('loss_fn')
  503. if not loss_fn:
  504. loss_fn = AnalyzeObject.get_loss_fn_by_network(network)
  505. loss = None
  506. else:
  507. loss = run_context_args.get('net_outputs')
  508. if loss:
  509. log.info('Calculating loss...')
  510. loss_numpy = loss.asnumpy()
  511. loss = float(np.atleast_1d(loss_numpy)[0])
  512. log.debug('loss: %s', loss)
  513. train_lineage[Metadata.loss] = loss
  514. else:
  515. train_lineage[Metadata.loss] = None
  516. # Analyze classname of optimizer, loss function and training network.
  517. train_lineage[Metadata.optimizer] = type(optimizer).__name__ \
  518. if optimizer else None
  519. train_lineage[Metadata.train_network] = AnalyzeObject.get_backbone_network(network)
  520. train_lineage[Metadata.loss_function] = type(loss_fn).__name__ \
  521. if loss_fn else None
  522. return train_lineage

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。

Contributors (1)