1. remove comments
2. fix a bug that ws assert failure for english input
3. add an english input test for ws
3. remove a test case which the dataset can not be visited by outer website
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9657140
master
| @@ -379,9 +379,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| kwargs['padding'] = kwargs.get( | |||
| 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| kwargs['is_split_into_words'] = kwargs.pop( | |||
| 'is_split_into_words', | |||
| False if mode == ModeKeys.INFERENCE else True) | |||
| self.label_all_tokens = kwargs.pop('label_all_tokens', False) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| @@ -397,22 +394,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| # preprocess the data for the model input | |||
| # if isinstance(data, dict): | |||
| # data = data[self.first_sequence] | |||
| # text = data.replace(' ', '').strip() | |||
| # tokens = [] | |||
| # for token in text: | |||
| # token = self.tokenizer.tokenize(token) | |||
| # tokens.extend(token) | |||
| # input_ids = self.tokenizer.convert_tokens_to_ids(tokens) | |||
| # input_ids = self.tokenizer.build_inputs_with_special_tokens(input_ids) | |||
| # attention_mask = [1] * len(input_ids) | |||
| # token_type_ids = [0] * len(input_ids) | |||
| # new code to deal with labels | |||
| # tokenized_inputs = self.tokenizer(data, truncation=True, is_split_into_words=True) | |||
| text_a = None | |||
| labels_list = None | |||
| if isinstance(data, str): | |||
| @@ -420,10 +401,14 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| elif isinstance(data, dict): | |||
| text_a = data.get(self.first_sequence) | |||
| labels_list = data.get(self.label) | |||
| text_a = text_a.replace(' ', '').strip() | |||
| if isinstance(text_a, str): | |||
| text_a = text_a.replace(' ', '').strip() | |||
| tokenized_inputs = self.tokenizer( | |||
| text_a, | |||
| [t for t in text_a], | |||
| return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, | |||
| is_split_into_words=True, | |||
| **self.tokenize_kwargs) | |||
| if labels_list is not None: | |||
| @@ -15,6 +15,7 @@ from modelscope.utils.test_utils import test_level | |||
| class WordSegmentationTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' | |||
| sentence = '今天天气不错,适合出去游玩' | |||
| sentence_eng = 'I am a program.' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| @@ -42,6 +43,7 @@ class WordSegmentationTest(unittest.TestCase): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.word_segmentation, model=self.model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| print(pipeline_ins(input=self.sentence_eng)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| @@ -45,83 +45,6 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||
| for i in range(10): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skip | |||
| def test_token_classification(self): | |||
| # WS task | |||
| os.system( | |||
| f'curl http://dingkun.oss-cn-hangzhou-zmf.aliyuncs.com/atemp/train.txt > {self.tmp_dir}/train.txt' | |||
| ) | |||
| os.system( | |||
| f'curl http://dingkun.oss-cn-hangzhou-zmf.aliyuncs.com/atemp/dev.txt > {self.tmp_dir}/dev.txt' | |||
| ) | |||
| from datasets import load_dataset | |||
| dataset = load_dataset( | |||
| 'text', | |||
| data_files={ | |||
| 'train': f'{self.tmp_dir}/train.txt', | |||
| 'test': f'{self.tmp_dir}/dev.txt' | |||
| }) | |||
| def split_to_dict(examples): | |||
| text, label = examples['text'].split('\t') | |||
| return { | |||
| 'first_sequence': text.split(' '), | |||
| 'labels': label.split(' ') | |||
| } | |||
| dataset = dataset.map(split_to_dict, batched=False) | |||
| def reducer(x, y): | |||
| x = x.split(' ') if isinstance(x, str) else x | |||
| y = y.split(' ') if isinstance(y, str) else y | |||
| return x + y | |||
| label_enumerate_values = list( | |||
| set(reduce(reducer, dataset['train'][:1000]['labels']))) | |||
| label_enumerate_values.sort() | |||
| def cfg_modify_fn(cfg): | |||
| cfg.task = 'token-classification' | |||
| cfg['preprocessor'] = {'type': 'token-cls-tokenizer'} | |||
| cfg['dataset'] = { | |||
| 'train': { | |||
| 'labels': label_enumerate_values, | |||
| 'first_sequence': 'first_sequence', | |||
| 'label': 'labels', | |||
| } | |||
| } | |||
| cfg.train.max_epochs = 3 | |||
| cfg.train.lr_scheduler = { | |||
| 'type': 'LinearLR', | |||
| 'start_factor': 1.0, | |||
| 'end_factor': 0.0, | |||
| 'total_iters': | |||
| int(len(dataset['train']) / 32) * cfg.train.max_epochs, | |||
| 'options': { | |||
| 'by_epoch': False | |||
| } | |||
| } | |||
| cfg.train.hooks = [{ | |||
| 'type': 'CheckpointHook', | |||
| 'interval': 1 | |||
| }, { | |||
| 'type': 'TextLoggerHook', | |||
| 'interval': 1 | |||
| }, { | |||
| 'type': 'IterTimerHook' | |||
| }, { | |||
| 'type': 'EvaluationHook', | |||
| 'by_epoch': False, | |||
| 'interval': 300 | |||
| }] | |||
| return cfg | |||
| self.finetune( | |||
| 'damo/nlp_structbert_backbone_tiny_std', | |||
| dataset['train'], | |||
| dataset['test'], | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| @unittest.skip | |||
| def test_word_segmentation(self): | |||
| os.system( | |||