2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性tags/v0.4.10
@@ -1,5 +1,5 @@ | |||
from .batch import Batch | |||
# from .dataset import DataSet | |||
from .dataset import DataSet | |||
from .fieldarray import FieldArray | |||
from .instance import Instance | |||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
@@ -9,5 +9,5 @@ from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSample | |||
from .tester import Tester | |||
from .trainer import Trainer | |||
from .vocabulary import Vocabulary | |||
from ..io.dataset_loader import DataSet | |||
from .callback import Callback | |||
from .utils import cache_results |
@@ -61,6 +61,10 @@ class Callback(object): | |||
"""If use_tqdm, return trainer's tqdm print bar, else return None.""" | |||
return self._trainer.pbar | |||
@property | |||
def update_every(self): | |||
"""The model in trainer will update parameters every `update_every` batches.""" | |||
return self._trainer.update_every | |||
def on_train_begin(self): | |||
# before the main training loop | |||
pass | |||
@@ -6,7 +6,6 @@ from fastNLP.core.fieldarray import AutoPadder | |||
from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.io.base_loader import DataLoaderRegister | |||
class DataSet(object): | |||
@@ -105,11 +104,6 @@ class DataSet(object): | |||
raise AttributeError | |||
if isinstance(item, str) and item in self.field_arrays: | |||
return self.field_arrays[item] | |||
try: | |||
reader = DataLoaderRegister.get_reader(item) | |||
return reader | |||
except AttributeError: | |||
raise | |||
def __setstate__(self, state): | |||
self.__dict__ = state | |||
@@ -369,7 +363,7 @@ class DataSet(object): | |||
:return dataset: the read data set | |||
""" | |||
with open(csv_path, "r") as f: | |||
with open(csv_path, "r", encoding='utf-8') as f: | |||
start_idx = 0 | |||
if headers is None: | |||
headers = f.readline().rstrip('\r\n') | |||
@@ -11,6 +11,64 @@ import torch | |||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
def _prepare_cache_filepath(filepath): | |||
""" | |||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||
:param filepath: str. | |||
:return: None, if not, this function will raise error | |||
""" | |||
_cache_filepath = os.path.abspath(filepath) | |||
if os.path.isdir(_cache_filepath): | |||
raise RuntimeError("The cache_file_path must be a file, not a directory.") | |||
cache_dir = os.path.dirname(_cache_filepath) | |||
if not os.path.exists(cache_dir): | |||
os.makedirs(cache_dir) | |||
def cache_results(cache_filepath, refresh=False, verbose=1): | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
for key, _ in signature.parameters.items(): | |||
if key in ('cache_filepath', 'refresh', 'verbose'): | |||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
def wrapper(*args, **kwargs): | |||
if 'cache_filepath' in kwargs: | |||
_cache_filepath = kwargs.pop('cache_filepath') | |||
assert isinstance(_cache_filepath, str), "cache_filepath can only be str." | |||
else: | |||
_cache_filepath = cache_filepath | |||
if 'refresh' in kwargs: | |||
_refresh = kwargs.pop('refresh') | |||
assert isinstance(_refresh, bool), "refresh can only be bool." | |||
else: | |||
_refresh = refresh | |||
if 'verbose' in kwargs: | |||
_verbose = kwargs.pop('verbose') | |||
assert isinstance(_verbose, int), "verbose can only be integer." | |||
refresh_flag = True | |||
if _cache_filepath is not None and _refresh is False: | |||
# load data | |||
if os.path.exists(_cache_filepath): | |||
with open(_cache_filepath, 'rb') as f: | |||
results = _pickle.load(f) | |||
if verbose==1: | |||
print("Read cache from {}.".format(_cache_filepath)) | |||
refresh_flag = False | |||
if refresh_flag: | |||
results = func(*args, **kwargs) | |||
if _cache_filepath is not None: | |||
if results is None: | |||
raise RuntimeError("The return value is None. Delete the decorator.") | |||
_prepare_cache_filepath(_cache_filepath) | |||
with open(_cache_filepath, 'wb') as f: | |||
_pickle.dump(results, f) | |||
print("Save cache to {}.".format(_cache_filepath)) | |||
return results | |||
return wrapper | |||
return wrapper_ | |||
def save_pickle(obj, pickle_path, file_name): | |||
"""Save an object into a pickle file. | |||
@@ -1,5 +1,5 @@ | |||
from collections import Counter | |||
from fastNLP.core.dataset import DataSet | |||
def check_build_vocab(func): | |||
"""A decorator to make sure the indexing is built before used. | |||
@@ -151,6 +151,68 @@ class Vocabulary(object): | |||
else: | |||
raise ValueError("word {} not in vocabulary".format(w)) | |||
@check_build_vocab | |||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
""" | |||
example: | |||
# remember to use `field_name` | |||
vocab.index_dataset(tr_data, dev_data, te_data, field_name='words') | |||
:param datasets: fastNLP Dataset type. you can pass multiple datasets | |||
:param field_name: str, what field to index. Only support 0,1,2 dimension. | |||
:param new_field_name: str. What the indexed field should be named, default is to overwrite field_name | |||
:return: | |||
""" | |||
def index_instance(ins): | |||
""" | |||
有几种情况, str, 1d-list, 2d-list | |||
:param ins: | |||
:return: | |||
""" | |||
field = ins[field_name] | |||
if isinstance(field, str): | |||
return self.to_index(field) | |||
elif isinstance(field, list): | |||
if not isinstance(field[0], list): | |||
return [self.to_index(w) for w in field] | |||
else: | |||
if isinstance(field[0][0], list): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
return[[self.to_index(c) for c in w] for w in field] | |||
if new_field_name is None: | |||
new_field_name = field_name | |||
for dataset in datasets: | |||
if isinstance(dataset, DataSet): | |||
dataset.apply(index_instance, new_field_name=new_field_name) | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
def from_dataset(self, *datasets, field_name): | |||
""" | |||
Construct vocab from dataset. | |||
:param datasets: DataSet. | |||
:param field_name: str, what field is used to construct dataset. | |||
:return: | |||
""" | |||
def construct_vocab(ins): | |||
field = ins[field_name] | |||
if isinstance(field, str): | |||
self.add_word(field) | |||
elif isinstance(field, list): | |||
if not isinstance(field[0], list): | |||
self.add_word_lst(field) | |||
else: | |||
if isinstance(field[0][0], list): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
[self.add_word_lst(w) for w in field] | |||
for dataset in datasets: | |||
if isinstance(dataset, DataSet): | |||
dataset.apply(construct_vocab) | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
def to_index(self, w): | |||
""" Turn a word to an index. If w is not in Vocabulary, return the unknown label. | |||
@@ -0,0 +1 @@ | |||
from .embed_loader import EmbedLoader |
@@ -1,3 +1,5 @@ | |||
import os | |||
import numpy as np | |||
import torch | |||
@@ -124,3 +126,97 @@ class EmbedLoader(BaseLoader): | |||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||
return embedding_matrix | |||
@staticmethod | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True): | |||
""" | |||
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | |||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||
:param embed_filepath: str, where to read pretrain embedding | |||
:param vocab: Vocabulary. | |||
:param dtype: the dtype of the embedding matrix | |||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||
:return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain | |||
embedding | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
if not os.path.exists(embed_filepath): | |||
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||
line = f.readline().strip() | |||
parts = line.split() | |||
if len(parts)==2: | |||
dim = int(parts[1]) | |||
else: | |||
dim = len(parts)-1 | |||
f.seek(0) | |||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
for line in f: | |||
parts = line.strip().split() | |||
if parts[0] in vocab: | |||
index = vocab.to_index(parts[0]) | |||
matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||
hit_flags[index] = True | |||
total_hits = sum(hit_flags) | |||
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||
found_vectors = matrix[hit_flags] | |||
if len(found_vectors)!=0: | |||
mean = np.mean(found_vectors, axis=1, keepdims=True) | |||
std = np.std(found_vectors, axis=1, keepdims=True) | |||
unfound_vec_num = len(vocab) - total_hits | |||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean | |||
matrix[hit_flags==False] = r_vecs | |||
if normalize: | |||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
return matrix | |||
@staticmethod | |||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True): | |||
""" | |||
load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. | |||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||
:param embed_filepath: str, where to read pretrain embedding | |||
:param dtype: the dtype of the embedding matrix | |||
:param padding: the padding tag for vocabulary. | |||
:param unknown: the unknown tag for vocabulary. | |||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||
:return: np.ndarray() is determined by the pretraining embeddings | |||
Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>] | |||
""" | |||
vocab = Vocabulary(padding=padding, unknown=unknown) | |||
vec_dict = {} | |||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
line = f.readline() | |||
start = 1 | |||
dim = -1 | |||
if len(line.strip().split())!=2: | |||
f.seek(0) | |||
start = 0 | |||
for idx, line in enumerate(f, start=start): | |||
parts = line.strip().split() | |||
word = parts[0] | |||
if dim==-1: | |||
dim = len(parts)-1 | |||
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||
vec_dict[word] = vec | |||
vocab.add_word(word) | |||
if dim==-1: | |||
raise RuntimeError("{} is an empty file.".format(embed_filepath)) | |||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
for key, vec in vec_dict.items(): | |||
index = vocab.to_index(key) | |||
matrix[index] = vec | |||
if normalize: | |||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
return matrix, vocab |
@@ -1,35 +0,0 @@ | |||
import logging | |||
import os | |||
def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO): | |||
"""Create a logger. | |||
:param str logger_name: | |||
:param str log_path: | |||
:param log_format: | |||
:param log_level: | |||
:return: logger | |||
To use a logger:: | |||
logger.debug("this is a debug message") | |||
logger.info("this is a info message") | |||
logger.warning("this is a warning message") | |||
logger.error("this is an error message") | |||
""" | |||
logger = logging.getLogger(logger_name) | |||
logger.setLevel(log_level) | |||
if log_path is None: | |||
handler = logging.StreamHandler() | |||
else: | |||
os.stat(os.path.dirname(os.path.abspath(log_path))) | |||
handler = logging.FileHandler(log_path) | |||
handler.setLevel(log_level) | |||
if log_format is None: | |||
log_format = "[%(asctime)s %(name)-13s %(levelname)s %(process)d %(thread)d " \ | |||
"%(filename)s:%(lineno)-5d] %(message)s" | |||
formatter = logging.Formatter(log_format) | |||
handler.setFormatter(formatter) | |||
logger.addHandler(handler) | |||
return logger |
@@ -0,0 +1,115 @@ | |||
import unittest | |||
import _pickle | |||
from fastNLP import cache_results | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
import time | |||
import os | |||
@cache_results('test/demo1.pkl') | |||
def process_data_1(embed_file, cws_train): | |||
embed, vocab = EmbedLoader.load_without_vocab(embed_file) | |||
time.sleep(1) # 测试是否通过读取cache获得结果 | |||
with open(cws_train, 'r', encoding='utf-8') as f: | |||
d = DataSet() | |||
for line in f: | |||
line = line.strip() | |||
if len(line)>0: | |||
d.append(Instance(raw=line)) | |||
return embed, vocab, d | |||
class TestCache(unittest.TestCase): | |||
def test_cache_save(self): | |||
try: | |||
start_time = time.time() | |||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') | |||
end_time = time.time() | |||
pre_time = end_time - start_time | |||
with open('test/demo1.pkl', 'rb') as f: | |||
_embed, _vocab, _d = _pickle.load(f) | |||
self.assertEqual(embed.shape, _embed.shape) | |||
for i in range(embed.shape[0]): | |||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||
start_time = time.time() | |||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') | |||
end_time = time.time() | |||
read_time = end_time - start_time | |||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
self.assertGreater(pre_time-0.5, read_time) | |||
finally: | |||
os.remove('test/demo1.pkl') | |||
def test_cache_save_overwrite_path(self): | |||
try: | |||
start_time = time.time() | |||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||
cache_filepath='test/demo_overwrite.pkl') | |||
end_time = time.time() | |||
pre_time = end_time - start_time | |||
with open('test/demo_overwrite.pkl', 'rb') as f: | |||
_embed, _vocab, _d = _pickle.load(f) | |||
self.assertEqual(embed.shape, _embed.shape) | |||
for i in range(embed.shape[0]): | |||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||
start_time = time.time() | |||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||
cache_filepath='test/demo_overwrite.pkl') | |||
end_time = time.time() | |||
read_time = end_time - start_time | |||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
self.assertGreater(pre_time-0.5, read_time) | |||
finally: | |||
os.remove('test/demo_overwrite.pkl') | |||
def test_cache_refresh(self): | |||
try: | |||
start_time = time.time() | |||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||
refresh=True) | |||
end_time = time.time() | |||
pre_time = end_time - start_time | |||
with open('test/demo1.pkl', 'rb') as f: | |||
_embed, _vocab, _d = _pickle.load(f) | |||
self.assertEqual(embed.shape, _embed.shape) | |||
for i in range(embed.shape[0]): | |||
self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||
start_time = time.time() | |||
embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||
refresh=True) | |||
end_time = time.time() | |||
read_time = end_time - start_time | |||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
self.assertGreater(0.1, pre_time-read_time) | |||
finally: | |||
os.remove('test/demo1.pkl') | |||
def test_duplicate_keyword(self): | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_verbose(a, verbose): | |||
pass | |||
func_verbose(0, 1) | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_cache(a, cache_filepath): | |||
pass | |||
func_cache(1, 2) | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_refresh(a, refresh): | |||
pass | |||
func_refresh(1, 2) | |||
def test_create_cache_dir(self): | |||
@cache_results('test/demo1/demo.pkl') | |||
def cache(): | |||
return 1, 2 | |||
try: | |||
results = cache() | |||
print(results) | |||
finally: | |||
os.remove('test/demo1/demo.pkl') | |||
os.rmdir('test/demo1') |
@@ -2,6 +2,8 @@ import unittest | |||
from collections import Counter | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
"works", "well", "in", "most", "cases", "scales", "well"] | |||
@@ -31,6 +33,42 @@ class TestAdd(unittest.TestCase): | |||
vocab.update(text) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_from_dataset(self): | |||
start_char = 65 | |||
num_samples = 10 | |||
# 0 dim | |||
dataset = DataSet() | |||
for i in range(num_samples): | |||
ins = Instance(char=chr(start_char+i)) | |||
dataset.append(ins) | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='char') | |||
for i in range(num_samples): | |||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
vocab.index_dataset(dataset, field_name='char') | |||
# 1 dim | |||
dataset = DataSet() | |||
for i in range(num_samples): | |||
ins = Instance(char=[chr(start_char+i)]*6) | |||
dataset.append(ins) | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='char') | |||
for i in range(num_samples): | |||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
vocab.index_dataset(dataset, field_name='char') | |||
# 2 dim | |||
dataset = DataSet() | |||
for i in range(num_samples): | |||
ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)]) | |||
dataset.append(ins) | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='char') | |||
for i in range(num_samples): | |||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
vocab.index_dataset(dataset, field_name='char') | |||
class TestIndexing(unittest.TestCase): | |||
def test_len(self): | |||
@@ -1,4 +1,5 @@ | |||
import unittest | |||
import numpy as np | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
@@ -10,3 +11,34 @@ class TestEmbedLoader(unittest.TestCase): | |||
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | |||
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | |||
self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | |||
def test_load_with_vocab(self): | |||
vocab = Vocabulary() | |||
glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||
word2vec = "test/data_for_tests/word2vec_test.txt" | |||
vocab.add_word('the') | |||
g_m = EmbedLoader.load_with_vocab(glove, vocab) | |||
self.assertEqual(g_m.shape, (3, 50)) | |||
w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True) | |||
self.assertEqual(w_m.shape, (3, 50)) | |||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 3) | |||
def test_load_without_vocab(self): | |||
words = ['the', 'of', 'in', 'a', 'to', 'and'] | |||
glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||
word2vec = "test/data_for_tests/word2vec_test.txt" | |||
g_m, vocab = EmbedLoader.load_without_vocab(glove) | |||
self.assertEqual(g_m.shape, (8, 50)) | |||
for word in words: | |||
self.assertIn(word, vocab) | |||
w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True) | |||
self.assertEqual(w_m.shape, (8, 50)) | |||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8) | |||
for word in words: | |||
self.assertIn(word, vocab) | |||
# no unk | |||
w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None) | |||
self.assertEqual(w_m.shape, (7, 50)) | |||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7) | |||
for word in words: | |||
self.assertIn(word, vocab) |