diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index fa6d90a2..58436a35 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -22,7 +22,7 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require 'varargs']) -class Example(dict): +class Option(dict): """a dict can treat keys as attributes""" def __getattr__(self, item): try: diff --git a/reproduction/seqence_labelling/Chinese_Word_Segmentation/data/CWSDataLoader.py b/reproduction/seqence_labelling/Chinese_Word_Segmentation/data/CWSDataLoader.py new file mode 100644 index 00000000..1000c204 --- /dev/null +++ b/reproduction/seqence_labelling/Chinese_Word_Segmentation/data/CWSDataLoader.py @@ -0,0 +1,210 @@ + +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary + +class SigHanLoader(DataSetLoader): + """ + 任务相关的说明可以在这里找到http://sighan.cs.uchicago.edu/ + 支持的数据格式为,一行一句,不同的word用空格隔开。如下例 + + 共同 创造 美好 的 新 世纪 —— 二○○一年 新年 + 女士 们 , 先生 们 , 同志 们 , 朋友 们 : + + 读取sighan中的数据集,返回的DataSet将包含以下的内容fields: + raw_chars: list(str), 每个元素是一个汉字 + chars: list(str), 每个元素是一个index(汉字对应的index) + target: list(int), 根据不同的encoding_type会有不同的变化 + + :param target_type: target的类型,当前支持以下的两种: "bmes", "pointer" + """ + + def __init__(self, target_type:str): + super().__init__() + + if target_type.lower() not in ('bmes', 'pointer'): + raise ValueError("target_type only supports 'bmes', 'pointer'.") + + self.target_type = target_type + if target_type=='bmes': + self._word_len_to_target = self._word_len_to_bems + + + + @staticmethod + def _word_len_to_bems(word_lens:Iterator[int])->List[str]: + """ + + :param word_lens: 每个word的长度 + :return: 返回对应的BMES的str + """ + tags = [] + for word_len in word_lens: + if word_len==1: + tags.append('S') + else: + tags.append('B') + for _ in range(word_len-2): + tags.append('M') + tags.append('E') + return tags + + @staticmethod + def _gen_bigram(chars:List[str])->List[str]: + """ + + :param chars: + :return: + """ + return [c1+c2 for c1, c2 in zip(chars, chars[1:]+[''])] + + def load(self, path:str, bigram:bool=False)->DataSet: + """ + :param path: str + :param bigram: 是否使用bigram feature + :return: + """ + dataset = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + parts = line.split() + word_lens = map(len, parts) + chars = list(line) + tags = self._word_len_to_target(word_lens) + dataset.append(Instance(raw_chars=chars, target=tags)) + if len(dataset)==0: + raise RuntimeError(f"{path} has no valid data.") + if bigram: + dataset.apply_field(self._gen_bigram, field_name='raw_chars', new_field_name='bigrams') + return dataset + + def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, + char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, + bigram_embed_opt:EmbeddingOption=None): + """ + 支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 + + Option:: + + 共同 创造 美好 的 新 世纪 —— 二○○一年 新年 贺词 + ( 二○○○年 十二月 三十一日 ) ( 附 图片 1 张 ) + 女士 们 , 先生 们 , 同志 们 , 朋友 们 : + + paths支持两种格式,第一种是str,第二种是Dict[str, str]. + + Option:: + + # 1. str类型 + # 1.1 传入具体的文件路径 + data = SigHanLoader('bmes').process('/path/to/cws/data.txt') # 将读取data.txt的内容 + # 包含以下的内容data.vocabs['chars']:Vocabulary对象, + # data.vocabs['target']: Vocabulary对象,根据encoding_type可能会没有该值 + # data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项 + # data.datasets['train']: DataSet对象 + # 包含的field有: + # raw_chars: list[str], 每个元素是一个汉字 + # chars: list[int], 每个元素是汉字对应的index + # target: list[int], 根据encoding_type有对应的变化 + # 1.2 传入一个目录, 里面必须包含train.txt文件 + data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt + # 包含以下的内容data.vocabs['chars']: Vocabulary对象 + # data.vocabs['target']:Vocabulary对象 + # data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项 + # data.datasets['train']: DataSet对象 + # 包含的field有: + # raw_chars: list[str], 每个元素是一个汉字 + # chars: list[int], 每个元素是汉字对应的index + # target: list[int], 根据encoding_type有对应的变化 + # data.datasets['dev']: DataSet对象,如果文件夹下包含了dev.txt;内容与data.datasets['train']一样 + + # 2. dict类型, key是文件的名称,value是对应的读取路径. 必须包含'train'这个key + paths = {'train': '/path/to/train/train.txt', 'test':'/path/to/test/test.txt', 'dev':'/path/to/dev/dev.txt'} + data = SigHanLoader(paths).process(paths) + # 结果与传入目录时是一致的,但是可以传入多个数据集。data.datasets中的key将与这里传入的一致 + + :param paths: 支持传入目录,文件路径,以及dict。 + :param char_vocab_opt: 用于构建chars的vocabulary参数,默认为min_freq=2 + :param char_embed_opt: 用于读取chars的Embedding的参数,默认不读取pretrained的embedding + :param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 + 为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e + :param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 + :return: + """ + # 推荐大家使用这个check_data_loader_paths进行paths的验证 + paths = check_dataloader_paths(paths) + datasets = {} + bigram = bigram_vocab_opt is not None + for name, path in paths.items(): + dataset = self.load(path, bigram=bigram) + datasets[name] = dataset + # 创建vocab + char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) + char_vocab.from_dataset(datasets['train'], field_name='raw_chars') + char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') + # 创建target + if self.target_type == 'bmes': + target_vocab = Vocabulary(unknown=None, padding=None) + target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) + target_vocab.index_dataset(*datasets.values(), field_name='target') + if bigram: + bigram_vocab = Vocabulary(**bigram_vocab_opt) + bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') + bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') + if bigram_embed_opt is not None: + pass + + + + +import os + +def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: + """ + 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 + { + 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 + 'test': 'xxx' # 可能有,也可能没有 + ... + } + 如果paths为不合法的,将直接进行raise相应的错误 + + :param paths: 路径 + :return: + """ + if isinstance(paths, str): + if os.path.isfile(paths): + return {'train': paths} + elif os.path.isdir(paths): + train_fp = os.path.join(paths, 'train.txt') + if not os.path.isfile(train_fp): + raise FileNotFoundError(f"train.txt is not found in folder {paths}.") + files = {'train': train_fp} + for filename in ['test.txt', 'dev.txt']: + fp = os.path.join(paths, filename) + if os.path.isfile(fp): + files[filename.split('.')[0]] = fp + return files + else: + raise FileNotFoundError(f"{paths} is not a valid file path.") + + elif isinstance(paths, dict): + if paths: + if 'train' not in paths: + raise KeyError("You have to include `train` in your dict.") + for key, value in paths.items(): + if isinstance(key, str) and isinstance(value, str): + if not os.path.isfile(value): + raise TypeError(f"{value} is not a valid file.") + else: + raise TypeError("All keys and values in paths should be str.") + return paths + else: + raise ValueError("Empty paths is not allowed.") + else: + raise TypeError(f"paths only supports str and dict. not {type(paths)}.") + +