You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sequence_classification_trainer.py 1.2 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import unittest
  2. import zipfile
  3. from pathlib import Path
  4. from maas_lib.fileio import File
  5. from maas_lib.trainers import build_trainer
  6. from maas_lib.utils.logger import get_logger
  7. logger = get_logger()
  8. class SequenceClassificationTrainerTest(unittest.TestCase):
  9. def test_sequence_classification(self):
  10. model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
  11. '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
  12. cache_path_str = r'.cache/easynlp/bert-base-sst2.zip'
  13. cache_path = Path(cache_path_str)
  14. if not cache_path.exists():
  15. cache_path.parent.mkdir(parents=True, exist_ok=True)
  16. cache_path.touch(exist_ok=True)
  17. with cache_path.open('wb') as ofile:
  18. ofile.write(File.read(model_url))
  19. with zipfile.ZipFile(cache_path_str, 'r') as zipf:
  20. zipf.extractall(cache_path.parent)
  21. path: str = './configs/nlp/sequence_classification_trainer.yaml'
  22. default_args = dict(cfg_file=path)
  23. trainer = build_trainer('bert-sentiment-analysis', default_args)
  24. trainer.train()
  25. trainer.evaluate()
  26. if __name__ == '__main__':
  27. unittest.main()
  28. ...

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展