@@ -1006,24 +1006,28 @@ class CMRC2018Metric(MetricBase): | |||
self.total = 0 | |||
self.f1 = 0 | |||
def evaluate(self, answers, raw_chars, context_len, pred_start, pred_end): | |||
def evaluate(self, answers, raw_chars, pred_start, pred_end, context_len=None): | |||
""" | |||
:param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] | |||
:param list[str] raw_chars: [["这", "是", ...], [...]] | |||
:param tensor pred_start: batch_size x length 或 batch_size, | |||
:param tensor pred_end: batch_size x length 或 batch_size(是闭区间,包含end位置), | |||
:param tensor context_len: context长度, batch_size | |||
:param tensor pred_start: batch_size x length | |||
:param tensor pred_end: batch_size x length | |||
:return: | |||
""" | |||
batch_size, max_len = pred_start.size() | |||
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) | |||
pred_start.masked_fill_(context_mask, float('-inf')) | |||
pred_end.masked_fill_(context_mask, float('-inf')) | |||
max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | |||
pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 | |||
pred_end.masked_fill_(pred_start_mask, float('-inf')) | |||
pred_end_index = pred_end.argmax(dim=-1) + 1 | |||
if pred_start.dim() > 1: | |||
batch_size, max_len = pred_start.size() | |||
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) | |||
pred_start.masked_fill_(context_mask, float('-inf')) | |||
pred_end.masked_fill_(context_mask, float('-inf')) | |||
max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | |||
pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 | |||
pred_end.masked_fill_(pred_start_mask, float('-inf')) | |||
pred_end_index = pred_end.argmax(dim=-1) + 1 | |||
else: | |||
pred_start_index = pred_start | |||
pred_end_index = pred_end + 1 | |||
pred_ans = [] | |||
for index, (start, end) in enumerate(zip(pred_start_index.flatten().tolist(), pred_end_index.tolist())): | |||
pred_ans.append(''.join(raw_chars[index][start:end])) | |||
@@ -68,7 +68,11 @@ class StaticEmbedding(TokenEmbedding): | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||
:param dict kwarngs: only_train_min_freq, 仅对train中的词语使用min_freq筛选; only_norm_found_vector是否仅对在预训练中找到的词语使用normalize。 | |||
:param dict kwarngs: | |||
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; | |||
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; | |||
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表 | |||
不需要更新设置为True。 | |||
""" | |||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
if embedding_dim > 0: | |||
@@ -118,7 +122,8 @@ class StaticEmbedding(TokenEmbedding): | |||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||
vocab = truncated_vocab | |||
self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | |||
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | |||
# 读取embedding | |||
if lower: | |||
@@ -249,12 +254,13 @@ class StaticEmbedding(TokenEmbedding): | |||
logger.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
for word, index in vocab: | |||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||
matrix[index] = matrix[vocab.unknown_idx] | |||
else: | |||
matrix[index] = None | |||
if not self.only_use_pretrain_word: # 如果只用pretrain中的值就不要为未找到的词创建entry了 | |||
for word, index in vocab: | |||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||
matrix[index] = matrix[vocab.unknown_idx] | |||
else: | |||
matrix[index] = None | |||
# matrix中代表是需要建立entry的词 | |||
vectors = self._randomly_init_embed(len(matrix), dim, init_method) | |||
@@ -16,9 +16,9 @@ from ...core import Vocabulary | |||
__all__ = ['CMRC2018BertPipe'] | |||
def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars'): | |||
def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | |||
""" | |||
处理data_bundle中的DataSet,将context与question进行tokenize,然后使用[SEP]将两者连接起来。 | |||
处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 | |||
会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | |||
与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。 | |||
@@ -26,6 +26,7 @@ def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars') | |||
:param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ] | |||
:return: | |||
""" | |||
tokenizer = get_tokenizer('cn-char', lang='cn') | |||
for name in list(data_bundle.datasets.keys()): | |||
ds = data_bundle.get_dataset(name) | |||
data_bundle.delete_dataset(name) | |||
@@ -87,8 +88,8 @@ class CMRC2018BertPipe(Pipe): | |||
".", "...", "...","...", "..." | |||
raw_words列是context与question拼起来的结果,words是转为index的值, target_start当当前位置为答案的开头时为1,target_end当当前 | |||
位置为答案的结尾是为1;context_len指示的是words列中context的长度。 | |||
raw_words列是context与question拼起来的结果(连接的地方加入了[SEP]),words是转为index的值, target_start为答案start的index,target_end为答案end的index | |||
(闭区间);context_len指示的是words列中context的长度。 | |||
其中各列的meta信息如下: | |||
+-------------+-------------+-----------+--------------+------------+-------+---------+ | |||
@@ -119,8 +120,7 @@ class CMRC2018BertPipe(Pipe): | |||
:param data_bundle: | |||
:return: | |||
""" | |||
_tokenizer = get_tokenizer('cn-char', lang='cn') | |||
data_bundle = _concat_clip(data_bundle, tokenizer=_tokenizer, max_len=self.max_len, concat_field_name='raw_chars') | |||
data_bundle = _concat_clip(data_bundle, max_len=self.max_len, concat_field_name='raw_chars') | |||
src_vocab = Vocabulary() | |||
src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | |||
@@ -35,6 +35,79 @@ class TestLoad(unittest.TestCase): | |||
words = torch.randint(1, 200, (batch, length)).long() | |||
embed(words) | |||
def test_only_use_pretrain_word(self): | |||
def check_word_unk(words, vocab, embed): | |||
for word in words: | |||
self.assertListEqual(embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0], | |||
embed(torch.LongTensor([1])).tolist()[0]) | |||
def check_vector_equal(words, vocab, embed, embed_dict, lower=False): | |||
for word in words: | |||
index = vocab.to_index(word) | |||
v1 = embed(torch.LongTensor([index])).tolist()[0] | |||
if lower: | |||
word = word.lower() | |||
v2 = embed_dict[word] | |||
for v1i, v2i in zip(v1, v2): | |||
self.assertAlmostEqual(v1i, v2i, places=4) | |||
embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt') | |||
# 测试是否只使用pretrain的word | |||
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) | |||
vocab.add_word('of', no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_use_pretrain_word=True) | |||
# notinfile应该被置为unk | |||
check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict) | |||
check_word_unk(['notinfile'], vocab, embed) | |||
# 测试在大小写情况下的使用 | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile']) | |||
vocab.add_word('Of', no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_use_pretrain_word=True) | |||
check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到 | |||
check_vector_equal(['a'], vocab, embed, embed_dict) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_use_pretrain_word=True, lower=True) | |||
check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True) | |||
check_word_unk(['notinfile'], vocab, embed) | |||
# 测试min_freq | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2']) | |||
vocab.add_word('Of', no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', | |||
only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True) | |||
check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True) | |||
check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed) | |||
def read_static_embed(fp): | |||
""" | |||
:param str fp: embedding的路径 | |||
:return: {}, key是word, value是vector | |||
""" | |||
embed = {} | |||
with open(fp, 'r') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split() | |||
vector = list(map(float, parts[1:])) | |||
word = parts[0] | |||
embed[word] = vector | |||
return embed | |||
class TestRandomSameEntry(unittest.TestCase): | |||
def test_same_vector(self): | |||
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | |||