2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性tags/v0.4.10
@@ -1,5 +1,5 @@ | |||||
from .batch import Batch | from .batch import Batch | ||||
# from .dataset import DataSet | |||||
from .dataset import DataSet | |||||
from .fieldarray import FieldArray | from .fieldarray import FieldArray | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
@@ -9,5 +9,5 @@ from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSample | |||||
from .tester import Tester | from .tester import Tester | ||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from ..io.dataset_loader import DataSet | |||||
from .callback import Callback | from .callback import Callback | ||||
from .utils import cache_results |
@@ -61,6 +61,10 @@ class Callback(object): | |||||
"""If use_tqdm, return trainer's tqdm print bar, else return None.""" | """If use_tqdm, return trainer's tqdm print bar, else return None.""" | ||||
return self._trainer.pbar | return self._trainer.pbar | ||||
@property | |||||
def update_every(self): | |||||
"""The model in trainer will update parameters every `update_every` batches.""" | |||||
return self._trainer.update_every | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
# before the main training loop | # before the main training loop | ||||
pass | pass | ||||
@@ -6,7 +6,6 @@ from fastNLP.core.fieldarray import AutoPadder | |||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.io.base_loader import DataLoaderRegister | |||||
class DataSet(object): | class DataSet(object): | ||||
@@ -105,11 +104,6 @@ class DataSet(object): | |||||
raise AttributeError | raise AttributeError | ||||
if isinstance(item, str) and item in self.field_arrays: | if isinstance(item, str) and item in self.field_arrays: | ||||
return self.field_arrays[item] | return self.field_arrays[item] | ||||
try: | |||||
reader = DataLoaderRegister.get_reader(item) | |||||
return reader | |||||
except AttributeError: | |||||
raise | |||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
self.__dict__ = state | self.__dict__ = state | ||||
@@ -369,7 +363,7 @@ class DataSet(object): | |||||
:return dataset: the read data set | :return dataset: the read data set | ||||
""" | """ | ||||
with open(csv_path, "r") as f: | |||||
with open(csv_path, "r", encoding='utf-8') as f: | |||||
start_idx = 0 | start_idx = 0 | ||||
if headers is None: | if headers is None: | ||||
headers = f.readline().rstrip('\r\n') | headers = f.readline().rstrip('\r\n') | ||||
@@ -11,6 +11,64 @@ import torch | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
def _prepare_cache_filepath(filepath): | |||||
""" | |||||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||||
:param filepath: str. | |||||
:return: None, if not, this function will raise error | |||||
""" | |||||
_cache_filepath = os.path.abspath(filepath) | |||||
if os.path.isdir(_cache_filepath): | |||||
raise RuntimeError("The cache_file_path must be a file, not a directory.") | |||||
cache_dir = os.path.dirname(_cache_filepath) | |||||
if not os.path.exists(cache_dir): | |||||
os.makedirs(cache_dir) | |||||
def cache_results(cache_filepath, refresh=False, verbose=1): | |||||
def wrapper_(func): | |||||
signature = inspect.signature(func) | |||||
for key, _ in signature.parameters.items(): | |||||
if key in ('cache_filepath', 'refresh', 'verbose'): | |||||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||||
def wrapper(*args, **kwargs): | |||||
if 'cache_filepath' in kwargs: | |||||
_cache_filepath = kwargs.pop('cache_filepath') | |||||
assert isinstance(_cache_filepath, str), "cache_filepath can only be str." | |||||
else: | |||||
_cache_filepath = cache_filepath | |||||
if 'refresh' in kwargs: | |||||
_refresh = kwargs.pop('refresh') | |||||
assert isinstance(_refresh, bool), "refresh can only be bool." | |||||
else: | |||||
_refresh = refresh | |||||
if 'verbose' in kwargs: | |||||
_verbose = kwargs.pop('verbose') | |||||
assert isinstance(_verbose, int), "verbose can only be integer." | |||||
refresh_flag = True | |||||
if _cache_filepath is not None and _refresh is False: | |||||
# load data | |||||
if os.path.exists(_cache_filepath): | |||||
with open(_cache_filepath, 'rb') as f: | |||||
results = _pickle.load(f) | |||||
if verbose==1: | |||||
print("Read cache from {}.".format(_cache_filepath)) | |||||
refresh_flag = False | |||||
if refresh_flag: | |||||
results = func(*args, **kwargs) | |||||
if _cache_filepath is not None: | |||||
if results is None: | |||||
raise RuntimeError("The return value is None. Delete the decorator.") | |||||
_prepare_cache_filepath(_cache_filepath) | |||||
with open(_cache_filepath, 'wb') as f: | |||||
_pickle.dump(results, f) | |||||
print("Save cache to {}.".format(_cache_filepath)) | |||||
return results | |||||
return wrapper | |||||
return wrapper_ | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
"""Save an object into a pickle file. | """Save an object into a pickle file. | ||||
@@ -1,5 +1,5 @@ | |||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.dataset import DataSet | |||||
def check_build_vocab(func): | def check_build_vocab(func): | ||||
"""A decorator to make sure the indexing is built before used. | """A decorator to make sure the indexing is built before used. | ||||
@@ -151,6 +151,68 @@ class Vocabulary(object): | |||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@check_build_vocab | |||||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||||
""" | |||||
example: | |||||
# remember to use `field_name` | |||||
vocab.index_dataset(tr_data, dev_data, te_data, field_name='words') | |||||
:param datasets: fastNLP Dataset type. you can pass multiple datasets | |||||
:param field_name: str, what field to index. Only support 0,1,2 dimension. | |||||
:param new_field_name: str. What the indexed field should be named, default is to overwrite field_name | |||||
:return: | |||||
""" | |||||
def index_instance(ins): | |||||
""" | |||||
有几种情况, str, 1d-list, 2d-list | |||||
:param ins: | |||||
:return: | |||||
""" | |||||
field = ins[field_name] | |||||
if isinstance(field, str): | |||||
return self.to_index(field) | |||||
elif isinstance(field, list): | |||||
if not isinstance(field[0], list): | |||||
return [self.to_index(w) for w in field] | |||||
else: | |||||
if isinstance(field[0][0], list): | |||||
raise RuntimeError("Only support field with 2 dimensions.") | |||||
return[[self.to_index(c) for c in w] for w in field] | |||||
if new_field_name is None: | |||||
new_field_name = field_name | |||||
for dataset in datasets: | |||||
if isinstance(dataset, DataSet): | |||||
dataset.apply(index_instance, new_field_name=new_field_name) | |||||
else: | |||||
raise RuntimeError("Only DataSet type is allowed.") | |||||
def from_dataset(self, *datasets, field_name): | |||||
""" | |||||
Construct vocab from dataset. | |||||
:param datasets: DataSet. | |||||
:param field_name: str, what field is used to construct dataset. | |||||
:return: | |||||
""" | |||||
def construct_vocab(ins): | |||||
field = ins[field_name] | |||||
if isinstance(field, str): | |||||
self.add_word(field) | |||||
elif isinstance(field, list): | |||||
if not isinstance(field[0], list): | |||||
self.add_word_lst(field) | |||||
else: | |||||
if isinstance(field[0][0], list): | |||||
raise RuntimeError("Only support field with 2 dimensions.") | |||||
[self.add_word_lst(w) for w in field] | |||||
for dataset in datasets: | |||||
if isinstance(dataset, DataSet): | |||||
dataset.apply(construct_vocab) | |||||
else: | |||||
raise RuntimeError("Only DataSet type is allowed.") | |||||
def to_index(self, w): | def to_index(self, w): | ||||
""" Turn a word to an index. If w is not in Vocabulary, return the unknown label. | """ Turn a word to an index. If w is not in Vocabulary, return the unknown label. | ||||
@@ -0,0 +1 @@ | |||||
from .embed_loader import EmbedLoader |
@@ -1,3 +1,5 @@ | |||||
import os | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -124,3 +126,97 @@ class EmbedLoader(BaseLoader): | |||||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | size=(len(vocab) - np.sum(hit_flags), emb_dim)) | ||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | ||||
return embedding_matrix | return embedding_matrix | ||||
@staticmethod | |||||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True): | |||||
""" | |||||
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | |||||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||||
:param embed_filepath: str, where to read pretrain embedding | |||||
:param vocab: Vocabulary. | |||||
:param dtype: the dtype of the embedding matrix | |||||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||||
:return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain | |||||
embedding | |||||
""" | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||||
if not os.path.exists(embed_filepath): | |||||
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||||
line = f.readline().strip() | |||||
parts = line.split() | |||||
if len(parts)==2: | |||||
dim = int(parts[1]) | |||||
else: | |||||
dim = len(parts)-1 | |||||
f.seek(0) | |||||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||||
for line in f: | |||||
parts = line.strip().split() | |||||
if parts[0] in vocab: | |||||
index = vocab.to_index(parts[0]) | |||||
matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||||
hit_flags[index] = True | |||||
total_hits = sum(hit_flags) | |||||
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||||
found_vectors = matrix[hit_flags] | |||||
if len(found_vectors)!=0: | |||||
mean = np.mean(found_vectors, axis=1, keepdims=True) | |||||
std = np.std(found_vectors, axis=1, keepdims=True) | |||||
unfound_vec_num = len(vocab) - total_hits | |||||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean | |||||
matrix[hit_flags==False] = r_vecs | |||||
if normalize: | |||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||||
return matrix | |||||
@staticmethod | |||||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True): | |||||
""" | |||||
load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. | |||||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||||
:param embed_filepath: str, where to read pretrain embedding | |||||
:param dtype: the dtype of the embedding matrix | |||||
:param padding: the padding tag for vocabulary. | |||||
:param unknown: the unknown tag for vocabulary. | |||||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||||
:return: np.ndarray() is determined by the pretraining embeddings | |||||
Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>] | |||||
""" | |||||
vocab = Vocabulary(padding=padding, unknown=unknown) | |||||
vec_dict = {} | |||||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||||
line = f.readline() | |||||
start = 1 | |||||
dim = -1 | |||||
if len(line.strip().split())!=2: | |||||
f.seek(0) | |||||
start = 0 | |||||
for idx, line in enumerate(f, start=start): | |||||
parts = line.strip().split() | |||||
word = parts[0] | |||||
if dim==-1: | |||||
dim = len(parts)-1 | |||||
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||||
vec_dict[word] = vec | |||||
vocab.add_word(word) | |||||
if dim==-1: | |||||
raise RuntimeError("{} is an empty file.".format(embed_filepath)) | |||||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||||
for key, vec in vec_dict.items(): | |||||
index = vocab.to_index(key) | |||||
matrix[index] = vec | |||||
if normalize: | |||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||||
return matrix, vocab |
@@ -1,35 +0,0 @@ | |||||
import logging | |||||
import os | |||||
def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO): | |||||
"""Create a logger. | |||||
:param str logger_name: | |||||
:param str log_path: | |||||
:param log_format: | |||||
:param log_level: | |||||
:return: logger | |||||
To use a logger:: | |||||
logger.debug("this is a debug message") | |||||
logger.info("this is a info message") | |||||
logger.warning("this is a warning message") | |||||
logger.error("this is an error message") | |||||
""" | |||||
logger = logging.getLogger(logger_name) | |||||
logger.setLevel(log_level) | |||||
if log_path is None: | |||||
handler = logging.StreamHandler() | |||||
else: | |||||
os.stat(os.path.dirname(os.path.abspath(log_path))) | |||||
handler = logging.FileHandler(log_path) | |||||
handler.setLevel(log_level) | |||||
if log_format is None: | |||||
log_format = "[%(asctime)s %(name)-13s %(levelname)s %(process)d %(thread)d " \ | |||||
"%(filename)s:%(lineno)-5d] %(message)s" | |||||
formatter = logging.Formatter(log_format) | |||||
handler.setFormatter(formatter) | |||||
logger.addHandler(handler) | |||||
return logger |
@@ -0,0 +1,115 @@ | |||||
import unittest | |||||
import _pickle | |||||
from fastNLP import cache_results | |||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
import time | |||||
import os | |||||
@cache_results('test/demo1.pkl') | |||||
def process_data_1(embed_file, cws_train): | |||||
embed, vocab = EmbedLoader.load_without_vocab(embed_file) | |||||
time.sleep(1) # 测试是否通过读取cache获得结果 | |||||
with open(cws_train, 'r', encoding='utf-8') as f: | |||||
d = DataSet() | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line)>0: | |||||
d.append(Instance(raw=line)) | |||||
return embed, vocab, d | |||||
class TestCache(unittest.TestCase): | |||||
def test_cache_save(self): | |||||
try: | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') | |||||
end_time = time.time() | |||||
pre_time = end_time - start_time | |||||
with open('test/demo1.pkl', 'rb') as f: | |||||
_embed, _vocab, _d = _pickle.load(f) | |||||
self.assertEqual(embed.shape, _embed.shape) | |||||
for i in range(embed.shape[0]): | |||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') | |||||
end_time = time.time() | |||||
read_time = end_time - start_time | |||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||||
self.assertGreater(pre_time-0.5, read_time) | |||||
finally: | |||||
os.remove('test/demo1.pkl') | |||||
def test_cache_save_overwrite_path(self): | |||||
try: | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
cache_filepath='test/demo_overwrite.pkl') | |||||
end_time = time.time() | |||||
pre_time = end_time - start_time | |||||
with open('test/demo_overwrite.pkl', 'rb') as f: | |||||
_embed, _vocab, _d = _pickle.load(f) | |||||
self.assertEqual(embed.shape, _embed.shape) | |||||
for i in range(embed.shape[0]): | |||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
cache_filepath='test/demo_overwrite.pkl') | |||||
end_time = time.time() | |||||
read_time = end_time - start_time | |||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||||
self.assertGreater(pre_time-0.5, read_time) | |||||
finally: | |||||
os.remove('test/demo_overwrite.pkl') | |||||
def test_cache_refresh(self): | |||||
try: | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
refresh=True) | |||||
end_time = time.time() | |||||
pre_time = end_time - start_time | |||||
with open('test/demo1.pkl', 'rb') as f: | |||||
_embed, _vocab, _d = _pickle.load(f) | |||||
self.assertEqual(embed.shape, _embed.shape) | |||||
for i in range(embed.shape[0]): | |||||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||||
start_time = time.time() | |||||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||||
refresh=True) | |||||
end_time = time.time() | |||||
read_time = end_time - start_time | |||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||||
self.assertGreater(0.1, pre_time-read_time) | |||||
finally: | |||||
os.remove('test/demo1.pkl') | |||||
def test_duplicate_keyword(self): | |||||
with self.assertRaises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_verbose(a, verbose): | |||||
pass | |||||
func_verbose(0, 1) | |||||
with self.assertRaises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_cache(a, cache_filepath): | |||||
pass | |||||
func_cache(1, 2) | |||||
with self.assertRaises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_refresh(a, refresh): | |||||
pass | |||||
func_refresh(1, 2) | |||||
def test_create_cache_dir(self): | |||||
@cache_results('test/demo1/demo.pkl') | |||||
def cache(): | |||||
return 1, 2 | |||||
try: | |||||
results = cache() | |||||
print(results) | |||||
finally: | |||||
os.remove('test/demo1/demo.pkl') | |||||
os.rmdir('test/demo1') |
@@ -2,6 +2,8 @@ import unittest | |||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | ||||
"works", "well", "in", "most", "cases", "scales", "well"] | "works", "well", "in", "most", "cases", "scales", "well"] | ||||
@@ -31,6 +33,42 @@ class TestAdd(unittest.TestCase): | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_from_dataset(self): | |||||
start_char = 65 | |||||
num_samples = 10 | |||||
# 0 dim | |||||
dataset = DataSet() | |||||
for i in range(num_samples): | |||||
ins = Instance(char=chr(start_char+i)) | |||||
dataset.append(ins) | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(dataset, field_name='char') | |||||
for i in range(num_samples): | |||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
vocab.index_dataset(dataset, field_name='char') | |||||
# 1 dim | |||||
dataset = DataSet() | |||||
for i in range(num_samples): | |||||
ins = Instance(char=[chr(start_char+i)]*6) | |||||
dataset.append(ins) | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(dataset, field_name='char') | |||||
for i in range(num_samples): | |||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
vocab.index_dataset(dataset, field_name='char') | |||||
# 2 dim | |||||
dataset = DataSet() | |||||
for i in range(num_samples): | |||||
ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)]) | |||||
dataset.append(ins) | |||||
vocab = Vocabulary() | |||||
vocab.from_dataset(dataset, field_name='char') | |||||
for i in range(num_samples): | |||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
vocab.index_dataset(dataset, field_name='char') | |||||
class TestIndexing(unittest.TestCase): | class TestIndexing(unittest.TestCase): | ||||
def test_len(self): | def test_len(self): | ||||
@@ -1,4 +1,5 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.embed_loader import EmbedLoader | from fastNLP.io.embed_loader import EmbedLoader | ||||
@@ -10,3 +11,34 @@ class TestEmbedLoader(unittest.TestCase): | |||||
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | ||||
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | ||||
self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | ||||
def test_load_with_vocab(self): | |||||
vocab = Vocabulary() | |||||
glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||||
word2vec = "test/data_for_tests/word2vec_test.txt" | |||||
vocab.add_word('the') | |||||
g_m = EmbedLoader.load_with_vocab(glove, vocab) | |||||
self.assertEqual(g_m.shape, (3, 50)) | |||||
w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True) | |||||
self.assertEqual(w_m.shape, (3, 50)) | |||||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 3) | |||||
def test_load_without_vocab(self): | |||||
words = ['the', 'of', 'in', 'a', 'to', 'and'] | |||||
glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||||
word2vec = "test/data_for_tests/word2vec_test.txt" | |||||
g_m, vocab = EmbedLoader.load_without_vocab(glove) | |||||
self.assertEqual(g_m.shape, (8, 50)) | |||||
for word in words: | |||||
self.assertIn(word, vocab) | |||||
w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True) | |||||
self.assertEqual(w_m.shape, (8, 50)) | |||||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8) | |||||
for word in words: | |||||
self.assertIn(word, vocab) | |||||
# no unk | |||||
w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None) | |||||
self.assertEqual(w_m.shape, (7, 50)) | |||||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7) | |||||
for word in words: | |||||
self.assertIn(word, vocab) |