* add more code comments * fix tester * refresh code stylestags/v0.2.0
@@ -10,7 +10,7 @@ class Batch(object): | |||
""" | |||
def __init__(self, dataset, batch_size, sampler, as_numpy=False,): | |||
def __init__(self, dataset, batch_size, sampler, as_numpy=False): | |||
""" | |||
:param dataset: a DataSet object | |||
@@ -1,6 +1,7 @@ | |||
import numpy as np | |||
from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.instance import Instance | |||
_READERS = {} | |||
@@ -27,10 +28,10 @@ class DataSet(object): | |||
""" | |||
class Instance(object): | |||
def __init__(self, dataset, idx=-1): | |||
def __init__(self, dataset, idx=-1, **fields): | |||
self.dataset = dataset | |||
self.idx = idx | |||
self.fields = None | |||
self.fields = fields | |||
def __next__(self): | |||
self.idx += 1 | |||
@@ -38,6 +39,14 @@ class DataSet(object): | |||
raise StopIteration | |||
return self | |||
def add_field(self, field_name, field): | |||
"""Add a new field to the instance. | |||
:param field_name: str, the name of the field. | |||
:param field: | |||
""" | |||
self.fields[field_name] = field | |||
def __getitem__(self, name): | |||
return self.dataset[name][self.idx] | |||
@@ -47,13 +56,6 @@ class DataSet(object): | |||
self.dataset.add_field(name, new_fields) | |||
self.dataset[name][self.idx] = val | |||
def __getattr__(self, item): | |||
if item == 'fields': | |||
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} | |||
return self.fields | |||
else: | |||
raise AttributeError('{} does not exist.'.format(item)) | |||
def __repr__(self): | |||
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name | |||
in self.dataset.get_fields().keys()]) | |||
@@ -112,14 +114,13 @@ class DataSet(object): | |||
self.field_arrays[name].append(field) | |||
def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | |||
""" | |||
"""Add a new field to the DataSet. | |||
:param str name: | |||
:param fields: | |||
:param int padding_val: | |||
:param bool is_input: | |||
:param bool is_target: | |||
:return: | |||
:param str name: the name of the field. | |||
:param fields: a list of int, float, or other objects. | |||
:param int padding_val: integer for padding. | |||
:param bool is_input: whether this field is model input. | |||
:param bool is_target: whether this field is label or target. | |||
""" | |||
if len(self.field_arrays) != 0: | |||
assert len(self) == len(fields) | |||
@@ -127,28 +128,43 @@ class DataSet(object): | |||
is_input=is_input) | |||
def delete_field(self, name): | |||
"""Delete a field based on the field name. | |||
:param str name: the name of the field to be deleted. | |||
""" | |||
self.field_arrays.pop(name) | |||
def get_fields(self): | |||
"""Return all the fields with their names. | |||
:return dict field_arrays: the internal data structure of DataSet. | |||
""" | |||
return self.field_arrays | |||
def __getitem__(self, name): | |||
if isinstance(name, int): | |||
return self.Instance(self, idx=name) | |||
elif isinstance(name, slice): | |||
ds = DataSet() | |||
def __getitem__(self, idx): | |||
""" | |||
:param idx: can be int, slice, or str. | |||
:return: If `idx` is int, return an Instance object. | |||
If `idx` is slice, return a DataSet object. | |||
If `idx` is str, it must be a field name, return the field. | |||
""" | |||
if isinstance(idx, int): | |||
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays}) | |||
elif isinstance(idx, slice): | |||
data_set = DataSet() | |||
for field in self.field_arrays.values(): | |||
ds.add_field(name=field.name, | |||
fields=field.content[name], | |||
padding_val=field.padding_val, | |||
need_tensor=field.need_tensor, | |||
is_target=field.is_target) | |||
return ds | |||
elif isinstance(name, str): | |||
return self.field_arrays[name] | |||
data_set.add_field(name=field.name, | |||
fields=field.content[idx], | |||
padding_val=field.padding_val, | |||
is_input=field.is_input, | |||
is_target=field.is_target) | |||
return data_set | |||
elif isinstance(idx, str): | |||
return self.field_arrays[idx] | |||
else: | |||
raise KeyError | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __len__(self): | |||
if len(self.field_arrays) == 0: | |||
@@ -208,6 +224,7 @@ class DataSet(object): | |||
pass | |||
try: | |||
reader_cls = _READERS[item] | |||
# add read_*data() support | |||
def _read(*args, **kwargs): | |||
data = reader_cls().load(*args, **kwargs) | |||
@@ -231,6 +248,12 @@ class DataSet(object): | |||
return wrapper | |||
def apply(self, func, new_field_name=None): | |||
"""Apply a function to every instance of the DataSet. | |||
:param func: a function that takes an instance as input. | |||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||
:return results: returned values of the function over all instances. | |||
""" | |||
results = [] | |||
for ins in self: | |||
results.append(func(ins)) | |||
@@ -247,28 +270,24 @@ class DataSet(object): | |||
else: | |||
return results | |||
def split(self, test_ratio): | |||
assert isinstance(test_ratio, float) | |||
def split(self, dev_ratio): | |||
"""Split the dataset into training and development(validation) set. | |||
:param float dev_ratio: the ratio of test set in all data. | |||
:return DataSet train_set: the training set | |||
DataSet dev_set: the development set | |||
""" | |||
assert isinstance(dev_ratio, float) | |||
assert 0 < dev_ratio < 1 | |||
all_indices = [_ for _ in range(len(self))] | |||
np.random.shuffle(all_indices) | |||
test_indices = all_indices[:int(test_ratio)] | |||
train_indices = all_indices[int(test_ratio):] | |||
test_set = DataSet() | |||
split = int(dev_ratio * len(self)) | |||
dev_indices = all_indices[:split] | |||
train_indices = all_indices[split:] | |||
dev_set = DataSet() | |||
train_set = DataSet() | |||
for idx in test_indices: | |||
test_set.append(self[idx]) | |||
for idx in dev_indices: | |||
dev_set.append(self[idx]) | |||
for idx in train_indices: | |||
train_set.append(self[idx]) | |||
return train_set, test_set | |||
if __name__ == '__main__': | |||
from fastNLP.core.instance import Instance | |||
d = DataSet({'a': list('abc')}) | |||
_ = d.a | |||
d.apply(lambda x: x['a']) | |||
print(d[1]) | |||
import copy | |||
dd = copy.deepcopy(d) | |||
print(dd.a) | |||
return train_set, dev_set |
@@ -3,61 +3,19 @@ from collections import defaultdict | |||
import torch | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.metrics import Evaluator | |||
from fastNLP.core.sampler import RandomSampler | |||
# logger = create_logger(__name__, "./train_test.log") | |||
class Tester(object): | |||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||
def __init__(self, **kwargs): | |||
""" | |||
:param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||
""" | |||
def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs): | |||
super(Tester, self).__init__() | |||
""" | |||
"default_args" provides default value for important settings. | |||
The initialization arguments "kwargs" with the same key (name) will override the default value. | |||
"kwargs" must have the same type as "default_args" on corresponding keys. | |||
Otherwise, error will raise. | |||
""" | |||
default_args = {"batch_size": 8, | |||
"use_cuda": False, | |||
"pickle_path": "./save/", | |||
"model_name": "dev_best_model.pkl", | |||
"evaluator": Evaluator() | |||
} | |||
""" | |||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
This is used to warn users of essential settings in the training. | |||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||
""" | |||
required_args = {} | |||
for req_key in required_args: | |||
if req_key not in kwargs: | |||
raise ValueError("Tester lacks argument {}".format(req_key)) | |||
for key in default_args: | |||
if key in kwargs: | |||
if isinstance(kwargs[key], type(default_args[key])): | |||
default_args[key] = kwargs[key] | |||
else: | |||
msg = "Argument %s type mismatch: expected %s while get %s" % ( | |||
key, type(default_args[key]), type(kwargs[key])) | |||
raise ValueError(msg) | |||
else: | |||
# Tester doesn't care about extra arguments | |||
pass | |||
# print(default_args) | |||
self.batch_size = default_args["batch_size"] | |||
self.pickle_path = default_args["pickle_path"] | |||
self.use_cuda = default_args["use_cuda"] | |||
self._evaluator = default_args["evaluator"] | |||
self.batch_size = batch_size | |||
self.pickle_path = save_path | |||
self.use_cuda = use_cuda | |||
self._evaluator = evaluator | |||
self._model = None | |||
self.eval_history = [] # evaluation results of all batches | |||
@@ -72,7 +30,7 @@ class Tester(object): | |||
self.mode(network, is_test=True) | |||
self.eval_history.clear() | |||
output, truths = defaultdict(list), defaultdict(list) | |||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
@@ -15,6 +15,8 @@ from fastNLP.core.optimizer import Optimizer | |||
from fastNLP.core.sampler import RandomSampler | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.tester import Tester | |||
from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _build_args | |||
@@ -78,7 +80,7 @@ class Trainer(object): | |||
epoch = 1 | |||
while epoch <= self.n_epochs: | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler()) | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) | |||
@@ -207,9 +209,9 @@ def best_eval_result(self, metrics): | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
DEFAULT_CHECK_NUM_BATCH = 2 | |||
IGNORE_CHECK_LEVEL=0 | |||
WARNING_CHECK_LEVEL=1 | |||
STRICT_CHECK_LEVEL=2 | |||
IGNORE_CHECK_LEVEL = 0 | |||
WARNING_CHECK_LEVEL = 1 | |||
STRICT_CHECK_LEVEL = 2 | |||
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): | |||
# check get_loss 方法 | |||
@@ -220,11 +222,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
_syn_model_data(model, batch_x, batch_y) | |||
# forward check | |||
if batch_count==0: | |||
_check_forward_error(model=model, model_func=model.forward, check_level=check_level, | |||
batch_x=batch_x) | |||
if batch_count == 0: | |||
check_res = _check_arg_dict_list(model.forward, batch_x) | |||
_info_str = '' | |||
if len(check_res.missing) > 0: | |||
if check_level == WARNING_CHECK_LEVEL: | |||
for field_name in check_res.missing: | |||
if hasattr(dataset, field_name): | |||
_info_str += "{} " | |||
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" | |||
_info_str += "" | |||
print("") | |||
if len(check_res.unused) > 0: | |||
if check_level == WARNING_CHECK_LEVEL: | |||
_info_str += "" | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
output = model(**refined_batch_x) | |||
@@ -233,10 +244,14 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
# loss check | |||
if batch_count == 0: | |||
_check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level, | |||
output=output, batch_y=batch_y) | |||
loss_input = _build_args(model.get_loss, **output, **batch_y) | |||
loss = model.get_loss(**loss_input) | |||
_dict = _check_arg_dict_list(model.loss, [output, batch_y]) | |||
if len(_dict) != 0: | |||
pass | |||
loss_input = _build_args(model.loss, **output, **batch_y) | |||
loss = model.loss(**loss_input) | |||
if batch_count == 0: | |||
if isinstance(loss, torch.Tensor): | |||
pass | |||
# check loss output | |||
if batch_count == 0: | |||
@@ -248,8 +263,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
model_name, loss.size() | |||
)) | |||
loss.backward() | |||
model.zero_grad() | |||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||
if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE: | |||
break | |||
if check_level > IGNORE_CHECK_LEVEL: | |||
print('Finish checking training process.', flush=True) | |||
@@ -407,14 +421,7 @@ if __name__ == '__main__': | |||
# trainer = Trainer(dataset, model) | |||
_check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2) | |||
# _check_forward_error(model=model, model_func=model.forward, check_level=1, | |||
# batch_x=fake_data_dict) | |||
# import inspect | |||
# print(inspect.getfullargspec(model.forward)) | |||
if len(_dict) != 0: | |||
pass | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
output = model(**refined_batch_x) |
@@ -1,8 +1,8 @@ | |||
import _pickle | |||
import os | |||
import inspect | |||
from collections import namedtuple | |||
import os | |||
from collections import Counter | |||
from collections import namedtuple | |||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||