Browse Source

1.DataSet.apply()报错时提供错误的index

2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序
3.embedloader在embed读取时遇到不规则的数据跳过这一行.
tags/v0.4.10
yh_cc 6 years ago
parent
commit
c1ee0b27df
5 changed files with 91 additions and 25 deletions
  1. +5
    -0
      MANIFEST.in
  2. +11
    -1
      fastNLP/core/dataset.py
  3. +13
    -4
      fastNLP/core/vocabulary.py
  4. +58
    -17
      fastNLP/io/embed_loader.py
  5. +4
    -3
      test/io/test_embed_loader.py

+ 5
- 0
MANIFEST.in View File

@@ -0,0 +1,5 @@
include requirements.txt
include LICENSE
include README.md
prune test/
prune reproduction/

+ 11
- 1
fastNLP/core/dataset.py View File

@@ -277,7 +277,17 @@ class DataSet(object):
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target.
:return results: if new_field_name is not passed, returned values of the function over all instances. :return results: if new_field_name is not passed, returned values of the function over all instances.
""" """
results = [func(ins) for ins in self._inner_iter()]
assert len(self)!=0, "Null dataset cannot use .apply()."
results = []
idx = -1
try:
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
except Exception as e:
if idx!=-1:
print("Exception happens at the `{}`th instance.".format(idx))
raise e
# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func))) raise ValueError("{} always return None.".format(get_func_signature(func=func)))




+ 13
- 4
fastNLP/core/vocabulary.py View File

@@ -182,9 +182,13 @@ class Vocabulary(object):


if new_field_name is None: if new_field_name is None:
new_field_name = field_name new_field_name = field_name
for dataset in datasets:
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
dataset.apply(index_instance, new_field_name=new_field_name)
try:
dataset.apply(index_instance, new_field_name=new_field_name)
except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
else: else:
raise RuntimeError("Only DataSet type is allowed.") raise RuntimeError("Only DataSet type is allowed.")


@@ -207,11 +211,16 @@ class Vocabulary(object):
if isinstance(field[0][0], list): if isinstance(field[0][0], list):
raise RuntimeError("Only support field with 2 dimensions.") raise RuntimeError("Only support field with 2 dimensions.")
[self.add_word_lst(w) for w in field] [self.add_word_lst(w) for w in field]
for dataset in datasets:
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
dataset.apply(construct_vocab)
try:
dataset.apply(construct_vocab)
except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
else: else:
raise RuntimeError("Only DataSet type is allowed.") raise RuntimeError("Only DataSet type is allowed.")
return self


def to_index(self, w): def to_index(self, w):
""" Turn a word to an index. If w is not in Vocabulary, return the unknown label. """ Turn a word to an index. If w is not in Vocabulary, return the unknown label.


+ 58
- 17
fastNLP/io/embed_loader.py View File

@@ -6,6 +6,7 @@ import torch
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import BaseLoader from fastNLP.io.base_loader import BaseLoader


import warnings


class EmbedLoader(BaseLoader): class EmbedLoader(BaseLoader):
"""docstring for EmbedLoader""" """docstring for EmbedLoader"""
@@ -128,7 +129,7 @@ class EmbedLoader(BaseLoader):
return embedding_matrix return embedding_matrix


@staticmethod @staticmethod
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True):
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'):
""" """
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. embedding are initialized from a normal distribution which has the mean and std of the found words vectors.
@@ -138,6 +139,8 @@ class EmbedLoader(BaseLoader):
:param vocab: Vocabulary. :param vocab: Vocabulary.
:param dtype: the dtype of the embedding matrix :param dtype: the dtype of the embedding matrix
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. :param normalize: bool, whether to normalize each word vector so that every vector has norm 1.
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will
raise
:return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain :return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain
embedding embedding
""" """
@@ -148,24 +151,32 @@ class EmbedLoader(BaseLoader):
hit_flags = np.zeros(len(vocab), dtype=bool) hit_flags = np.zeros(len(vocab), dtype=bool)
line = f.readline().strip() line = f.readline().strip()
parts = line.split() parts = line.split()
start_idx = 0
if len(parts)==2: if len(parts)==2:
dim = int(parts[1]) dim = int(parts[1])
start_idx += 1
else: else:
dim = len(parts)-1 dim = len(parts)-1
f.seek(0) f.seek(0)
matrix = np.random.randn(len(vocab), dim).astype(dtype) matrix = np.random.randn(len(vocab), dim).astype(dtype)
for line in f:
parts = line.strip().split()
if parts[0] in vocab:
index = vocab.to_index(parts[0])
matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
hit_flags[index] = True
for idx, line in enumerate(f, start_idx):
try:
parts = line.strip().split()
if parts[0] in vocab:
index = vocab.to_index(parts[0])
matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
hit_flags[index] = True
except Exception as e:
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
else:
raise e
total_hits = sum(hit_flags) total_hits = sum(hit_flags)
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
found_vectors = matrix[hit_flags] found_vectors = matrix[hit_flags]
if len(found_vectors)!=0: if len(found_vectors)!=0:
mean = np.mean(found_vectors, axis=1, keepdims=True)
std = np.std(found_vectors, axis=1, keepdims=True)
mean = np.mean(found_vectors, axis=0, keepdims=True)
std = np.std(found_vectors, axis=0, keepdims=True)
unfound_vec_num = len(vocab) - total_hits unfound_vec_num = len(vocab) - total_hits
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean
matrix[hit_flags==False] = r_vecs matrix[hit_flags==False] = r_vecs
@@ -176,7 +187,8 @@ class EmbedLoader(BaseLoader):
return matrix return matrix


