You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sequence_classification_trainer.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import time
  2. from typing import Callable, Dict, List, Optional, Tuple, Union
  3. import numpy as np
  4. from maas_lib.utils.constant import Tasks
  5. from maas_lib.utils.logger import get_logger
  6. from ..base import BaseTrainer
  7. from ..builder import TRAINERS
  8. # __all__ = ["SequenceClassificationTrainer"]
  9. PATH = None
  10. logger = get_logger(PATH)
  11. @TRAINERS.register_module(
  12. Tasks.text_classification, module_name=r'bert-sentiment-analysis')
  13. class SequenceClassificationTrainer(BaseTrainer):
  14. def __init__(self, cfg_file: str, *args, **kwargs):
  15. """ A trainer is used for Sequence Classification
  16. Based on Config file (*.yaml or *.json), the trainer trains or evaluates on a dataset
  17. Args:
  18. cfg_file (str): the path of config file
  19. Raises:
  20. ValueError: _description_
  21. """
  22. super().__init__(cfg_file)
  23. def train(self, *args, **kwargs):
  24. logger.info('Train')
  25. ...
  26. def __attr_is_exist(self, attr: str) -> Tuple[Union[str, bool]]:
  27. """get attribute from config, if the attribute does exist, return false
  28. Example:
  29. >>> self.__attr_is_exist("model path")
  30. out: (model-path, "/workspace/bert-base-sst2")
  31. >>> self.__attr_is_exist("model weights")
  32. out: (model-weights, False)
  33. Args:
  34. attr (str): attribute str, "model path" -> config["model"][path]
  35. Returns:
  36. Tuple[Union[str, bool]]:[target attribute name, the target attribute or False]
  37. """
  38. paths = attr.split(' ')
  39. attr_str: str = '-'.join(paths)
  40. target = self.cfg[paths[0]] if hasattr(self.cfg, paths[0]) else None
  41. for path_ in paths[1:]:
  42. if not hasattr(target, path_):
  43. return attr_str, False
  44. target = target[path_]
  45. if target and target != '':
  46. return attr_str, target
  47. return attr_str, False
  48. def evaluate(self,
  49. checkpoint_path: Optional[str] = None,
  50. *args,
  51. **kwargs) -> Dict[str, float]:
  52. """evaluate a dataset
  53. evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
  54. does not exist, read from the config file.
  55. Args:
  56. checkpoint_path (Optional[str], optional): the model path. Defaults to None.
  57. Returns:
  58. Dict[str, float]: the results about the evaluation
  59. Example:
  60. {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
  61. """
  62. import torch
  63. from easynlp.appzoo import load_dataset
  64. from easynlp.appzoo.dataset import GeneralDataset
  65. from easynlp.appzoo.sequence_classification.model import SequenceClassification
  66. from easynlp.utils import losses
  67. from sklearn.metrics import f1_score
  68. from torch.utils.data import DataLoader
  69. raise_str = 'Attribute {} is not given in config file!'
  70. metrics = self.__attr_is_exist('evaluation metrics')
  71. eval_batch_size = self.__attr_is_exist('evaluation batch_size')
  72. test_dataset_path = self.__attr_is_exist('dataset valid file')
  73. attrs = [metrics, eval_batch_size, test_dataset_path]
  74. for attr_ in attrs:
  75. if not attr_[-1]:
  76. raise AttributeError(raise_str.format(attr_[0]))
  77. if not checkpoint_path:
  78. checkpoint_path = self.__attr_is_exist('evaluation model_path')[-1]
  79. if not checkpoint_path:
  80. raise ValueError(
  81. 'Argument checkout_path must be passed if the evaluation-model_path is not given in config file!'
  82. )
  83. max_sequence_length = kwargs.get(
  84. 'max_sequence_length',
  85. self.__attr_is_exist('evaluation max_sequence_length')[-1])
  86. if not max_sequence_length:
  87. raise ValueError(
  88. 'Argument max_sequence_length must be passed '
  89. 'if the evaluation-max_sequence_length does not exist in config file!'
  90. )
  91. # get the raw online dataset
  92. raw_dataset = load_dataset(*test_dataset_path[-1].split('/'))
  93. valid_dataset = raw_dataset['validation']
  94. # generate a standard dataloader
  95. pre_dataset = GeneralDataset(valid_dataset, checkpoint_path,
  96. max_sequence_length)
  97. valid_dataloader = DataLoader(
  98. pre_dataset,
  99. batch_size=eval_batch_size[-1],
  100. shuffle=False,
  101. collate_fn=pre_dataset.batch_fn)
  102. # generate a model
  103. model = SequenceClassification.from_pretrained(checkpoint_path)
  104. # copy from easynlp (start)
  105. model.eval()
  106. total_loss = 0
  107. total_steps = 0
  108. total_samples = 0
  109. hit_num = 0
  110. total_num = 0
  111. logits_list = list()
  112. y_trues = list()
  113. total_spent_time = 0.0
  114. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  115. model.to(device)
  116. for _step, batch in enumerate(valid_dataloader):
  117. try:
  118. batch = {
  119. # key: val.cuda() if isinstance(val, torch.Tensor) else val
  120. # for key, val in batch.items()
  121. key:
  122. val.to(device) if isinstance(val, torch.Tensor) else val
  123. for key, val in batch.items()
  124. }
  125. except RuntimeError:
  126. batch = {key: val for key, val in batch.items()}
  127. infer_start_time = time.time()
  128. with torch.no_grad():
  129. label_ids = batch.pop('label_ids')
  130. outputs = model(batch)
  131. infer_end_time = time.time()
  132. total_spent_time += infer_end_time - infer_start_time
  133. assert 'logits' in outputs
  134. logits = outputs['logits']
  135. y_trues.extend(label_ids.tolist())
  136. logits_list.extend(logits.tolist())
  137. hit_num += torch.sum(
  138. torch.argmax(logits, dim=-1) == label_ids).item()
  139. total_num += label_ids.shape[0]
  140. if len(logits.shape) == 1 or logits.shape[-1] == 1:
  141. tmp_loss = losses.mse_loss(logits, label_ids)
  142. elif len(logits.shape) == 2:
  143. tmp_loss = losses.cross_entropy(logits, label_ids)
  144. else:
  145. raise RuntimeError
  146. total_loss += tmp_loss.mean().item()
  147. total_steps += 1
  148. total_samples += valid_dataloader.batch_size
  149. if (_step + 1) % 100 == 0:
  150. total_step = len(
  151. valid_dataloader.dataset) // valid_dataloader.batch_size
  152. logger.info('Eval: {}/{} steps finished'.format(
  153. _step + 1, total_step))
  154. logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format(
  155. total_spent_time, total_spent_time * 1000 / total_samples))
  156. eval_loss = total_loss / total_steps
  157. logger.info('Eval loss: {}'.format(eval_loss))
  158. logits_list = np.array(logits_list)
  159. eval_outputs = list()
  160. for metric in metrics[-1]:
  161. if metric.endswith('accuracy'):
  162. acc = hit_num / total_num
  163. logger.info('Accuracy: {}'.format(acc))
  164. eval_outputs.append(('accuracy', acc))
  165. elif metric == 'f1':
  166. if model.config.num_labels == 2:
  167. f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1))
  168. logger.info('F1: {}'.format(f1))
  169. eval_outputs.append(('f1', f1))
  170. else:
  171. f1 = f1_score(
  172. y_trues,
  173. np.argmax(logits_list, axis=-1),
  174. average='macro')
  175. logger.info('Macro F1: {}'.format(f1))
  176. eval_outputs.append(('macro-f1', f1))
  177. f1 = f1_score(
  178. y_trues,
  179. np.argmax(logits_list, axis=-1),
  180. average='micro')
  181. logger.info('Micro F1: {}'.format(f1))
  182. eval_outputs.append(('micro-f1', f1))
  183. else:
  184. raise NotImplementedError('Metric %s not implemented' % metric)
  185. # copy from easynlp (end)
  186. return dict(eval_outputs)

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展