Browse Source

修复Vocabulary在load的时候可能发生的bug

tags/v0.5.5
yh_cc 4 years ago
parent
commit
18747e632e
2 changed files with 29 additions and 7 deletions
  1. +6
    -6
      fastNLP/core/vocabulary.py
  2. +23
    -1
      test/core/test_vocabulary.py

+ 6
- 6
fastNLP/core/vocabulary.py View File

@@ -519,11 +519,11 @@ class Vocabulary(object):
line = line.strip()
if line:
name, value = line.split()
if name == 'max_size':
vocab.max_size = int(value) if value!='None' else None
elif name == 'min_freq':
vocab.min_freq = int(value) if value!='None' else None
if name in ('max_size', 'min_freq'):
value = int(value) if value!='None' else None
setattr(vocab, name, value)
elif name in ('unknown', 'padding'):
value = value if value!='None' else None
setattr(vocab, name, value)
elif name == 'rebuild':
vocab.rebuild = True if value=='True' else False
@@ -535,12 +535,12 @@ class Vocabulary(object):
for line in f:
line = line.strip()
if line:
parts = line.split()
parts = line.split('\t')
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3])
if idx >= 0:
word2idx[word] = idx
word_counter[word] = count
if no_create_entry_counter:
if no_create_entry:
no_create_entry_counter[word] = count

word_counter = Counter(word_counter)


+ 23
- 1
test/core/test_vocabulary.py View File

@@ -214,7 +214,29 @@ class TestOther(unittest.TestCase):
for idx in range(len(vocab)):
self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx))
self.assertEqual(vocab.unknown, new_vocab.unknown)
except:

# 测试vocab中包含None的padding和unk
vocab= Vocabulary(padding=None, unknown=None)
words = list('abcdefaddfdkjfe')
no_create_entry = list('12342331')

vocab.add_word_lst(words)
vocab.add_word_lst(no_create_entry, no_create_entry=True)
vocab.save(fp)

new_vocab = Vocabulary.load(fp)

for word, index in vocab:
self.assertEqual(new_vocab.to_index(word), index)
for word in no_create_entry:
self.assertTrue(new_vocab._is_word_no_create_entry(word))
for word in words:
self.assertFalse(new_vocab._is_word_no_create_entry(word))
for idx in range(len(vocab)):
self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx))
self.assertEqual(vocab.unknown, new_vocab.unknown)

finally:
import os
if os.path.exists(fp):
os.remove(fp)

Loading…
Cancel
Save