diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 4247d1de..a76d7ecc 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -1006,24 +1006,28 @@ class CMRC2018Metric(MetricBase): self.total = 0 self.f1 = 0 - def evaluate(self, answers, raw_chars, context_len, pred_start, pred_end): + def evaluate(self, answers, raw_chars, pred_start, pred_end, context_len=None): """ :param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] :param list[str] raw_chars: [["这", "是", ...], [...]] + :param tensor pred_start: batch_size x length 或 batch_size, + :param tensor pred_end: batch_size x length 或 batch_size(是闭区间,包含end位置), :param tensor context_len: context长度, batch_size - :param tensor pred_start: batch_size x length - :param tensor pred_end: batch_size x length :return: """ - batch_size, max_len = pred_start.size() - context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) - pred_start.masked_fill_(context_mask, float('-inf')) - pred_end.masked_fill_(context_mask, float('-inf')) - max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, - pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 - pred_end.masked_fill_(pred_start_mask, float('-inf')) - pred_end_index = pred_end.argmax(dim=-1) + 1 + if pred_start.dim() > 1: + batch_size, max_len = pred_start.size() + context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) + pred_start.masked_fill_(context_mask, float('-inf')) + pred_end.masked_fill_(context_mask, float('-inf')) + max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, + pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 + pred_end.masked_fill_(pred_start_mask, float('-inf')) + pred_end_index = pred_end.argmax(dim=-1) + 1 + else: + pred_start_index = pred_start + pred_end_index = pred_end + 1 pred_ans = [] for index, (start, end) in enumerate(zip(pred_start_index.flatten().tolist(), pred_end_index.tolist())): pred_ans.append(''.join(raw_chars[index][start:end])) diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index f519e705..a50ce25d 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -68,7 +68,11 @@ class StaticEmbedding(TokenEmbedding): :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 - :param dict kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中找到的词语使用normalize。 + :param dict kwarngs: + bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; + bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; + bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表 + 不需要更新设置为True。 """ super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) if embedding_dim > 0: @@ -118,7 +122,8 @@ class StaticEmbedding(TokenEmbedding): truncated_words_to_words[index] = truncated_vocab.to_index(word) logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") vocab = truncated_vocab - + + self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) # 读取embedding if lower: @@ -249,12 +254,13 @@ class StaticEmbedding(TokenEmbedding): logger.error("Error occurred at the {} line.".format(idx)) raise e logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) - for word, index in vocab: - if index not in matrix and not vocab._is_word_no_create_entry(word): - if found_unknown: # 如果有unkonwn,用unknown初始化 - matrix[index] = matrix[vocab.unknown_idx] - else: - matrix[index] = None + if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了 + for word, index in vocab: + if index not in matrix and not vocab._is_word_no_create_entry(word): + if found_unknown: # 如果有unkonwn,用unknown初始化 + matrix[index] = matrix[vocab.unknown_idx] + else: + matrix[index] = None # matrix中代表是需要建立entry的词 vectors = self._randomly_init_embed(len(matrix), dim, init_method) diff --git a/fastNLP/io/pipe/qa.py b/fastNLP/io/pipe/qa.py index ea989545..e8c0c69b 100644 --- a/fastNLP/io/pipe/qa.py +++ b/fastNLP/io/pipe/qa.py @@ -16,9 +16,9 @@ from ...core import Vocabulary __all__ = ['CMRC2018BertPipe'] -def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars'): +def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): """ - 处理data_bundle中的DataSet,将context与question进行tokenize,然后使用[SEP]将两者连接起来。 + 处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start 与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。 @@ -26,6 +26,7 @@ def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars') :param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ] :return: """ + tokenizer = get_tokenizer('cn-char', lang='cn') for name in list(data_bundle.datasets.keys()): ds = data_bundle.get_dataset(name) data_bundle.delete_dataset(name) @@ -87,8 +88,8 @@ class CMRC2018BertPipe(Pipe): ".", "...", "...","...", "..." - raw_words列是context与question拼起来的结果,words是转为index的值, target_start当当前位置为答案的开头时为1,target_end当当前 - 位置为答案的结尾是为1;context_len指示的是words列中context的长度。 + raw_words列是context与question拼起来的结果(连接的地方加入了[SEP]),words是转为index的值, target_start为答案start的index,target_end为答案end的index + (闭区间);context_len指示的是words列中context的长度。 其中各列的meta信息如下: +-------------+-------------+-----------+--------------+------------+-------+---------+ @@ -119,8 +120,7 @@ class CMRC2018BertPipe(Pipe): :param data_bundle: :return: """ - _tokenizer = get_tokenizer('cn-char', lang='cn') - data_bundle = _concat_clip(data_bundle, tokenizer=_tokenizer, max_len=self.max_len, concat_field_name='raw_chars') + data_bundle = _concat_clip(data_bundle, max_len=self.max_len, concat_field_name='raw_chars') src_vocab = Vocabulary() src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index 7d1e8302..61b7f2ed 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -35,6 +35,79 @@ class TestLoad(unittest.TestCase): words = torch.randint(1, 200, (batch, length)).long() embed(words) + def test_only_use_pretrain_word(self): + def check_word_unk(words, vocab, embed): + for word in words: + self.assertListEqual(embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0], + embed(torch.LongTensor([1])).tolist()[0]) + + def check_vector_equal(words, vocab, embed, embed_dict, lower=False): + for word in words: + index = vocab.to_index(word) + v1 = embed(torch.LongTensor([index])).tolist()[0] + if lower: + word = word.lower() + v2 = embed_dict[word] + for v1i, v2i in zip(v1, v2): + self.assertAlmostEqual(v1i, v2i, places=4) + embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt') + + # 测试是否只使用pretrain的word + vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) + vocab.add_word('of', no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt', + only_use_pretrain_word=True) + # notinfile应该被置为unk + check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict) + check_word_unk(['notinfile'], vocab, embed) + + # 测试在大小写情况下的使用 + vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile']) + vocab.add_word('Of', no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt', + only_use_pretrain_word=True) + check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到 + check_vector_equal(['a'], vocab, embed, embed_dict) + + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt', + only_use_pretrain_word=True, lower=True) + check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True) + check_word_unk(['notinfile'], vocab, embed) + + # 测试min_freq + vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) + vocab.add_word('Of', no_create_entry=True) + + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' + 'glove.6B.50d_test.txt', + only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True) + + check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) + check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) + + +def read_static_embed(fp): + """ + + :param str fp: embedding的路径 + :return: {}, key是word, value是vector + """ + embed = {} + with open(fp, 'r') as f: + for line in f: + line = line.strip() + if line: + parts = line.split() + vector = list(map(float, parts[1:])) + word = parts[0] + embed[word] = vector + return embed + + class TestRandomSameEntry(unittest.TestCase): def test_same_vector(self): vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])