@staticmethod @staticmethod
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True):
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True,
error='ignore'):
""" """
load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding.
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). The embedding type is determined automatically, support glove and word2vec(the first line only has two elements).
@@ -186,12 +198,16 @@ class EmbedLoader(BaseLoader):
:param padding: the padding tag for vocabulary. :param padding: the padding tag for vocabulary.
:param unknown: the unknown tag for vocabulary. :param unknown: the unknown tag for vocabulary.
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. :param normalize: bool, whether to normalize each word vector so that every vector has norm 1.
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will
:raise
:return: np.ndarray() is determined by the pretraining embeddings :return: np.ndarray() is determined by the pretraining embeddings
Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>] Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>]


""" """
vocab = Vocabulary(padding=padding, unknown=unknown) vocab = Vocabulary(padding=padding, unknown=unknown)
vec_dict = {} vec_dict = {}
found_unknown = False
found_pad = False


with open(embed_filepath, 'r', encoding='utf-8') as f: with open(embed_filepath, 'r', encoding='utf-8') as f:
line = f.readline() line = f.readline()
@@ -201,16 +217,41 @@ class EmbedLoader(BaseLoader):
f.seek(0) f.seek(0)
start = 0 start = 0
for idx, line in enumerate(f, start=start): for idx, line in enumerate(f, start=start):
parts = line.strip().split()
word = parts[0]
if dim==-1:
dim = len(parts)-1
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
vec_dict[word] = vec
vocab.add_word(word)
try:
parts = line.strip().split()
word = parts[0]
if dim==-1:
dim = len(parts)-1
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim)
vec_dict[word] = vec
vocab.add_word(word)
if unknown is not None and unknown==word:
found_unknown = True
if found_pad is not None and padding==word:
found_pad = True
except Exception as e:
if error=='ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
pass
else:
raise e
if dim==-1: if dim==-1:
raise RuntimeError("{} is an empty file.".format(embed_filepath)) raise RuntimeError("{} is an empty file.".format(embed_filepath))
matrix = np.random.randn(len(vocab), dim).astype(dtype) matrix = np.random.randn(len(vocab), dim).astype(dtype)
# TODO 需要保证unk其它数据同分布的吗?
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad):
start_idx = 0
if padding is not None:
start_idx += 1
if unknown is not None:
start_idx += 1

mean = np.mean(matrix[start_idx:], axis=0, keepdims=True)
std = np.std(matrix[start_idx:], axis=0, keepdims=True)
if (unknown is not None and not found_unknown):
matrix[start_idx-1] = np.random.randn(1, dim).astype(dtype)*std + mean
if (padding is not None and not found_pad):
matrix[0] = np.random.randn(1, dim).astype(dtype)*std + mean


for key, vec in vec_dict.items(): for key, vec in vec_dict.items():
index = vocab.to_index(key) index = vocab.to_index(key)


+ 4
- 3
test/io/test_embed_loader.py View File

@@ -17,11 +17,12 @@ class TestEmbedLoader(unittest.TestCase):
glove = "test/data_for_tests/glove.6B.50d_test.txt" glove = "test/data_for_tests/glove.6B.50d_test.txt"
word2vec = "test/data_for_tests/word2vec_test.txt" word2vec = "test/data_for_tests/word2vec_test.txt"
vocab.add_word('the') vocab.add_word('the')
vocab.add_word('none')
g_m = EmbedLoader.load_with_vocab(glove, vocab) g_m = EmbedLoader.load_with_vocab(glove, vocab)
self.assertEqual(g_m.shape, (3, 50))
self.assertEqual(g_m.shape, (4, 50))
w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True) w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True)
self.assertEqual(w_m.shape, (3, 50))
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 3)
self.assertEqual(w_m.shape, (4, 50))
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4)


def test_load_without_vocab(self): def test_load_without_vocab(self):
words = ['the', 'of', 'in', 'a', 'to', 'and'] words = ['the', 'of', 'in', 'a', 'to', 'and']


Loading…
Cancel
Save