diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 7ae34de9..2c52d104 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -690,11 +690,11 @@ class Trainer(object): (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) - eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, - self.n_steps) + \ - self.tester._format_eval_results(eval_res) + eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, + self.n_steps) # pbar.write(eval_str + '\n') - self.logger.info(eval_str + '\n') + self.logger.info(eval_str) + self.logger.info(self.tester._format_eval_results(eval_res)+'\n') # ================= mini-batch end ==================== # # lr decay; early stopping @@ -907,7 +907,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL info_str += '\n' else: info_str += 'There is no target field.' - print(info_str) + logger.info(info_str) _check_forward_error(forward_func=forward_func, dataset=dataset, batch_x=batch_x, check_level=check_level) refined_batch_x = _build_args(forward_func, **batch_x) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index cf0b57b0..bc0d46e2 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -67,8 +67,8 @@ class BertEmbedding(ContextualEmbedding): model_url = _get_embedding_url('bert', model_dir_or_name.lower()) model_dir = cached_path(model_url, name='embedding') # 检查是否存在 - elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): - model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name)) + elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): + model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py index 435e0b98..24cd052e 100644 --- a/fastNLP/embeddings/elmo_embedding.py +++ b/fastNLP/embeddings/elmo_embedding.py @@ -59,7 +59,7 @@ class ElmoEmbedding(ContextualEmbedding): model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) model_dir = cached_path(model_url, name='embedding') # 检查是否存在 - elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): + elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): model_dir = model_dir_or_name else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index ac9611fe..4079b2a2 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -70,10 +70,10 @@ class StaticEmbedding(TokenEmbedding): model_url = _get_embedding_url('static', model_dir_or_name.lower()) model_path = cached_path(model_url, name='embedding') # 检查是否存在 - elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): - model_path = model_dir_or_name - elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): - model_path = _get_file_name_base_on_postfix(model_dir_or_name, '.txt') + elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))): + model_path = os.path.abspath(os.path.expanduser(model_dir_or_name)) + elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): + model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") @@ -94,7 +94,7 @@ class StaticEmbedding(TokenEmbedding): no_create_entry=truncated_vocab._is_word_no_create_entry(word)) # 只限制在train里面的词语使用min_freq筛选 - if kwargs.get('only_train_min_freq', False): + if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None: for word in truncated_vocab.word_count.keys(): if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]str: + """ + 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, + 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str + """ + if self.dataset_name is None: + return None + data_dir = self._get_dataset_path(dataset_name=self.dataset_name) + modify_time = 0 + for filepath in glob.glob(os.path.join(data_dir, '*')): + modify_time = os.stat(filepath).st_mtime + break + if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=self.dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.txt')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + try: + with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.txt')) + os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): + os.remove(os.path.join(data_dir, 'middle_file.txt')) + + return data_dir diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 9ffb9ed6..1907af4a 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -21,6 +21,7 @@ __all__ = [ "MsraNERPipe", "WeiboNERPipe", "PeopleDailyPipe", + "Conll2003Pipe", "MatchingBertPipe", "RTEBertPipe", @@ -41,3 +42,4 @@ from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe from .pipe import Pipe +from .conll import Conll2003Pipe diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index d253f3be..617d1236 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -19,16 +19,14 @@ class _NERPipe(Pipe): :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 - :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 """ - def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0): + def __init__(self, encoding_type: str = 'bio', lower: bool = False): if encoding_type == 'bio': self.convert_tag = iob2 else: self.convert_tag = lambda words: iob2bioes(iob2(words)) self.lower = lower - self.target_pad_val = int(target_pad_val) def process(self, data_bundle: DataBundle) -> DataBundle: """ @@ -58,7 +56,6 @@ class _NERPipe(Pipe): target_fields = [Const.TARGET, Const.INPUT_LEN] for name, dataset in data_bundle.datasets.items(): - dataset.set_pad_val(Const.TARGET, self.target_pad_val) dataset.add_seq_len(Const.INPUT) data_bundle.set_input(*input_fields) @@ -86,7 +83,6 @@ class Conll2003NERPipe(_NERPipe): :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 - :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 """ def process_from_file(self, paths) -> DataBundle: @@ -103,7 +99,7 @@ class Conll2003NERPipe(_NERPipe): class Conll2003Pipe(Pipe): - def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, target_pad_val=0): + def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): """ 经过该Pipe后,DataSet中的内容如下 @@ -119,7 +115,6 @@ class Conll2003Pipe(Pipe): :param str chunk_encoding_type: 支持bioes, bio。 :param str ner_encoding_type: 支持bioes, bio。 :param bool lower: 是否将words列小写化后再建立词表 - :param int target_pad_val: pos, ner, chunk列的padding值 """ if chunk_encoding_type == 'bio': self.chunk_convert_tag = iob2 @@ -130,7 +125,6 @@ class Conll2003Pipe(Pipe): else: self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) self.lower = lower - self.target_pad_val = int(target_pad_val) def process(self, data_bundle)->DataBundle: """ @@ -166,9 +160,6 @@ class Conll2003Pipe(Pipe): target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] for name, dataset in data_bundle.datasets.items(): - dataset.set_pad_val('pos', self.target_pad_val) - dataset.set_pad_val('ner', self.target_pad_val) - dataset.set_pad_val('chunk', self.target_pad_val) dataset.add_seq_len(Const.INPUT) data_bundle.set_input(*input_fields) @@ -202,7 +193,6 @@ class OntoNotesNERPipe(_NERPipe): :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 - :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 """ def process_from_file(self, paths): @@ -220,15 +210,13 @@ class _CNNERPipe(Pipe): target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 - :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 """ - def __init__(self, encoding_type: str = 'bio', target_pad_val=0): + def __init__(self, encoding_type: str = 'bio'): if encoding_type == 'bio': self.convert_tag = iob2 else: self.convert_tag = lambda words: iob2bioes(iob2(words)) - self.target_pad_val = int(target_pad_val) def process(self, data_bundle: DataBundle) -> DataBundle: """ @@ -261,7 +249,6 @@ class _CNNERPipe(Pipe): target_fields = [Const.TARGET, Const.INPUT_LEN] for name, dataset in data_bundle.datasets.items(): - dataset.set_pad_val(Const.TARGET, self.target_pad_val) dataset.add_seq_len(Const.CHAR_INPUT) data_bundle.set_input(*input_fields) @@ -324,7 +311,6 @@ class WeiboNERPipe(_CNNERPipe): target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 - :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 """ def process_from_file(self, paths=None) -> DataBundle: data_bundle = WeiboNERLoader().load(paths) diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py new file mode 100644 index 00000000..6ea1ae0c --- /dev/null +++ b/fastNLP/io/pipe/cws.py @@ -0,0 +1,246 @@ +from .pipe import Pipe +from .. import DataBundle +from ..loader import CWSLoader +from ... import Const +from itertools import chain +from .utils import _indexize +import re +def _word_lens_to_bmes(word_lens): + """ + + :param list word_lens: List[int], 每个词语的长度 + :return: List[str], BMES的序列 + """ + tags = [] + for word_len in word_lens: + if word_len==1: + tags.append('S') + else: + tags.append('B') + tags.extend(['M']*(word_len-2)) + tags.append('E') + return tags + + +def _word_lens_to_segapp(word_lens): + """ + + :param list word_lens: List[int], 每个词语的长度 + :return: List[str], BMES的序列 + """ + tags = [] + for word_len in word_lens: + if word_len==1: + tags.append('SEG') + else: + tags.extend(['APP']*(word_len-1)) + tags.append('SEG') + return tags + + +def _alpha_span_to_special_tag(span): + """ + 将span替换成特殊的字符 + + :param str span: + :return: + """ + if 'oo' == span.lower(): # speical case when represent 2OO8 + return span + if len(span) == 1: + return span + else: + return '' + + +def _find_and_replace_alpha_spans(line): + """ + 传入原始句子,替换其中的字母为特殊标记 + + :param str line:原始数据 + :return: str + """ + new_line = '' + pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%,.。!<-“])' + prev_end = 0 + for match in re.finditer(pattern, line): + start, end = match.span() + span = line[start:end] + new_line += line[prev_end:start] + _alpha_span_to_special_tag(span) + prev_end = end + new_line += line[prev_end:] + return new_line + + +def _digit_span_to_special_tag(span): + """ + + :param str span: 需要替换的str + :return: + """ + if span[0] == '0' and len(span) > 2: + return '' + decimal_point_count = 0 # one might have more than one decimal pointers + for idx, char in enumerate(span): + if char == '.' or char == '﹒' or char == '·': + decimal_point_count += 1 + if span[-1] == '.' or span[-1] == '﹒' or span[ + -1] == '·': # last digit being decimal point means this is not a number + if decimal_point_count == 1: + return span + else: + return '' + if decimal_point_count == 1: + return '' + elif decimal_point_count > 1: + return '' + else: + return '' + +def _find_and_replace_digit_spans(line): + # only consider words start with number, contains '.', characters. + # If ends with space, will be processed + # If ends with Chinese character, will be processed + # If ends with or contains english char, not handled. + # floats are replaced by + # otherwise unkdgt + new_line = '' + pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])' + prev_end = 0 + for match in re.finditer(pattern, line): + start, end = match.span() + span = line[start:end] + new_line += line[prev_end:start] + _digit_span_to_special_tag(span) + prev_end = end + new_line += line[prev_end:] + return new_line + + +class CWSPipe(Pipe): + """ + 对CWS数据进行预处理, 处理之后的数据,具备以下的结构 + + .. csv-table:: + :header: "raw_words", "chars", "target", "bigrams", "trigrams", "seq_len" + + "共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", "[10, 4, 1,...]","[6, 4, 1,...]", 13 + "2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", "[11, 12, ...]","[3, 9, ...]", 20 + "...", "[...]","[...]", "[...]","[...]", . + + 其中bigrams仅当bigrams列为True的时候为真 + + :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None + :param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp + 的tag为[seg, app, seg, app, app, app, seg, ...] + :param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 + :param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] + :param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + """ + def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): + if encoding_type=='bmes': + self.word_lens_to_tags = _word_lens_to_bmes + else: + self.word_lens_to_tags = _word_lens_to_segapp + + self.dataset_name = dataset_name + self.bigrams = bigrams + self.trigrams = trigrams + self.replace_num_alpha = replace_num_alpha + + def _tokenize(self, data_bundle): + """ + 将data_bundle中的'chars'列切分成一个一个的word. + 例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] + + :param data_bundle: + :return: + """ + def split_word_into_chars(raw_chars): + words = raw_chars.split() + chars = [] + for word in words: + char = [] + subchar = [] + for c in word: + if c=='<': + subchar.append(c) + continue + if c=='>' and subchar[0]=='<': + char.append(''.join(subchar)) + subchar = [] + if subchar: + subchar.append(c) + else: + char.append(c) + char.extend(subchar) + chars.append(char) + return chars + + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, + new_field_name=Const.CHAR_INPUT) + return data_bundle + + def process(self, data_bundle: DataBundle) -> DataBundle: + """ + 可以处理的DataSet需要包含raw_words列 + + .. csv-table:: + :header: "raw_words" + + "上海 浦东 开发 与 法制 建设 同步" + "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" + "..." + + :param data_bundle: + :return: + """ + data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) + + if self.replace_num_alpha: + data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) + data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) + + self._tokenize(data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(lambda chars:self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT, + new_field_name=Const.TARGET) + dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT, + new_field_name=Const.CHAR_INPUT) + input_field_names = [Const.CHAR_INPUT] + if self.bigrams: + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+[''])], + field_name=Const.CHAR_INPUT, new_field_name='bigrams') + input_field_names.append('bigrams') + if self.trigrams: + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(lambda chars: [c1+c2+c3 for c1, c2, c3 in zip(chars, chars[1:]+[''], chars[2:]+['']*2)], + field_name=Const.CHAR_INPUT, new_field_name='trigrams') + input_field_names.append('trigrams') + + _indexize(data_bundle, input_field_names, Const.TARGET) + + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names + target_fields = [Const.TARGET, Const.INPUT_LEN] + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.CHAR_INPUT) + + data_bundle.set_input(*input_fields) + data_bundle.set_target(*target_fields) + + return data_bundle + + def process_from_file(self, paths=None) -> DataBundle: + """ + + :param str paths: + :return: + """ + if self.dataset_name is None and paths is None: + raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") + if self.dataset_name is not None and paths is not None: + raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") + data_bundle = CWSLoader(self.dataset_name).load(paths) + return self.process(data_bundle) \ No newline at end of file diff --git a/reproduction/seqence_labelling/cws/data/CWSDataLoader.py b/reproduction/seqence_labelling/cws/data/CWSDataLoader.py deleted file mode 100644 index 5f69c0ad..00000000 --- a/reproduction/seqence_labelling/cws/data/CWSDataLoader.py +++ /dev/null @@ -1,249 +0,0 @@ - -from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader -from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.data_bundle import DataSetLoader, DataBundle -from typing import Union, Dict, List, Iterator -from fastNLP import DataSet -from fastNLP import Instance -from fastNLP import Vocabulary -from fastNLP import Const -from reproduction.utils import check_dataloader_paths -from functools import partial - -class SigHanLoader(DataSetLoader): - """ - 任务相关的说明可以在这里找到http://sighan.cs.uchicago.edu/ - 支持的数据格式为,一行一句,不同的word用空格隔开。如下例 - - 共同 创造 美好 的 新 世纪 —— 二○○一年 新年 - 女士 们 , 先生 们 , 同志 们 , 朋友 们 : - - 读取sighan中的数据集,返回的DataSet将包含以下的内容fields: - raw_chars: list(str), 每个元素是一个汉字 - chars: list(str), 每个元素是一个index(汉字对应的index) - target: list(int), 根据不同的encoding_type会有不同的变化 - - :param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay" - """ - - def __init__(self, target_type:str): - super().__init__() - - if target_type.lower() not in ('bmes', 'shift_relay'): - raise ValueError("target_type only supports 'bmes', 'shift_relay'.") - - self.target_type = target_type - if target_type=='bmes': - self._word_len_to_target = self._word_len_to_bems - elif target_type=='shift_relay': - self._word_len_to_target = self._word_lens_to_relay - - @staticmethod - def _word_lens_to_relay(word_lens: Iterator[int]): - """ - [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); - :param word_lens: - :return: {'target': , 'end_seg_mask':, 'start_seg_mask':} - """ - tags = [] - end_seg_mask = [] - start_seg_mask = [] - for word_len in word_lens: - tags.extend([idx for idx in range(word_len - 1, -1, -1)]) - end_seg_mask.extend([0] * (word_len - 1) + [1]) - start_seg_mask.extend([1] + [0] * (word_len - 1)) - return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask} - - @staticmethod - def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]: - """ - - :param word_lens: 每个word的长度 - :return: - """ - tags = [] - for word_len in word_lens: - if word_len==1: - tags.append('S') - else: - tags.append('B') - for _ in range(word_len-2): - tags.append('M') - tags.append('E') - return {'target':tags} - - @staticmethod - def _gen_bigram(chars:List[str])->List[str]: - """ - - :param chars: - :return: - """ - return [c1+c2 for c1, c2 in zip(chars, chars[1:]+[''])] - - def load(self, path:str, bigram:bool=False)->DataSet: - """ - :param path: str - :param bigram: 是否使用bigram feature - :return: - """ - dataset = DataSet() - with open(path, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if not line: # 去掉空行 - continue - parts = line.split() - word_lens = map(len, parts) - chars = list(''.join(parts)) - tags = self._word_len_to_target(word_lens) - assert len(chars)==len(tags['target']) - dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars))) - if len(dataset)==0: - raise RuntimeError(f"{path} has no valid data.") - if bigram: - dataset.apply_field(self._gen_bigram, field_name='raw_chars', new_field_name='bigrams') - return dataset - - def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, - char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, - bigram_embed_opt:EmbeddingOption=None, L:int=4): - """ - 支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 - - Option:: - - 共同 创造 美好 的 新 世纪 —— 二○○一年 新年 贺词 - ( 二○○○年 十二月 三十一日 ) ( 附 图片 1 张 ) - 女士 们 , 先生 们 , 同志 们 , 朋友 们 : - - paths支持两种格式,第一种是str,第二种是Dict[str, str]. - - Option:: - - # 1. str类型 - # 1.1 传入具体的文件路径 - data = SigHanLoader('bmes').process('/path/to/cws/data.txt') # 将读取data.txt的内容 - # 包含以下的内容data.vocabs['chars']:Vocabulary对象, - # data.vocabs['target']: Vocabulary对象,根据encoding_type可能会没有该值 - # data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项 - # data.datasets['train']: DataSet对象 - # 包含的field有: - # raw_chars: list[str], 每个元素是一个汉字 - # chars: list[int], 每个元素是汉字对应的index - # target: list[int], 根据encoding_type有对应的变化 - # 1.2 传入一个目录, 里面必须包含train.txt文件 - data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt - # 包含以下的内容data.vocabs['chars']: Vocabulary对象 - # data.vocabs['target']:Vocabulary对象 - # data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象; - # data.datasets['train']: DataSet对象 - # 包含的field有: - # raw_chars: list[str], 每个元素是一个汉字 - # chars: list[int], 每个元素是汉字对应的index - # target: list[int], 根据encoding_type有对应的变化 - # data.datasets['dev']: DataSet对象,如果文件夹下包含了dev.txt;内容与data.datasets['train']一样 - - # 2. dict类型, key是文件的名称,value是对应的读取路径. 必须包含'train'这个key - paths = {'train': '/path/to/train/train.txt', 'test':'/path/to/test/test.txt', 'dev':'/path/to/dev/dev.txt'} - data = SigHanLoader(paths).process(paths) - # 结果与传入目录时是一致的,但是可以传入多个数据集。data.datasets中的key将与这里传入的一致 - - :param paths: 支持传入目录,文件路径,以及dict。 - :param char_vocab_opt: 用于构建chars的vocabulary参数,默认为min_freq=2 - :param char_embed_opt: 用于读取chars的Embedding的参数,默认不读取pretrained的embedding - :param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 - 为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e - :param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 - :param L: 当target_type为shift_relay时传入的segment长度 - :return: - """ - # 推荐大家使用这个check_data_loader_paths进行paths的验证 - paths = check_dataloader_paths(paths) - datasets = {} - data = DataBundle() - bigram = bigram_vocab_opt is not None - for name, path in paths.items(): - dataset = self.load(path, bigram=bigram) - datasets[name] = dataset - input_fields = [] - target_fields = [] - # 创建vocab - char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) - char_vocab.from_dataset(datasets['train'], field_name='raw_chars') - char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') - data.vocabs[Const.CHAR_INPUT] = char_vocab - input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET]) - target_fields.append(Const.TARGET) - # 创建target - if self.target_type == 'bmes': - target_vocab = Vocabulary(unknown=None, padding=None) - target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) - target_vocab.index_dataset(*datasets.values(), field_name='target') - data.vocabs[Const.TARGET] = target_vocab - if char_embed_opt is not None: - char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab) - data.embeddings['chars'] = char_embed - if bigram: - bigram_vocab = Vocabulary(**bigram_vocab_opt) - bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') - bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') - data.vocabs['bigrams'] = bigram_vocab - if bigram_embed_opt is not None: - bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab) - data.embeddings['bigrams'] = bigram_embed - input_fields.append('bigrams') - if self.target_type == 'shift_relay': - func = partial(self._clip_target, L=L) - for name, dataset in datasets.items(): - res = dataset.apply_field(func, field_name='target') - relay_target = [res_i[0] for res_i in res] - relay_mask = [res_i[1] for res_i in res] - dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) - dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) - if self.target_type == 'shift_relay': - input_fields.extend(['end_seg_mask']) - target_fields.append('start_seg_mask') - # 将dataset加入DataInfo - for name, dataset in datasets.items(): - dataset.set_input(*input_fields) - dataset.set_target(*target_fields) - data.datasets[name] = dataset - - return data - - @staticmethod - def _clip_target(target:List[int], L:int): - """ - - 只有在target_type为shift_relay的使用 - :param target: List[int] - :param L: - :return: - """ - relay_target_i = [] - tmp = [] - for j in range(len(target) - 1): - tmp.append(target[j]) - if target[j] > target[j + 1]: - pass - else: - relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) - tmp = [] - # 处理未结束的部分 - if len(tmp) == 0: - relay_target_i.append(0) - else: - tmp.append(target[-1]) - relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) - relay_mask_i = [] - j = 0 - while j < len(target): - seg_len = target[j] + 1 - if target[j] < L: - relay_mask_i.extend([0] * (seg_len)) - else: - relay_mask_i.extend([1] * (seg_len - L) + [0] * L) - j = seg_len + j - return relay_target_i, relay_mask_i - diff --git a/reproduction/seqence_labelling/cws/data/cws_shift_pipe.py b/reproduction/seqence_labelling/cws/data/cws_shift_pipe.py new file mode 100644 index 00000000..0ae4064d --- /dev/null +++ b/reproduction/seqence_labelling/cws/data/cws_shift_pipe.py @@ -0,0 +1,202 @@ +from fastNLP.io.pipe import Pipe +from fastNLP.io import DataBundle +from fastNLP.io.loader import CWSLoader +from fastNLP import Const +from itertools import chain +from fastNLP.io.pipe.utils import _indexize +from functools import partial +from fastNLP.io.pipe.cws import _find_and_replace_alpha_spans, _find_and_replace_digit_spans + + +def _word_lens_to_relay(word_lens): + """ + [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); + :param word_lens: + :return: + """ + tags = [] + for word_len in word_lens: + tags.extend([idx for idx in range(word_len - 1, -1, -1)]) + return tags + +def _word_lens_to_end_seg_mask(word_lens): + """ + [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); + :param word_lens: + :return: + """ + end_seg_mask = [] + for word_len in word_lens: + end_seg_mask.extend([0] * (word_len - 1) + [1]) + return end_seg_mask + +def _word_lens_to_start_seg_mask(word_lens): + """ + [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); + :param word_lens: + :return: + """ + start_seg_mask = [] + for word_len in word_lens: + start_seg_mask.extend([1] + [0] * (word_len - 1)) + return start_seg_mask + + +class CWSShiftRelayPipe(Pipe): + """ + + :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None + :param int L: ShiftRelay模型的超参数 + :param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 + :param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] + :param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + """ + def __init__(self, dataset_name=None, L=5, replace_num_alpha=True, bigrams=True): + self.dataset_name = dataset_name + self.bigrams = bigrams + self.replace_num_alpha = replace_num_alpha + self.L = L + + def _tokenize(self, data_bundle): + """ + 将data_bundle中的'chars'列切分成一个一个的word. + 例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] + + :param data_bundle: + :return: + """ + def split_word_into_chars(raw_chars): + words = raw_chars.split() + chars = [] + for word in words: + char = [] + subchar = [] + for c in word: + if c=='<': + subchar.append(c) + continue + if c=='>' and subchar[0]=='<': + char.append(''.join(subchar)) + subchar = [] + if subchar: + subchar.append(c) + else: + char.append(c) + char.extend(subchar) + chars.append(char) + return chars + + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, + new_field_name=Const.CHAR_INPUT) + return data_bundle + + def process(self, data_bundle: DataBundle) -> DataBundle: + """ + 可以处理的DataSet需要包含raw_words列 + + .. csv-table:: + :header: "raw_words" + + "上海 浦东 开发 与 法制 建设 同步" + "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" + "..." + + :param data_bundle: + :return: + """ + data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) + + if self.replace_num_alpha: + data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) + data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) + + self._tokenize(data_bundle) + input_field_names = [Const.CHAR_INPUT] + target_field_names = [] + + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(lambda chars:_word_lens_to_relay(map(len, chars)), field_name=Const.CHAR_INPUT, + new_field_name=Const.TARGET) + dataset.apply_field(lambda chars:_word_lens_to_start_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT, + new_field_name='start_seg_mask') + dataset.apply_field(lambda chars:_word_lens_to_end_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT, + new_field_name='end_seg_mask') + dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT, + new_field_name=Const.CHAR_INPUT) + target_field_names.append('start_seg_mask') + input_field_names.append('end_seg_mask') + if self.bigrams: + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+[''])], + field_name=Const.CHAR_INPUT, new_field_name='bigrams') + input_field_names.append('bigrams') + + _indexize(data_bundle, ['chars', 'bigrams'], []) + + func = partial(_clip_target, L=self.L) + for name, dataset in data_bundle.datasets.items(): + res = dataset.apply_field(func, field_name='target') + relay_target = [res_i[0] for res_i in res] + relay_mask = [res_i[1] for res_i in res] + dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) + dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) + input_field_names.append('relay_target') + input_field_names.append('relay_mask') + + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names + target_fields = [Const.TARGET, Const.INPUT_LEN] + target_field_names + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.CHAR_INPUT) + + data_bundle.set_input(*input_fields) + data_bundle.set_target(*target_fields) + + return data_bundle + + def process_from_file(self, paths=None) -> DataBundle: + """ + + :param str paths: + :return: + """ + if self.dataset_name is None and paths is None: + raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") + if self.dataset_name is not None and paths is not None: + raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") + data_bundle = CWSLoader(self.dataset_name).load(paths) + return self.process(data_bundle) + +def _clip_target(target, L:int): + """ + + 只有在target_type为shift_relay的使用 + :param target: List[int] + :param L: + :return: + """ + relay_target_i = [] + tmp = [] + for j in range(len(target) - 1): + tmp.append(target[j]) + if target[j] > target[j + 1]: + pass + else: + relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) + tmp = [] + # 处理未结束的部分 + if len(tmp) == 0: + relay_target_i.append(0) + else: + tmp.append(target[-1]) + relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) + relay_mask_i = [] + j = 0 + while j < len(target): + seg_len = target[j] + 1 + if target[j] < L: + relay_mask_i.extend([0] * (seg_len)) + else: + relay_mask_i.extend([1] * (seg_len - L) + [0] * L) + j = seg_len + j + return relay_target_i, relay_mask_i diff --git a/reproduction/seqence_labelling/cws/model/bilstm_crf_cws.py b/reproduction/seqence_labelling/cws/model/bilstm_crf_cws.py new file mode 100644 index 00000000..4f87a81c --- /dev/null +++ b/reproduction/seqence_labelling/cws/model/bilstm_crf_cws.py @@ -0,0 +1,60 @@ + +import torch +from fastNLP.modules import LSTM +from fastNLP.modules import allowed_transitions, ConditionalRandomField +from fastNLP import seq_len_to_mask +from torch import nn +from fastNLP import Const +import torch.nn.functional as F + +class BiLSTMCRF(nn.Module): + def __init__(self, char_embed, hidden_size, num_layers, target_vocab=None, bigram_embed=None, trigram_embed=None, + dropout=0.5): + super().__init__() + + embed_size = char_embed.embed_size + self.char_embed = char_embed + if bigram_embed: + embed_size += bigram_embed.embed_size + self.bigram_embed = bigram_embed + if trigram_embed: + embed_size += trigram_embed.embed_size + self.trigram_embed = trigram_embed + + self.lstm = LSTM(embed_size, hidden_size=hidden_size//2, bidirectional=True, batch_first=True, + num_layers=num_layers) + self.dropout = nn.Dropout(p=dropout) + self.fc = nn.Linear(hidden_size, len(target_vocab)) + + transitions = None + if target_vocab: + transitions = allowed_transitions(target_vocab, include_start_end=True, encoding_type='bmes') + + self.crf = ConditionalRandomField(num_tags=len(target_vocab), allowed_transitions=transitions) + + def _forward(self, chars, bigrams, trigrams, seq_len, target=None): + chars = self.char_embed(chars) + if bigrams is not None: + bigrams = self.bigram_embed(bigrams) + chars = torch.cat([chars, bigrams], dim=-1) + if trigrams is not None: + trigrams = self.trigram_embed(trigrams) + chars = torch.cat([chars, trigrams], dim=-1) + + output, _ = self.lstm(chars, seq_len) + output = self.dropout(output) + output = self.fc(output) + output = F.log_softmax(output, dim=-1) + mask = seq_len_to_mask(seq_len) + if target is None: + pred, _ = self.crf.viterbi_decode(output, mask) + return {Const.OUTPUT:pred} + else: + loss = self.crf.forward(output, tags=target, mask=mask) + return {Const.LOSS:loss} + + def forward(self, chars, seq_len, target, bigrams=None, trigrams=None): + return self._forward(chars, bigrams, trigrams, seq_len, target) + + def predict(self, chars, seq_len, bigrams=None, trigrams=None): + return self._forward(chars, bigrams, trigrams, seq_len) \ No newline at end of file diff --git a/reproduction/seqence_labelling/cws/model/model.py b/reproduction/seqence_labelling/cws/model/bilstm_shift_relay.py similarity index 74% rename from reproduction/seqence_labelling/cws/model/model.py rename to reproduction/seqence_labelling/cws/model/bilstm_shift_relay.py index de945ac3..4ce1cc51 100644 --- a/reproduction/seqence_labelling/cws/model/model.py +++ b/reproduction/seqence_labelling/cws/model/bilstm_shift_relay.py @@ -1,7 +1,5 @@ from torch import nn import torch -from fastNLP.embeddings import Embedding -import numpy as np from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay from fastNLP.modules import LSTM @@ -21,25 +19,21 @@ class ShiftRelayCWSModel(nn.Module): :param num_bigram_per_char: 每个character对应的bigram的数量 :param drop_p: Dropout的大小 """ - def __init__(self, char_embed:Embedding, bigram_embed:Embedding, hidden_size:int=400, num_layers:int=1, - L:int=6, num_bigram_per_char:int=1, drop_p:float=0.2): + def __init__(self, char_embed, bigram_embed, hidden_size:int=400, num_layers:int=1, L:int=6, drop_p:float=0.2): super().__init__() - self.char_embedding = Embedding(char_embed, dropout=drop_p) - self._pretrained_embed = False - if isinstance(char_embed, np.ndarray): - self._pretrained_embed = True - self.bigram_embedding = Embedding(bigram_embed, dropout=drop_p) - self.lstm = LSTM(100 * (num_bigram_per_char + 1), hidden_size // 2, num_layers=num_layers, bidirectional=True, + self.char_embedding = char_embed + self.bigram_embedding = bigram_embed + self.lstm = LSTM(char_embed.embed_size+bigram_embed.embed_size, hidden_size // 2, num_layers=num_layers, + bidirectional=True, batch_first=True) self.feature_fn = FeatureFunMax(hidden_size, L) self.semi_crf_relay = SemiCRFShiftRelay(L) self.feat_drop = nn.Dropout(drop_p) self.reset_param() - # self.feature_fn.reset_parameters() def reset_param(self): for name, param in self.named_parameters(): - if 'embedding' in name and self._pretrained_embed: + if 'embedding' in name: continue if 'bias_hh' in name: nn.init.constant_(param, 0) @@ -51,10 +45,8 @@ class ShiftRelayCWSModel(nn.Module): nn.init.xavier_uniform_(param) def get_feats(self, chars, bigrams, seq_len): - batch_size, max_len = chars.size() chars = self.char_embedding(chars) bigrams = self.bigram_embedding(bigrams) - bigrams = bigrams.view(bigrams.size(0), max_len, -1) chars = torch.cat([chars, bigrams], dim=-1) feats, _ = self.lstm(chars, seq_len) feats = self.feat_drop(feats) diff --git a/reproduction/seqence_labelling/cws/train_bilstm_crf.py b/reproduction/seqence_labelling/cws/train_bilstm_crf.py new file mode 100644 index 00000000..b9a77249 --- /dev/null +++ b/reproduction/seqence_labelling/cws/train_bilstm_crf.py @@ -0,0 +1,52 @@ +import sys +sys.path.append('../../..') + +from fastNLP.io.pipe.cws import CWSPipe +from reproduction.seqence_labelling.cws.model.bilstm_crf_cws import BiLSTMCRF +from fastNLP import Trainer, cache_results +from fastNLP.embeddings import StaticEmbedding +from fastNLP import EvaluateCallback, BucketSampler, SpanFPreRecMetric, GradientClipCallback +from torch.optim import Adagrad + +###########hyper +dataname = 'pku' +hidden_size = 400 +num_layers = 1 +lr = 0.05 +###########hyper + + +@cache_results('{}.pkl'.format(dataname), _refresh=False) +def get_data(): + data_bundle = CWSPipe(dataset_name=dataname, bigrams=True, trigrams=False).process_from_file() + char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.33, word_dropout=0.01, + model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt') + bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.33,min_freq=3, word_dropout=0.01, + model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt') + return data_bundle, char_embed, bigram_embed + +data_bundle, char_embed, bigram_embed = get_data() +print(data_bundle) + +model = BiLSTMCRF(char_embed, hidden_size, num_layers, target_vocab=data_bundle.get_vocab('target'), bigram_embed=bigram_embed, + trigram_embed=None, dropout=0.3) +model.cuda() + +callbacks = [] +callbacks.append(EvaluateCallback(data_bundle.get_dataset('test'))) +callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) +optimizer = Adagrad(model.parameters(), lr=lr) + +metrics = [] +metric1 = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type='bmes') +metrics.append(metric1) + +trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer=optimizer, loss=None, + batch_size=128, sampler=BucketSampler(), update_every=1, + num_workers=1, n_epochs=10, print_every=5, + dev_data=data_bundle.get_dataset('dev'), + metrics=metrics, + metric_key=None, + validate_every=-1, save_path=None, use_tqdm=True, device=0, + callbacks=callbacks, check_code_level=0, dev_batch_size=128) +trainer.train() diff --git a/reproduction/seqence_labelling/cws/train_shift_relay.py b/reproduction/seqence_labelling/cws/train_shift_relay.py index 55576575..322f42bb 100644 --- a/reproduction/seqence_labelling/cws/train_shift_relay.py +++ b/reproduction/seqence_labelling/cws/train_shift_relay.py @@ -1,64 +1,53 @@ -import os +import sys +sys.path.append('../../..') from fastNLP import cache_results -from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader -from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel -from fastNLP.io.embed_loader import EmbeddingOption -from fastNLP.core.vocabulary import VocabularyOption +from reproduction.seqence_labelling.cws.data.cws_shift_pipe import CWSShiftRelayPipe +from reproduction.seqence_labelling.cws.model.bilstm_shift_relay import ShiftRelayCWSModel from fastNLP import Trainer from torch.optim import Adam from fastNLP import BucketSampler from fastNLP import GradientClipCallback from reproduction.seqence_labelling.cws.model.metric import RelayMetric - - -# 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果 -@cache_results(None) -def prepare_data(): - data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt, - bigram_vocab_opt=bigram_vocab_opt, - bigram_embed_opt=bigram_embed_opt, - L=L) - return data +from fastNLP.embeddings import StaticEmbedding +from fastNLP import EvaluateCallback #########hyper L = 4 hidden_size = 200 num_layers = 1 drop_p = 0.2 -lr = 0.02 - +lr = 0.008 +data_name = 'pku' #########hyper device = 0 -# !!!!这里千万不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 -# 你们的reproduction路径下,然后设置.gitignore -file_dir = '/path/to/' -char_embed_path = '/pretrain/vectors/1grams_t3_m50_corpus.txt' -bigram_embed_path = '/pretrain/vectors/2grams_t3_m50_corpus.txt' -bigram_vocab_opt = VocabularyOption(min_freq=3) -char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path) -bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path) - -data_name = os.path.basename(file_dir) cache_fp = 'caches/{}.pkl'.format(data_name) +@cache_results(_cache_fp=cache_fp, _refresh=True) # 将结果缓存到cache_fp中,这样下次运行就直接读取,而不需要再次运行 +def prepare_data(): + data_bundle = CWSShiftRelayPipe(dataset_name=data_name, L=L).process_from_file() + # 预训练的character embedding和bigram embedding + char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.5, word_dropout=0.01, + model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt') + bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.5, min_freq=3, word_dropout=0.01, + model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt') -data = prepare_data(_cache_fp=cache_fp, _refresh=True) + return data_bundle, char_embed, bigram_embed -model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'], - hidden_size=hidden_size, num_layers=num_layers, - L=L, num_bigram_per_char=1, drop_p=drop_p) +data, char_embed, bigram_embed = prepare_data() -sampler = BucketSampler(batch_size=32) +model = ShiftRelayCWSModel(char_embed=char_embed, bigram_embed=bigram_embed, + hidden_size=hidden_size, num_layers=num_layers, drop_p=drop_p, L=L) + +sampler = BucketSampler() optimizer = Adam(model.parameters(), lr=lr) -clipper = GradientClipCallback(clip_value=5, clip_type='value') -callbacks = [clipper] -# if pretrain: -# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) -# callbacks.append(fixer) -trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, - update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), +clipper = GradientClipCallback(clip_value=5, clip_type='value') # 截断太大的梯度 +evaluator = EvaluateCallback(data.get_dataset('test')) # 额外测试在test集上的效果 +callbacks = [clipper, evaluator] + +trainer = Trainer(data.get_dataset('train'), model, optimizer=optimizer, loss=None, batch_size=128, sampler=sampler, + update_every=1, n_epochs=10, print_every=5, dev_data=data.get_dataset('dev'), metrics=RelayMetric(), metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, - check_code_level=0) + check_code_level=0, num_workers=1) trainer.train() \ No newline at end of file diff --git a/reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py b/reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py index 249e2851..c38dce38 100644 --- a/reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py +++ b/reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py @@ -8,11 +8,10 @@ import torch.nn.functional as F from fastNLP import Const class CNNBiLSTMCRF(nn.Module): - def __init__(self, embed, char_embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'): + def __init__(self, embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'): super().__init__() self.embedding = embed - self.char_embedding = char_embed - self.lstm = LSTM(input_size=self.embedding.embedding_dim+self.char_embedding.embedding_dim, + self.lstm = LSTM(input_size=self.embedding.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, bidirectional=True, batch_first=True) self.fc = nn.Linear(hidden_size, len(tag_vocab)) @@ -32,9 +31,7 @@ class CNNBiLSTMCRF(nn.Module): nn.init.zeros_(param) def _forward(self, words, seq_len, target=None): - word_embeds = self.embedding(words) - char_embeds = self.char_embedding(words) - words = torch.cat((word_embeds, char_embeds), dim=-1) + words = self.embedding(words) outputs, _ = self.lstm(words, seq_len) self.dropout(outputs) diff --git a/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py b/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py index 10c5bdea..3138a6c2 100644 --- a/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py +++ b/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py @@ -1,7 +1,7 @@ import sys sys.path.append('../../..') -from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding +from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding, StackEmbedding from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF from fastNLP import Trainer @@ -22,7 +22,7 @@ def load_data(): paths = {'test':"NER/corpus/CoNLL-2003/eng.testb", 'train':"NER/corpus/CoNLL-2003/eng.train", 'dev':"NER/corpus/CoNLL-2003/eng.testa"} - data = Conll2003NERPipe(encoding_type=encoding_type, target_pad_val=0).process_from_file(paths) + data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths) return data data = load_data() print(data) @@ -33,8 +33,9 @@ word_embed = StaticEmbedding(vocab=data.get_vocab('words'), model_dir_or_name='en-glove-6b-100d', requires_grad=True, lower=True, word_dropout=0.01, dropout=0.5) word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std() +embed = StackEmbedding([word_embed, char_embed]) -model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], +model = CNNBiLSTMCRF(embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type) callbacks = [ diff --git a/reproduction/seqence_labelling/ner/train_ontonote.py b/reproduction/seqence_labelling/ner/train_ontonote.py index 7b465d77..ee80b6f7 100644 --- a/reproduction/seqence_labelling/ner/train_ontonote.py +++ b/reproduction/seqence_labelling/ner/train_ontonote.py @@ -2,7 +2,7 @@ import sys sys.path.append('../../..') -from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding +from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding, StackEmbedding from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF from fastNLP import Trainer @@ -35,7 +35,7 @@ def cache(): char_embed = CNNCharEmbedding(vocab=data.vocabs['words'], embed_size=30, char_emb_size=30, filter_nums=[30], kernel_sizes=[3], dropout=dropout) word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], - model_dir_or_name='en-glove-100d', + model_dir_or_name='en-glove-6b-100d', requires_grad=True, normalize=normalize, word_dropout=0.01, @@ -47,7 +47,8 @@ data, char_embed, word_embed = cache() print(data) -model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], +embed = StackEmbedding([word_embed, char_embed]) +model = CNNBiLSTMCRF(embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type, dropout=dropout) callbacks = [ diff --git a/test/io/loader/test_cws_loader.py b/test/io/loader/test_cws_loader.py new file mode 100644 index 00000000..6ad607c3 --- /dev/null +++ b/test/io/loader/test_cws_loader.py @@ -0,0 +1,13 @@ +import unittest +import os +from fastNLP.io.loader import CWSLoader + + +class CWSLoaderTest(unittest.TestCase): + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_download(self): + dataset_names = ['pku', 'cityu', 'as', 'msra'] + for dataset_name in dataset_names: + with self.subTest(dataset_name=dataset_name): + data_bundle = CWSLoader(dataset_name=dataset_name).load() + print(data_bundle) \ No newline at end of file diff --git a/test/io/pipe/test_cws.py b/test/io/pipe/test_cws.py new file mode 100644 index 00000000..2fc57ae2 --- /dev/null +++ b/test/io/pipe/test_cws.py @@ -0,0 +1,13 @@ + +import unittest +import os +from fastNLP.io.pipe.cws import CWSPipe + +class CWSPipeTest(unittest.TestCase): + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_process_from_file(self): + dataset_names = ['pku', 'cityu', 'as', 'msra'] + for dataset_name in dataset_names: + with self.subTest(dataset_name=dataset_name): + data_bundle = CWSPipe(dataset_name=dataset_name).process_from_file() + print(data_bundle) \ No newline at end of file