Browse Source

1.DataSet.apply()报错时提供错误的index

2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序
3.embedloader在embed读取时遇到不规则的数据跳过这一行.
tags/v0.4.10
yh_cc 6 years ago
parent
commit
c1ee0b27df
5 changed files with 91 additions and 25 deletions
  1. +5
    -0
      MANIFEST.in
  2. +11
    -1
      fastNLP/core/dataset.py
  3. +13
    -4
      fastNLP/core/vocabulary.py
  4. +58
    -17
      fastNLP/io/embed_loader.py
  5. +4
    -3
      test/io/test_embed_loader.py

+ 5
- 0
MANIFEST.in View File

@@ -0,0 +1,5 @@
include requirements.txt
include LICENSE
include README.md
prune test/
prune reproduction/

+ 11
- 1
fastNLP/core/dataset.py View File

@@ -277,7 +277,17 @@ class DataSet(object):
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target.
:return results: if new_field_name is not passed, returned values of the function over all instances.
"""
results = [func(ins) for ins in self._inner_iter()]
assert len(self)!=0, "Null dataset cannot use .apply()."
results = []
idx = -1
try:
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
except Exception as e:
if idx!=-1:
print("Exception happens at the `{}`th instance.".format(idx))
raise e
# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))



+ 13
- 4
fastNLP/core/vocabulary.py View File

@@ -182,9 +182,13 @@ class Vocabulary(object):

if new_field_name is None:
new_field_name = field_name
for dataset in datasets:
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
dataset.apply(index_instance, new_field_name=new_field_name)
try:
dataset.apply(index_instance, new_field_name=new_field_name)
except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
else:
raise RuntimeError("Only DataSet type is allowed.")

@@ -207,11 +211,16 @@ class Vocabulary(object):
if isinstance(field[0][0], list):
raise RuntimeError("Only support field with 2 dimensions.")
[self.add_word_lst(w) for w in field]
for dataset in datasets:
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
dataset.apply(construct_vocab)
try:
dataset.apply(construct_vocab)
except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
else:
raise RuntimeError("Only DataSet type is allowed.")
return self

def to_index(self, w):
""" Turn a word to an index. If w is not in Vocabulary, return the unknown label.


+ 58
- 17
fastNLP/io/embed_loader.py View File

@@ -6,6 +6,7 @@ import torch
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import BaseLoader

import warnings

class EmbedLoader(BaseLoader):
"""docstring for EmbedLoader"""
@@ -128,7 +129,7 @@ class EmbedLoader(BaseLoader):
return embedding_matrix

@staticmethod
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True):
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'):
"""
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining
embedding are initialized from a normal distribution which has the mean and std of the found words vectors.
@@ -138,6 +139,8 @@ class EmbedLoader(BaseLoader):
:param vocab: Vocabulary.
:param dtype: the dtype of the embedding matrix
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1.
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will
raise
:return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain
embedding
"""
@@ -148,24 +151,32 @@ class EmbedLoader(BaseLoader):
hit_flags = np.zeros(len(vocab), dtype=bool)
line = f.readline().strip()
parts = line.split()
start_idx = 0
if len(parts)==2:
dim = int(parts[1])
start_idx += 1
else:
dim = len(parts)-1
f.seek(0)
matrix = np.random.randn(len(vocab), dim).astype(dtype)
for line in f:
parts = line.strip().split()
if parts[0] in vocab:
index = vocab.to_index(parts[0])
matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
hit_flags[index] = True
for idx, line in enumerate(f, start_idx):
try:
parts = line.strip().split()
if parts[0] in vocab:
index = vocab.to_index(parts[0])
matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
hit_flags[index] = True
except Exception as e:
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
else:
raise e
total_hits = sum(hit_flags)
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
found_vectors = matrix[hit_flags]
if len(found_vectors)!=0:
mean = np.mean(found_vectors, axis=1, keepdims=True)
std = np.std(found_vectors, axis=1, keepdims=True)
mean = np.mean(found_vectors, axis=0, keepdims=True)
std = np.std(found_vectors, axis=0, keepdims=True)
unfound_vec_num = len(vocab) - total_hits
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean
matrix[hit_flags==False] = r_vecs
@@ -176,7 +187,8 @@ class EmbedLoader(BaseLoader):
return matrix

@staticmethod
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True):
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True,
error='ignore'):
"""
load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding.
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements).
@@ -186,12 +198,16 @@ class EmbedLoader(BaseLoader):
:param padding: the padding tag for vocabulary.
:param unknown: the unknown tag for vocabulary.
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1.
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will
:raise
:return: np.ndarray() is determined by the pretraining embeddings
Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>]

"""
vocab = Vocabulary(padding=padding, unknown=unknown)
vec_dict = {}
found_unknown = False
found_pad = False

with open(embed_filepath, 'r', encoding='utf-8') as f:
line = f.readline()
@@ -201,16 +217,41 @@ class EmbedLoader(BaseLoader):
f.seek(0)
start = 0
for idx, line in enumerate(f, start=start):
parts = line.strip().split()
word = parts[0]
if dim==-1:
dim = len(parts)-1
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
vec_dict[word] = vec
vocab.add_word(word)
try:
parts = line.strip().split()
word = parts[0]
if dim==-1:
dim = len(parts)-1
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
vec_dict[word] = vec
vocab.add_word(word)
if unknown is not None and unknown==word:
found_unknown = True
if found_pad is not None and padding==word:
found_pad = True
except Exception as e:
if error=='ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
pass
else:
raise e
if dim==-1:
raise RuntimeError("{} is an empty file.".format(embed_filepath))
matrix = np.random.randn(len(vocab), dim).astype(dtype)
# TODO 需要保证unk其它数据同分布的吗?
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad):
start_idx = 0
if padding is not None:
start_idx += 1
if unknown is not None:
start_idx += 1

mean = np.mean(matrix[start_idx:], axis=0, keepdims=True)
std = np.std(matrix[start_idx:], axis=0, keepdims=True)
if (unknown is not None and not found_unknown):
matrix[start_idx-1] = np.random.randn(1, dim).astype(dtype)*std + mean
if (padding is not None and not found_pad):
matrix[0] = np.random.randn(1, dim).astype(dtype)*std + mean

for key, vec in vec_dict.items():
index = vocab.to_index(key)


+ 4
- 3
test/io/test_embed_loader.py View File

@@ -17,11 +17,12 @@ class TestEmbedLoader(unittest.TestCase):
glove = "test/data_for_tests/glove.6B.50d_test.txt"
word2vec = "test/data_for_tests/word2vec_test.txt"
vocab.add_word('the')
vocab.add_word('none')
g_m = EmbedLoader.load_with_vocab(glove, vocab)
self.assertEqual(g_m.shape, (3, 50))
self.assertEqual(g_m.shape, (4, 50))
w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True)
self.assertEqual(w_m.shape, (3, 50))
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 3)
self.assertEqual(w_m.shape, (4, 50))
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4)

def test_load_without_vocab(self):
words = ['the', 'of', 'in', 'a', 'to', 'and']


Loading…
Cancel
Save