@@ -1006,24 +1006,28 @@ class CMRC2018Metric(MetricBase): | |||||
self.total = 0 | self.total = 0 | ||||
self.f1 = 0 | self.f1 = 0 | ||||
def evaluate(self, answers, raw_chars, context_len, pred_start, pred_end): | |||||
def evaluate(self, answers, raw_chars, pred_start, pred_end, context_len=None): | |||||
""" | """ | ||||
:param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] | :param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] | ||||
:param list[str] raw_chars: [["这", "是", ...], [...]] | :param list[str] raw_chars: [["这", "是", ...], [...]] | ||||
:param tensor pred_start: batch_size x length 或 batch_size, | |||||
:param tensor pred_end: batch_size x length 或 batch_size(是闭区间,包含end位置), | |||||
:param tensor context_len: context长度, batch_size | :param tensor context_len: context长度, batch_size | ||||
:param tensor pred_start: batch_size x length | |||||
:param tensor pred_end: batch_size x length | |||||
:return: | :return: | ||||
""" | """ | ||||
batch_size, max_len = pred_start.size() | |||||
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) | |||||
pred_start.masked_fill_(context_mask, float('-inf')) | |||||
pred_end.masked_fill_(context_mask, float('-inf')) | |||||
max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | |||||
pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 | |||||
pred_end.masked_fill_(pred_start_mask, float('-inf')) | |||||
pred_end_index = pred_end.argmax(dim=-1) + 1 | |||||
if pred_start.dim() > 1: | |||||
batch_size, max_len = pred_start.size() | |||||
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) | |||||
pred_start.masked_fill_(context_mask, float('-inf')) | |||||
pred_end.masked_fill_(context_mask, float('-inf')) | |||||
max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | |||||
pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 | |||||
pred_end.masked_fill_(pred_start_mask, float('-inf')) | |||||
pred_end_index = pred_end.argmax(dim=-1) + 1 | |||||
else: | |||||
pred_start_index = pred_start | |||||
pred_end_index = pred_end + 1 | |||||
pred_ans = [] | pred_ans = [] | ||||
for index, (start, end) in enumerate(zip(pred_start_index.flatten().tolist(), pred_end_index.tolist())): | for index, (start, end) in enumerate(zip(pred_start_index.flatten().tolist(), pred_end_index.tolist())): | ||||
pred_ans.append(''.join(raw_chars[index][start:end])) | pred_ans.append(''.join(raw_chars[index][start:end])) | ||||
@@ -68,7 +68,11 @@ class StaticEmbedding(TokenEmbedding): | |||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | ||||
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | ||||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | ||||
:param dict kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中找到的词语使用normalize。 | |||||
:param dict kwarngs: | |||||
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; | |||||
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; | |||||
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表 | |||||
不需要更新设置为True。 | |||||
""" | """ | ||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
if embedding_dim > 0: | if embedding_dim > 0: | ||||
@@ -118,7 +122,8 @@ class StaticEmbedding(TokenEmbedding): | |||||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | truncated_words_to_words[index] = truncated_vocab.to_index(word) | ||||
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | ||||
vocab = truncated_vocab | vocab = truncated_vocab | ||||
self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | |||||
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | ||||
# 读取embedding | # 读取embedding | ||||
if lower: | if lower: | ||||
@@ -249,12 +254,13 @@ class StaticEmbedding(TokenEmbedding): | |||||
logger.error("Error occurred at the {} line.".format(idx)) | logger.error("Error occurred at the {} line.".format(idx)) | ||||
raise e | raise e | ||||
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | ||||
for word, index in vocab: | |||||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||||
matrix[index] = matrix[vocab.unknown_idx] | |||||
else: | |||||
matrix[index] = None | |||||
if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了 | |||||
for word, index in vocab: | |||||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||||
matrix[index] = matrix[vocab.unknown_idx] | |||||
else: | |||||
matrix[index] = None | |||||
# matrix中代表是需要建立entry的词 | # matrix中代表是需要建立entry的词 | ||||
vectors = self._randomly_init_embed(len(matrix), dim, init_method) | vectors = self._randomly_init_embed(len(matrix), dim, init_method) | ||||
@@ -16,9 +16,9 @@ from ...core import Vocabulary | |||||
__all__ = ['CMRC2018BertPipe'] | __all__ = ['CMRC2018BertPipe'] | ||||
def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars'): | |||||
def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | |||||
""" | """ | ||||
处理data_bundle中的DataSet,将context与question进行tokenize,然后使用[SEP]将两者连接起来。 | |||||
处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 | |||||
会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | 会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | ||||
与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。 | 与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。 | ||||
@@ -26,6 +26,7 @@ def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars') | |||||
:param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ] | :param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ] | ||||
:return: | :return: | ||||
""" | """ | ||||
tokenizer = get_tokenizer('cn-char', lang='cn') | |||||
for name in list(data_bundle.datasets.keys()): | for name in list(data_bundle.datasets.keys()): | ||||
ds = data_bundle.get_dataset(name) | ds = data_bundle.get_dataset(name) | ||||
data_bundle.delete_dataset(name) | data_bundle.delete_dataset(name) | ||||
@@ -87,8 +88,8 @@ class CMRC2018BertPipe(Pipe): | |||||
".", "...", "...","...", "..." | ".", "...", "...","...", "..." | ||||
raw_words列是context与question拼起来的结果,words是转为index的值, target_start当当前位置为答案的开头时为1,target_end当当前 | |||||
位置为答案的结尾是为1;context_len指示的是words列中context的长度。 | |||||
raw_words列是context与question拼起来的结果(连接的地方加入了[SEP]),words是转为index的值, target_start为答案start的index,target_end为答案end的index | |||||
(闭区间);context_len指示的是words列中context的长度。 | |||||
其中各列的meta信息如下: | 其中各列的meta信息如下: | ||||
+-------------+-------------+-----------+--------------+------------+-------+---------+ | +-------------+-------------+-----------+--------------+------------+-------+---------+ | ||||
@@ -119,8 +120,7 @@ class CMRC2018BertPipe(Pipe): | |||||
:param data_bundle: | :param data_bundle: | ||||
:return: | :return: | ||||
""" | """ | ||||
_tokenizer = get_tokenizer('cn-char', lang='cn') | |||||
data_bundle = _concat_clip(data_bundle, tokenizer=_tokenizer, max_len=self.max_len, concat_field_name='raw_chars') | |||||
data_bundle = _concat_clip(data_bundle, max_len=self.max_len, concat_field_name='raw_chars') | |||||
src_vocab = Vocabulary() | src_vocab = Vocabulary() | ||||
src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | ||||
@@ -35,6 +35,79 @@ class TestLoad(unittest.TestCase): | |||||
words = torch.randint(1, 200, (batch, length)).long() | words = torch.randint(1, 200, (batch, length)).long() | ||||
embed(words) | embed(words) | ||||
def test_only_use_pretrain_word(self): | |||||
def check_word_unk(words, vocab, embed): | |||||
for word in words: | |||||
self.assertListEqual(embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0], | |||||
embed(torch.LongTensor([1])).tolist()[0]) | |||||
def check_vector_equal(words, vocab, embed, embed_dict, lower=False): | |||||
for word in words: | |||||
index = vocab.to_index(word) | |||||
v1 = embed(torch.LongTensor([index])).tolist()[0] | |||||
if lower: | |||||
word = word.lower() | |||||
v2 = embed_dict[word] | |||||
for v1i, v2i in zip(v1, v2): | |||||
self.assertAlmostEqual(v1i, v2i, places=4) | |||||
embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | |||||
# 测试是否只使用pretrain的word | |||||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | |||||
vocab.add_word('of', no_create_entry=True) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | |||||
only_use_pretrain_word=True) | |||||
# notinfile应该被置为unk | |||||
check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict) | |||||
check_word_unk(['notinfile'], vocab, embed) | |||||
# 测试在大小写情况下的使用 | |||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile']) | |||||
vocab.add_word('Of', no_create_entry=True) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | |||||
only_use_pretrain_word=True) | |||||
check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到 | |||||
check_vector_equal(['a'], vocab, embed, embed_dict) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | |||||
only_use_pretrain_word=True, lower=True) | |||||
check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True) | |||||
check_word_unk(['notinfile'], vocab, embed) | |||||
# 测试min_freq | |||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | |||||
vocab.add_word('Of', no_create_entry=True) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', | |||||
only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True) | |||||
check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) | |||||
check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) | |||||
def read_static_embed(fp): | |||||
""" | |||||
:param str fp: embedding的路径 | |||||
:return: {}, key是word, value是vector | |||||
""" | |||||
embed = {} | |||||
with open(fp, 'r') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split() | |||||
vector = list(map(float, parts[1:])) | |||||
word = parts[0] | |||||
embed[word] = vector | |||||
return embed | |||||
class TestRandomSameEntry(unittest.TestCase): | class TestRandomSameEntry(unittest.TestCase): | ||||
def test_same_vector(self): | def test_same_vector(self): | ||||
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | ||||