* add more code comments * fix tester * refresh code stylestags/v0.2.0
@@ -10,7 +10,7 @@ class Batch(object): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler, as_numpy=False,): | |||||
def __init__(self, dataset, batch_size, sampler, as_numpy=False): | |||||
""" | """ | ||||
:param dataset: a DataSet object | :param dataset: a DataSet object | ||||
@@ -1,6 +1,7 @@ | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
from fastNLP.core.instance import Instance | |||||
_READERS = {} | _READERS = {} | ||||
@@ -27,10 +28,10 @@ class DataSet(object): | |||||
""" | """ | ||||
class Instance(object): | class Instance(object): | ||||
def __init__(self, dataset, idx=-1): | |||||
def __init__(self, dataset, idx=-1, **fields): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.idx = idx | self.idx = idx | ||||
self.fields = None | |||||
self.fields = fields | |||||
def __next__(self): | def __next__(self): | ||||
self.idx += 1 | self.idx += 1 | ||||
@@ -38,6 +39,14 @@ class DataSet(object): | |||||
raise StopIteration | raise StopIteration | ||||
return self | return self | ||||
def add_field(self, field_name, field): | |||||
"""Add a new field to the instance. | |||||
:param field_name: str, the name of the field. | |||||
:param field: | |||||
""" | |||||
self.fields[field_name] = field | |||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
return self.dataset[name][self.idx] | return self.dataset[name][self.idx] | ||||
@@ -47,13 +56,6 @@ class DataSet(object): | |||||
self.dataset.add_field(name, new_fields) | self.dataset.add_field(name, new_fields) | ||||
self.dataset[name][self.idx] = val | self.dataset[name][self.idx] = val | ||||
def __getattr__(self, item): | |||||
if item == 'fields': | |||||
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} | |||||
return self.fields | |||||
else: | |||||
raise AttributeError('{} does not exist.'.format(item)) | |||||
def __repr__(self): | def __repr__(self): | ||||
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name | return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name | ||||
in self.dataset.get_fields().keys()]) | in self.dataset.get_fields().keys()]) | ||||
@@ -112,14 +114,13 @@ class DataSet(object): | |||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | ||||
""" | |||||
"""Add a new field to the DataSet. | |||||
:param str name: | |||||
:param fields: | |||||
:param int padding_val: | |||||
:param bool is_input: | |||||
:param bool is_target: | |||||
:return: | |||||
:param str name: the name of the field. | |||||
:param fields: a list of int, float, or other objects. | |||||
:param int padding_val: integer for padding. | |||||
:param bool is_input: whether this field is model input. | |||||
:param bool is_target: whether this field is label or target. | |||||
""" | """ | ||||
if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
assert len(self) == len(fields) | assert len(self) == len(fields) | ||||
@@ -127,28 +128,43 @@ class DataSet(object): | |||||
is_input=is_input) | is_input=is_input) | ||||
def delete_field(self, name): | def delete_field(self, name): | ||||
"""Delete a field based on the field name. | |||||
:param str name: the name of the field to be deleted. | |||||
""" | |||||
self.field_arrays.pop(name) | self.field_arrays.pop(name) | ||||
def get_fields(self): | def get_fields(self): | ||||
"""Return all the fields with their names. | |||||
:return dict field_arrays: the internal data structure of DataSet. | |||||
""" | |||||
return self.field_arrays | return self.field_arrays | ||||
def __getitem__(self, name): | |||||
if isinstance(name, int): | |||||
return self.Instance(self, idx=name) | |||||
elif isinstance(name, slice): | |||||
ds = DataSet() | |||||
def __getitem__(self, idx): | |||||
""" | |||||
:param idx: can be int, slice, or str. | |||||
:return: If `idx` is int, return an Instance object. | |||||
If `idx` is slice, return a DataSet object. | |||||
If `idx` is str, it must be a field name, return the field. | |||||
""" | |||||
if isinstance(idx, int): | |||||
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays}) | |||||
elif isinstance(idx, slice): | |||||
data_set = DataSet() | |||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
ds.add_field(name=field.name, | |||||
fields=field.content[name], | |||||
padding_val=field.padding_val, | |||||
need_tensor=field.need_tensor, | |||||
is_target=field.is_target) | |||||
return ds | |||||
elif isinstance(name, str): | |||||
return self.field_arrays[name] | |||||
data_set.add_field(name=field.name, | |||||
fields=field.content[idx], | |||||
padding_val=field.padding_val, | |||||
is_input=field.is_input, | |||||
is_target=field.is_target) | |||||
return data_set | |||||
elif isinstance(idx, str): | |||||
return self.field_arrays[idx] | |||||
else: | else: | ||||
raise KeyError | |||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||||
def __len__(self): | def __len__(self): | ||||
if len(self.field_arrays) == 0: | if len(self.field_arrays) == 0: | ||||
@@ -208,6 +224,7 @@ class DataSet(object): | |||||
pass | pass | ||||
try: | try: | ||||
reader_cls = _READERS[item] | reader_cls = _READERS[item] | ||||
# add read_*data() support | # add read_*data() support | ||||
def _read(*args, **kwargs): | def _read(*args, **kwargs): | ||||
data = reader_cls().load(*args, **kwargs) | data = reader_cls().load(*args, **kwargs) | ||||
@@ -231,6 +248,12 @@ class DataSet(object): | |||||
return wrapper | return wrapper | ||||
def apply(self, func, new_field_name=None): | def apply(self, func, new_field_name=None): | ||||
"""Apply a function to every instance of the DataSet. | |||||
:param func: a function that takes an instance as input. | |||||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||||
:return results: returned values of the function over all instances. | |||||
""" | |||||
results = [] | results = [] | ||||
for ins in self: | for ins in self: | ||||
results.append(func(ins)) | results.append(func(ins)) | ||||
@@ -247,28 +270,24 @@ class DataSet(object): | |||||
else: | else: | ||||
return results | return results | ||||
def split(self, test_ratio): | |||||
assert isinstance(test_ratio, float) | |||||
def split(self, dev_ratio): | |||||
"""Split the dataset into training and development(validation) set. | |||||
:param float dev_ratio: the ratio of test set in all data. | |||||
:return DataSet train_set: the training set | |||||
DataSet dev_set: the development set | |||||
""" | |||||
assert isinstance(dev_ratio, float) | |||||
assert 0 < dev_ratio < 1 | |||||
all_indices = [_ for _ in range(len(self))] | all_indices = [_ for _ in range(len(self))] | ||||
np.random.shuffle(all_indices) | np.random.shuffle(all_indices) | ||||
test_indices = all_indices[:int(test_ratio)] | |||||
train_indices = all_indices[int(test_ratio):] | |||||
test_set = DataSet() | |||||
split = int(dev_ratio * len(self)) | |||||
dev_indices = all_indices[:split] | |||||
train_indices = all_indices[split:] | |||||
dev_set = DataSet() | |||||
train_set = DataSet() | train_set = DataSet() | ||||
for idx in test_indices: | |||||
test_set.append(self[idx]) | |||||
for idx in dev_indices: | |||||
dev_set.append(self[idx]) | |||||
for idx in train_indices: | for idx in train_indices: | ||||
train_set.append(self[idx]) | train_set.append(self[idx]) | ||||
return train_set, test_set | |||||
if __name__ == '__main__': | |||||
from fastNLP.core.instance import Instance | |||||
d = DataSet({'a': list('abc')}) | |||||
_ = d.a | |||||
d.apply(lambda x: x['a']) | |||||
print(d[1]) | |||||
import copy | |||||
dd = copy.deepcopy(d) | |||||
print(dd.a) | |||||
return train_set, dev_set |
@@ -3,61 +3,19 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
# logger = create_logger(__name__, "./train_test.log") | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
def __init__(self, **kwargs): | |||||
""" | |||||
:param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
""" | |||||
def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs): | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
""" | |||||
"default_args" provides default value for important settings. | |||||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||||
Otherwise, error will raise. | |||||
""" | |||||
default_args = {"batch_size": 8, | |||||
"use_cuda": False, | |||||
"pickle_path": "./save/", | |||||
"model_name": "dev_best_model.pkl", | |||||
"evaluator": Evaluator() | |||||
} | |||||
""" | |||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||||
This is used to warn users of essential settings in the training. | |||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||||
""" | |||||
required_args = {} | |||||
for req_key in required_args: | |||||
if req_key not in kwargs: | |||||
raise ValueError("Tester lacks argument {}".format(req_key)) | |||||
for key in default_args: | |||||
if key in kwargs: | |||||
if isinstance(kwargs[key], type(default_args[key])): | |||||
default_args[key] = kwargs[key] | |||||
else: | |||||
msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||||
key, type(default_args[key]), type(kwargs[key])) | |||||
raise ValueError(msg) | |||||
else: | |||||
# Tester doesn't care about extra arguments | |||||
pass | |||||
# print(default_args) | |||||
self.batch_size = default_args["batch_size"] | |||||
self.pickle_path = default_args["pickle_path"] | |||||
self.use_cuda = default_args["use_cuda"] | |||||
self._evaluator = default_args["evaluator"] | |||||
self.batch_size = batch_size | |||||
self.pickle_path = save_path | |||||
self.use_cuda = use_cuda | |||||
self._evaluator = evaluator | |||||
self._model = None | self._model = None | ||||
self.eval_history = [] # evaluation results of all batches | self.eval_history = [] # evaluation results of all batches | ||||
@@ -72,7 +30,7 @@ class Tester(object): | |||||
self.mode(network, is_test=True) | self.mode(network, is_test=True) | ||||
self.eval_history.clear() | self.eval_history.clear() | ||||
output, truths = defaultdict(list), defaultdict(list) | output, truths = defaultdict(list), defaultdict(list) | ||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
@@ -15,6 +15,8 @@ from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
@@ -78,7 +80,7 @@ class Trainer(object): | |||||
epoch = 1 | epoch = 1 | ||||
while epoch <= self.n_epochs: | while epoch <= self.n_epochs: | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler()) | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||||
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) | self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) | ||||
@@ -207,9 +209,9 @@ def best_eval_result(self, metrics): | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
DEFAULT_CHECK_NUM_BATCH = 2 | DEFAULT_CHECK_NUM_BATCH = 2 | ||||
IGNORE_CHECK_LEVEL=0 | |||||
WARNING_CHECK_LEVEL=1 | |||||
STRICT_CHECK_LEVEL=2 | |||||
IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | |||||
STRICT_CHECK_LEVEL = 2 | |||||
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): | def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
@@ -220,11 +222,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||||
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) | batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) | ||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | for batch_count, (batch_x, batch_y) in enumerate(batch): | ||||
_syn_model_data(model, batch_x, batch_y) | |||||
# forward check | |||||
if batch_count==0: | |||||
_check_forward_error(model=model, model_func=model.forward, check_level=check_level, | |||||
batch_x=batch_x) | |||||
if batch_count == 0: | |||||
check_res = _check_arg_dict_list(model.forward, batch_x) | |||||
_info_str = '' | |||||
if len(check_res.missing) > 0: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
for field_name in check_res.missing: | |||||
if hasattr(dataset, field_name): | |||||
_info_str += "{} " | |||||
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" | |||||
_info_str += "" | |||||
print("") | |||||
if len(check_res.unused) > 0: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
_info_str += "" | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
output = model(**refined_batch_x) | output = model(**refined_batch_x) | ||||
@@ -233,10 +244,14 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||||
# loss check | # loss check | ||||
if batch_count == 0: | if batch_count == 0: | ||||
_check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level, | |||||
output=output, batch_y=batch_y) | |||||
loss_input = _build_args(model.get_loss, **output, **batch_y) | |||||
loss = model.get_loss(**loss_input) | |||||
_dict = _check_arg_dict_list(model.loss, [output, batch_y]) | |||||
if len(_dict) != 0: | |||||
pass | |||||
loss_input = _build_args(model.loss, **output, **batch_y) | |||||
loss = model.loss(**loss_input) | |||||
if batch_count == 0: | |||||
if isinstance(loss, torch.Tensor): | |||||
pass | |||||
# check loss output | # check loss output | ||||
if batch_count == 0: | if batch_count == 0: | ||||
@@ -248,8 +263,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||||
model_name, loss.size() | model_name, loss.size() | ||||
)) | )) | ||||
loss.backward() | loss.backward() | ||||
model.zero_grad() | |||||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||||
if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE: | |||||
break | break | ||||
if check_level > IGNORE_CHECK_LEVEL: | if check_level > IGNORE_CHECK_LEVEL: | ||||
print('Finish checking training process.', flush=True) | print('Finish checking training process.', flush=True) | ||||
@@ -407,14 +421,7 @@ if __name__ == '__main__': | |||||
# trainer = Trainer(dataset, model) | # trainer = Trainer(dataset, model) | ||||
_check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2) | |||||
# _check_forward_error(model=model, model_func=model.forward, check_level=1, | |||||
# batch_x=fake_data_dict) | |||||
# import inspect | |||||
# print(inspect.getfullargspec(model.forward)) | |||||
if len(_dict) != 0: | |||||
pass | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | |||||
output = model(**refined_batch_x) |
@@ -1,8 +1,8 @@ | |||||
import _pickle | import _pickle | ||||
import os | |||||
import inspect | import inspect | ||||
from collections import namedtuple | |||||
import os | |||||
from collections import Counter | from collections import Counter | ||||
from collections import namedtuple | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | ||||