diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 66aabd3d..d4fa6cc9 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -117,6 +117,8 @@ class Vocabulary(object): :param str word: 新词 """ + if word in self._no_create_word: + self._no_create_word.pop(word) self.add(word) @_check_build_status @@ -126,6 +128,9 @@ class Vocabulary(object): :param list[str] word_lst: 词的序列 """ + for word in word_lst: + if word in self._no_create_word: + self._no_create_word.pop(word) self.update(word_lst) def build_vocab(self): diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 005cfe75..0d6f30e3 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -179,16 +179,16 @@ class StaticEmbedding(TokenEmbedding): :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding 的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d, `en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 - :param requires_grad: 是否需要gradient. 默认为True - :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 - :param normailize: 是否对vector进行normalize,使得每个vector的norm为1。 + :param bool requires_grad: 是否需要gradient. 默认为True + :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 + :param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。 + :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 + 为大写的词语开辟一个vector表示,则将lower设置为False。 """ def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, - normalize=False): + normalize=False, lower=False): super(StaticEmbedding, self).__init__(vocab) - # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, - # 得到cache_path if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: PRETRAIN_URL = _get_base_url('static') @@ -202,8 +202,40 @@ class StaticEmbedding(TokenEmbedding): raise ValueError(f"Cannot recognize {model_dir_or_name}.") # 读取embedding - embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, - normalize=normalize) + if lower: + lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) + for word, index in vocab: + if not vocab._is_word_no_create_entry(word): + lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 + for word in vocab._no_create_word.keys(): # 不需要创建entry的 + if word in vocab: + lowered_word = word.lower() + if lowered_word not in lowered_vocab.word_count: + lowered_vocab.add_word(lowered_word) + lowered_vocab._no_create_word[lowered_word] += 1 + print(f"All word in vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered " + f"words.") + embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method, + normalize=normalize) + # 需要适配一下 + if not hasattr(self, 'words_to_words'): + self.words_to_words = torch.arange(len(lowered_vocab, )).long() + if lowered_vocab.unknown: + unknown_idx = lowered_vocab.unknown_idx + else: + unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow + words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), + requires_grad=False) + for word, index in vocab: + if word not in lowered_vocab: + word = word.lower() + if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了 + continue + words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] + self.words_to_words = words_to_words + else: + embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, + normalize=normalize) self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], padding_idx=vocab.padding_idx, max_norm=None, norm_type=2, scale_grad_by_freq=False, @@ -301,7 +333,7 @@ class StaticEmbedding(TokenEmbedding): if vocab._no_create_word_length>0: if vocab.unknown is None: # 创建一个专门的unknown unknown_idx = len(matrix) - vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous() + vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() else: unknown_idx = vocab.unknown_idx words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), @@ -438,19 +470,15 @@ class ElmoEmbedding(ContextualEmbedding): :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, 目前支持的ELMo包括{`en` : 英文版本的ELMo, `cn` : 中文版本的ELMo,}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载 :param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 - 按照这个顺序concat起来。默认为'2'。 - :param requires_grad: bool, 该层是否需要gradient. 默认为False + 按照这个顺序concat起来。默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致, + 初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。) + :param requires_grad: bool, 该层是否需要gradient, 默认为False. :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, 并删除character encoder,之后将直接使用cache的embedding。默认为False。 """ def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False): super(ElmoEmbedding, self).__init__(vocab) - layers = list(map(int, layers.split(','))) - assert len(layers) > 0, "Must choose one output" - for layer in layers: - assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." - self.layers = layers # 根据model_dir_or_name检查是否存在并下载 if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: @@ -464,8 +492,49 @@ class ElmoEmbedding(ContextualEmbedding): else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) + + if layers=='mix': + self.layer_weights = nn.Parameter(torch.zeros(self.model.config['encoder']['n_layers']+1), + requires_grad=requires_grad) + self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad) + self._get_outputs = self._get_mixed_outputs + self._embed_size = self.model.config['encoder']['projection_dim'] * 2 + else: + layers = list(map(int, layers.split(','))) + assert len(layers) > 0, "Must choose one output" + for layer in layers: + assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." + self.layers = layers + self._get_outputs = self._get_layer_outputs + self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2 + self.requires_grad = requires_grad - self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2 + + def _get_mixed_outputs(self, outputs): + # outputs: num_layers x batch_size x max_len x hidden_size + # return: batch_size x max_len x hidden_size + weights = F.softmax(self.layer_weights+1/len(outputs), dim=0).to(outputs) + outputs = torch.einsum('l,lbij->bij', weights, outputs) + return self.gamma.to(outputs)*outputs + + def set_mix_weights_requires_grad(self, flag=True): + """ + 当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用 + 该方法没有用。 + :param bool flag: 混合不同层表示的结果是否可以训练。 + :return: + """ + if hasattr(self, 'layer_weights'): + self.layer_weights.requires_grad = flag + self.gamma.requires_grad = flag + + def _get_layer_outputs(self, outputs): + if len(self.layers) == 1: + outputs = outputs[self.layers[0]] + else: + outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1) + + return outputs def forward(self, words: torch.LongTensor): """ @@ -480,15 +549,12 @@ class ElmoEmbedding(ContextualEmbedding): if outputs is not None: return outputs outputs = self.model(words) - if len(self.layers) == 1: - outputs = outputs[self.layers[0]] - else: - outputs = torch.cat([*outputs[self.layers]], dim=-1) - - return outputs + return self._get_outputs(outputs) def _delete_model_weights(self): - del self.layers, self.model + for name in ['layers', 'model', 'layer_weights', 'gamma']: + if hasattr(self, name): + delattr(self, name) @property def requires_grad(self): @@ -892,10 +958,11 @@ class StackEmbedding(TokenEmbedding): def __init__(self, embeds: List[TokenEmbedding]): vocabs = [] for embed in embeds: - vocabs.append(embed.get_word_vocab()) + if hasattr(embed, 'get_word_vocab'): + vocabs.append(embed.get_word_vocab()) _vocab = vocabs[0] for vocab in vocabs[1:]: - assert vocab == _vocab, "All embeddings should use the same word vocabulary." + assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." super(StackEmbedding, self).__init__(_vocab) assert isinstance(embeds, list) diff --git a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py deleted file mode 100644 index 577987c6..00000000 --- a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py +++ /dev/null @@ -1,93 +0,0 @@ - -from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataInfo -from typing import Union, Dict -from fastNLP import Vocabulary -from fastNLP import Const -from reproduction.utils import check_dataloader_paths - -from fastNLP.io.dataset_loader import ConllLoader -from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 - - -class Conll2003DataLoader(DataSetLoader): - def __init__(self, task:str='ner', encoding_type:str='bioes'): - """ - 加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos - 时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 - 的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 - 鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行 - ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 - - :param task: 指定需要标注任务。可选ner, pos, chunk - """ - assert task in ('ner', 'pos', 'chunk') - index = {'ner':3, 'pos':1, 'chunk':2}[task] - self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index]) - self._tag_converters = None - if task in ('ner', 'chunk'): - self._tag_converters = [iob2] - if encoding_type == 'bioes': - self._tag_converters.append(iob2bioes) - - def load(self, path: str): - dataset = self._loader.load(path) - def convert_tag_schema(tags): - for converter in self._tag_converters: - tags = converter(tags) - return tags - if self._tag_converters: - dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) - return dataset - - def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=True): - """ - 读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略 - - :param paths: - :param word_vocab_opt: vocabulary的初始化值 - :param lower: 是否将所有字母转为小写 - :return: - """ - # 读取数据 - paths = check_dataloader_paths(paths) - data = DataInfo() - input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] - target_fields = [Const.TARGET, Const.INPUT_LEN] - for name, path in paths.items(): - dataset = self.load(path) - dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) - if lower: - dataset.words.lower() - data.datasets[name] = dataset - - # 对construct vocab - word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) - word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, - no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) - word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) - data.vocabs[Const.INPUT] = word_vocab - - # cap words - cap_word_vocab = Vocabulary() - cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words', - no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) - cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') - input_fields.append('cap_words') - data.vocabs['cap_words'] = cap_word_vocab - - # 对target建vocab - target_vocab = Vocabulary(unknown=None, padding=None) - target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) - target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) - data.vocabs[Const.TARGET] = target_vocab - - for name, dataset in data.datasets.items(): - dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) - dataset.set_input(*input_fields) - dataset.set_target(*target_fields) - - return data - -if __name__ == '__main__': - pass \ No newline at end of file diff --git a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py deleted file mode 100644 index 8a2c567d..00000000 --- a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py +++ /dev/null @@ -1,152 +0,0 @@ -from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataSetLoader, DataInfo -from typing import Union, Dict -from fastNLP import DataSet -from fastNLP import Vocabulary -from fastNLP import Const -from reproduction.utils import check_dataloader_paths - -from fastNLP.io.dataset_loader import ConllLoader -from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 - -class OntoNoteNERDataLoader(DataSetLoader): - """ - 用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。 - - """ - def __init__(self, encoding_type:str='bioes'): - assert encoding_type in ('bioes', 'bio') - self.encoding_type = encoding_type - if encoding_type=='bioes': - self.encoding_method = iob2bioes - else: - self.encoding_method = iob2 - - def load(self, path:str)->DataSet: - """ - 给定一个文件路径,读取数据。返回的DataSet包含以下的field - raw_words: List[str] - target: List[str] - - :param path: - :return: - """ - dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path) - def convert_to_bio(tags): - bio_tags = [] - flag = None - for tag in tags: - label = tag.strip("()*") - if '(' in tag: - bio_label = 'B-' + label - flag = label - elif flag: - bio_label = 'I-' + flag - else: - bio_label = 'O' - if ')' in tag: - flag = None - bio_tags.append(bio_label) - return self.encoding_method(bio_tags) - - def convert_word(words): - converted_words = [] - for word in words: - word = word.replace('/.', '.') # 有些结尾的.是/.形式的 - if not word.startswith('-'): - converted_words.append(word) - continue - # 以下是由于这些符号被转义了,再转回来 - tfrs = {'-LRB-':'(', - '-RRB-': ')', - '-LSB-': '[', - '-RSB-': ']', - '-LCB-': '{', - '-RCB-': '}' - } - if word in tfrs: - converted_words.append(tfrs[word]) - else: - converted_words.append(word) - return converted_words - - dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words') - dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target') - - return dataset - - def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, - lower:bool=True)->DataInfo: - """ - 读取并处理数据。返回的DataInfo包含以下的内容 - vocabs: - word: Vocabulary - target: Vocabulary - datasets: - train: DataSet - words: List[int], 被设置为input - target: int. label,被同时设置为input和target - seq_len: int. 句子的长度,被同时设置为input和target - raw_words: List[str] - xxx(根据传入的paths可能有所变化) - - :param paths: - :param word_vocab_opt: vocabulary的初始化值 - :param lower: 是否使用小写 - :return: - """ - paths = check_dataloader_paths(paths) - data = DataInfo() - input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] - target_fields = [Const.TARGET, Const.INPUT_LEN] - for name, path in paths.items(): - dataset = self.load(path) - dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) - if lower: - dataset.words.lower() - data.datasets[name] = dataset - - # 对construct vocab - word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) - word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, - no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) - word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) - data.vocabs[Const.INPUT] = word_vocab - - # cap words - cap_word_vocab = Vocabulary() - cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words') - cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') - input_fields.append('cap_words') - data.vocabs['cap_words'] = cap_word_vocab - - # 对target建vocab - target_vocab = Vocabulary(unknown=None, padding=None) - target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) - target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) - data.vocabs[Const.TARGET] = target_vocab - - for name, dataset in data.datasets.items(): - dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) - dataset.set_input(*input_fields) - dataset.set_target(*target_fields) - - return data - - -if __name__ == '__main__': - loader = OntoNoteNERDataLoader() - dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt') - print(dataset.target.value_count()) - print(dataset[:4]) - - -""" -train 115812 2200752 -development 15680 304684 -test 12217 230111 - -train 92403 1901772 -valid 13606 279180 -test 10258 204135 -""" \ No newline at end of file diff --git a/reproduction/seqence_labelling/ner/data/utils.py b/reproduction/seqence_labelling/ner/data/utils.py deleted file mode 100644 index 8f7af792..00000000 --- a/reproduction/seqence_labelling/ner/data/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import List - -def iob2(tags:List[str])->List[str]: - """ - 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。 - - :param tags: 需要转换的tags - """ - for i, tag in enumerate(tags): - if tag == "O": - continue - split = tag.split("-") - if len(split) != 2 or split[0] not in ["I", "B"]: - raise TypeError("The encoding schema is not a valid IOB type.") - if split[0] == "B": - continue - elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 - tags[i] = "B" + tag[1:] - elif tags[i - 1][1:] == tag[1:]: - continue - else: # conversion IOB1 to IOB2 - tags[i] = "B" + tag[1:] - return tags - -def iob2bioes(tags:List[str])->List[str]: - """ - 将iob的tag转换为bmeso编码 - :param tags: - :return: - """ - new_tags = [] - for i, tag in enumerate(tags): - if tag == 'O': - new_tags.append(tag) - else: - split = tag.split('-')[0] - if split == 'B': - if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': - new_tags.append(tag) - else: - new_tags.append(tag.replace('B-', 'S-')) - elif split == 'I': - if i + 1Dict[str, str]: path_pair = ('train', filename) if 'dev' in filename: if path_pair: - raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) + raise Exception("File:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0])) path_pair = ('dev', filename) if 'test' in filename: if path_pair: - raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) + raise Exception("File:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0])) path_pair = ('test', filename) if path_pair: + if path_pair[0] in files: + raise RuntimeError(f"Multiple file under {paths} have '{path_pair[0]}' in their filename.") files[path_pair[0]] = os.path.join(paths, path_pair[1]) return files else: