From c1ee0b27dfb5daa8a0f83f161514f07d4075bb4f Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 21 Apr 2019 09:12:42 +0800 Subject: [PATCH] =?UTF-8?q?1.DataSet.apply()=E6=8A=A5=E9=94=99=E6=97=B6?= =?UTF-8?q?=E6=8F=90=E4=BE=9B=E9=94=99=E8=AF=AF=E7=9A=84index=202.Vocabula?= =?UTF-8?q?ry.from=5Fdataset(),=20index=5Fdataset()=E6=8F=90=E4=BE=9B?= =?UTF-8?q?=E6=8A=A5=E9=94=99=E6=97=B6=E7=9A=84vocab=E9=A1=BA=E5=BA=8F=203?= =?UTF-8?q?.embedloader=E5=9C=A8embed=E8=AF=BB=E5=8F=96=E6=97=B6=E9=81=87?= =?UTF-8?q?=E5=88=B0=E4=B8=8D=E8=A7=84=E5=88=99=E7=9A=84=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E8=B7=B3=E8=BF=87=E8=BF=99=E4=B8=80=E8=A1=8C.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MANIFEST.in | 5 +++ fastNLP/core/dataset.py | 12 +++++- fastNLP/core/vocabulary.py | 17 ++++++-- fastNLP/io/embed_loader.py | 75 ++++++++++++++++++++++++++++-------- test/io/test_embed_loader.py | 7 ++-- 5 files changed, 91 insertions(+), 25 deletions(-) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..f04509c1 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +include requirements.txt +include LICENSE +include README.md +prune test/ +prune reproduction/ diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 7b0e3b9a..76a34655 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -277,7 +277,17 @@ class DataSet(object): (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. :return results: if new_field_name is not passed, returned values of the function over all instances. """ - results = [func(ins) for ins in self._inner_iter()] + assert len(self)!=0, "Null dataset cannot use .apply()." + results = [] + idx = -1 + try: + for idx, ins in enumerate(self._inner_iter()): + results.append(func(ins)) + except Exception as e: + if idx!=-1: + print("Exception happens at the `{}`th instance.".format(idx)) + raise e + # results = [func(ins) for ins in self._inner_iter()] if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(get_func_signature(func=func))) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index a73ce2c7..c580dbec 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -182,9 +182,13 @@ class Vocabulary(object): if new_field_name is None: new_field_name = field_name - for dataset in datasets: + for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): - dataset.apply(index_instance, new_field_name=new_field_name) + try: + dataset.apply(index_instance, new_field_name=new_field_name) + except Exception as e: + print("When processing the `{}` dataset, the following error occurred.".format(idx)) + raise e else: raise RuntimeError("Only DataSet type is allowed.") @@ -207,11 +211,16 @@ class Vocabulary(object): if isinstance(field[0][0], list): raise RuntimeError("Only support field with 2 dimensions.") [self.add_word_lst(w) for w in field] - for dataset in datasets: + for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): - dataset.apply(construct_vocab) + try: + dataset.apply(construct_vocab) + except Exception as e: + print("When processing the `{}` dataset, the following error occurred.".format(idx)) + raise e else: raise RuntimeError("Only DataSet type is allowed.") + return self def to_index(self, w): """ Turn a word to an index. If w is not in Vocabulary, return the unknown label. diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index 08a55aa6..5ad27c53 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -6,6 +6,7 @@ import torch from fastNLP.core.vocabulary import Vocabulary from fastNLP.io.base_loader import BaseLoader +import warnings class EmbedLoader(BaseLoader): """docstring for EmbedLoader""" @@ -128,7 +129,7 @@ class EmbedLoader(BaseLoader): return embedding_matrix @staticmethod - def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True): + def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): """ load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining embedding are initialized from a normal distribution which has the mean and std of the found words vectors. @@ -138,6 +139,8 @@ class EmbedLoader(BaseLoader): :param vocab: Vocabulary. :param dtype: the dtype of the embedding matrix :param normalize: bool, whether to normalize each word vector so that every vector has norm 1. + :param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will + raise :return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain embedding """ @@ -148,24 +151,32 @@ class EmbedLoader(BaseLoader): hit_flags = np.zeros(len(vocab), dtype=bool) line = f.readline().strip() parts = line.split() + start_idx = 0 if len(parts)==2: dim = int(parts[1]) + start_idx += 1 else: dim = len(parts)-1 f.seek(0) matrix = np.random.randn(len(vocab), dim).astype(dtype) - for line in f: - parts = line.strip().split() - if parts[0] in vocab: - index = vocab.to_index(parts[0]) - matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) - hit_flags[index] = True + for idx, line in enumerate(f, start_idx): + try: + parts = line.strip().split() + if parts[0] in vocab: + index = vocab.to_index(parts[0]) + matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) + hit_flags[index] = True + except Exception as e: + if error == 'ignore': + warnings.warn("Error occurred at the {} line.".format(idx)) + else: + raise e total_hits = sum(hit_flags) print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) found_vectors = matrix[hit_flags] if len(found_vectors)!=0: - mean = np.mean(found_vectors, axis=1, keepdims=True) - std = np.std(found_vectors, axis=1, keepdims=True) + mean = np.mean(found_vectors, axis=0, keepdims=True) + std = np.std(found_vectors, axis=0, keepdims=True) unfound_vec_num = len(vocab) - total_hits r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean matrix[hit_flags==False] = r_vecs @@ -176,7 +187,8 @@ class EmbedLoader(BaseLoader): return matrix @staticmethod - def load_without_vocab(embed_filepath, dtype=np.float32, padding='', unknown='', normalize=True): + def load_without_vocab(embed_filepath, dtype=np.float32, padding='', unknown='', normalize=True, + error='ignore'): """ load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). @@ -186,12 +198,16 @@ class EmbedLoader(BaseLoader): :param padding: the padding tag for vocabulary. :param unknown: the unknown tag for vocabulary. :param normalize: bool, whether to normalize each word vector so that every vector has norm 1. + :param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will + :raise :return: np.ndarray() is determined by the pretraining embeddings Vocabulary: contain all pretraining words and two special tag[, ] """ vocab = Vocabulary(padding=padding, unknown=unknown) vec_dict = {} + found_unknown = False + found_pad = False with open(embed_filepath, 'r', encoding='utf-8') as f: line = f.readline() @@ -201,16 +217,41 @@ class EmbedLoader(BaseLoader): f.seek(0) start = 0 for idx, line in enumerate(f, start=start): - parts = line.strip().split() - word = parts[0] - if dim==-1: - dim = len(parts)-1 - vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) - vec_dict[word] = vec - vocab.add_word(word) + try: + parts = line.strip().split() + word = parts[0] + if dim==-1: + dim = len(parts)-1 + vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) + vec_dict[word] = vec + vocab.add_word(word) + if unknown is not None and unknown==word: + found_unknown = True + if found_pad is not None and padding==word: + found_pad = True + except Exception as e: + if error=='ignore': + warnings.warn("Error occurred at the {} line.".format(idx)) + pass + else: + raise e if dim==-1: raise RuntimeError("{} is an empty file.".format(embed_filepath)) matrix = np.random.randn(len(vocab), dim).astype(dtype) + # TODO 需要保证unk其它数据同分布的吗? + if (unknown is not None and not found_unknown) or (padding is not None and not found_pad): + start_idx = 0 + if padding is not None: + start_idx += 1 + if unknown is not None: + start_idx += 1 + + mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) + std = np.std(matrix[start_idx:], axis=0, keepdims=True) + if (unknown is not None and not found_unknown): + matrix[start_idx-1] = np.random.randn(1, dim).astype(dtype)*std + mean + if (padding is not None and not found_pad): + matrix[0] = np.random.randn(1, dim).astype(dtype)*std + mean for key, vec in vec_dict.items(): index = vocab.to_index(key) diff --git a/test/io/test_embed_loader.py b/test/io/test_embed_loader.py index 3f1fb5e7..9e325334 100644 --- a/test/io/test_embed_loader.py +++ b/test/io/test_embed_loader.py @@ -17,11 +17,12 @@ class TestEmbedLoader(unittest.TestCase): glove = "test/data_for_tests/glove.6B.50d_test.txt" word2vec = "test/data_for_tests/word2vec_test.txt" vocab.add_word('the') + vocab.add_word('none') g_m = EmbedLoader.load_with_vocab(glove, vocab) - self.assertEqual(g_m.shape, (3, 50)) + self.assertEqual(g_m.shape, (4, 50)) w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True) - self.assertEqual(w_m.shape, (3, 50)) - self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 3) + self.assertEqual(w_m.shape, (4, 50)) + self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4) def test_load_without_vocab(self): words = ['the', 'of', 'in', 'a', 'to', 'and']