diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 5e0be4c3..38da83da 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -10,7 +10,7 @@ class Batch(object): """ - def __init__(self, dataset, batch_size, sampler, as_numpy=False,): + def __init__(self, dataset, batch_size, sampler, as_numpy=False): """ :param dataset: a DataSet object diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 5e72106f..34ce56ba 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,6 +1,7 @@ import numpy as np from fastNLP.core.fieldarray import FieldArray +from fastNLP.core.instance import Instance _READERS = {} @@ -27,10 +28,10 @@ class DataSet(object): """ class Instance(object): - def __init__(self, dataset, idx=-1): + def __init__(self, dataset, idx=-1, **fields): self.dataset = dataset self.idx = idx - self.fields = None + self.fields = fields def __next__(self): self.idx += 1 @@ -38,6 +39,14 @@ class DataSet(object): raise StopIteration return self + def add_field(self, field_name, field): + """Add a new field to the instance. + + :param field_name: str, the name of the field. + :param field: + """ + self.fields[field_name] = field + def __getitem__(self, name): return self.dataset[name][self.idx] @@ -47,13 +56,6 @@ class DataSet(object): self.dataset.add_field(name, new_fields) self.dataset[name][self.idx] = val - def __getattr__(self, item): - if item == 'fields': - self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} - return self.fields - else: - raise AttributeError('{} does not exist.'.format(item)) - def __repr__(self): return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name in self.dataset.get_fields().keys()]) @@ -112,14 +114,13 @@ class DataSet(object): self.field_arrays[name].append(field) def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): - """ + """Add a new field to the DataSet. - :param str name: - :param fields: - :param int padding_val: - :param bool is_input: - :param bool is_target: - :return: + :param str name: the name of the field. + :param fields: a list of int, float, or other objects. + :param int padding_val: integer for padding. + :param bool is_input: whether this field is model input. + :param bool is_target: whether this field is label or target. """ if len(self.field_arrays) != 0: assert len(self) == len(fields) @@ -127,28 +128,43 @@ class DataSet(object): is_input=is_input) def delete_field(self, name): + """Delete a field based on the field name. + + :param str name: the name of the field to be deleted. + """ self.field_arrays.pop(name) def get_fields(self): + """Return all the fields with their names. + + :return dict field_arrays: the internal data structure of DataSet. + """ return self.field_arrays - def __getitem__(self, name): - if isinstance(name, int): - return self.Instance(self, idx=name) - elif isinstance(name, slice): - ds = DataSet() + def __getitem__(self, idx): + """ + + :param idx: can be int, slice, or str. + :return: If `idx` is int, return an Instance object. + If `idx` is slice, return a DataSet object. + If `idx` is str, it must be a field name, return the field. + + """ + if isinstance(idx, int): + return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays}) + elif isinstance(idx, slice): + data_set = DataSet() for field in self.field_arrays.values(): - ds.add_field(name=field.name, - fields=field.content[name], - padding_val=field.padding_val, - need_tensor=field.need_tensor, - is_target=field.is_target) - return ds - - elif isinstance(name, str): - return self.field_arrays[name] + data_set.add_field(name=field.name, + fields=field.content[idx], + padding_val=field.padding_val, + is_input=field.is_input, + is_target=field.is_target) + return data_set + elif isinstance(idx, str): + return self.field_arrays[idx] else: - raise KeyError + raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) def __len__(self): if len(self.field_arrays) == 0: @@ -208,6 +224,7 @@ class DataSet(object): pass try: reader_cls = _READERS[item] + # add read_*data() support def _read(*args, **kwargs): data = reader_cls().load(*args, **kwargs) @@ -231,6 +248,12 @@ class DataSet(object): return wrapper def apply(self, func, new_field_name=None): + """Apply a function to every instance of the DataSet. + + :param func: a function that takes an instance as input. + :param str new_field_name: If not None, results of the function will be stored as a new field. + :return results: returned values of the function over all instances. + """ results = [] for ins in self: results.append(func(ins)) @@ -247,28 +270,24 @@ class DataSet(object): else: return results - def split(self, test_ratio): - assert isinstance(test_ratio, float) + def split(self, dev_ratio): + """Split the dataset into training and development(validation) set. + + :param float dev_ratio: the ratio of test set in all data. + :return DataSet train_set: the training set + DataSet dev_set: the development set + """ + assert isinstance(dev_ratio, float) + assert 0 < dev_ratio < 1 all_indices = [_ for _ in range(len(self))] np.random.shuffle(all_indices) - test_indices = all_indices[:int(test_ratio)] - train_indices = all_indices[int(test_ratio):] - test_set = DataSet() + split = int(dev_ratio * len(self)) + dev_indices = all_indices[:split] + train_indices = all_indices[split:] + dev_set = DataSet() train_set = DataSet() - for idx in test_indices: - test_set.append(self[idx]) + for idx in dev_indices: + dev_set.append(self[idx]) for idx in train_indices: train_set.append(self[idx]) - return train_set, test_set - - -if __name__ == '__main__': - from fastNLP.core.instance import Instance - - d = DataSet({'a': list('abc')}) - _ = d.a - d.apply(lambda x: x['a']) - print(d[1]) - import copy - dd = copy.deepcopy(d) - print(dd.a) + return train_set, dev_set diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index d6ef9c1e..5495dbec 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -3,61 +3,19 @@ from collections import defaultdict import torch from fastNLP.core.batch import Batch -from fastNLP.core.metrics import Evaluator from fastNLP.core.sampler import RandomSampler -# logger = create_logger(__name__, "./train_test.log") - - class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ - def __init__(self, **kwargs): - """ - :param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" - """ + def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs): super(Tester, self).__init__() - """ - "default_args" provides default value for important settings. - The initialization arguments "kwargs" with the same key (name) will override the default value. - "kwargs" must have the same type as "default_args" on corresponding keys. - Otherwise, error will raise. - """ - default_args = {"batch_size": 8, - "use_cuda": False, - "pickle_path": "./save/", - "model_name": "dev_best_model.pkl", - "evaluator": Evaluator() - } - """ - "required_args" is the collection of arguments that users must pass to Trainer explicitly. - This is used to warn users of essential settings in the training. - Specially, "required_args" does not have default value, so they have nothing to do with "default_args". - """ - required_args = {} - - for req_key in required_args: - if req_key not in kwargs: - raise ValueError("Tester lacks argument {}".format(req_key)) - - for key in default_args: - if key in kwargs: - if isinstance(kwargs[key], type(default_args[key])): - default_args[key] = kwargs[key] - else: - msg = "Argument %s type mismatch: expected %s while get %s" % ( - key, type(default_args[key]), type(kwargs[key])) - raise ValueError(msg) - else: - # Tester doesn't care about extra arguments - pass - # print(default_args) - - self.batch_size = default_args["batch_size"] - self.pickle_path = default_args["pickle_path"] - self.use_cuda = default_args["use_cuda"] - self._evaluator = default_args["evaluator"] + + self.batch_size = batch_size + self.pickle_path = save_path + self.use_cuda = use_cuda + self._evaluator = evaluator self._model = None self.eval_history = [] # evaluation results of all batches @@ -72,7 +30,7 @@ class Tester(object): self.mode(network, is_test=True) self.eval_history.clear() output, truths = defaultdict(list), defaultdict(list) - data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) + data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False) with torch.no_grad(): for batch_x, batch_y in data_iterator: diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index eb727317..063de676 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -15,6 +15,8 @@ from fastNLP.core.optimizer import Optimizer from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester +from fastNLP.core.utils import _build_args +from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _build_args @@ -78,7 +80,7 @@ class Trainer(object): epoch = 1 while epoch <= self.n_epochs: - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler()) + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) @@ -207,9 +209,9 @@ def best_eval_result(self, metrics): DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 -IGNORE_CHECK_LEVEL=0 -WARNING_CHECK_LEVEL=1 -STRICT_CHECK_LEVEL=2 +IGNORE_CHECK_LEVEL = 0 +WARNING_CHECK_LEVEL = 1 +STRICT_CHECK_LEVEL = 2 def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): # check get_loss 方法 @@ -220,11 +222,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_count, (batch_x, batch_y) in enumerate(batch): - _syn_model_data(model, batch_x, batch_y) - # forward check - if batch_count==0: - _check_forward_error(model=model, model_func=model.forward, check_level=check_level, - batch_x=batch_x) + if batch_count == 0: + check_res = _check_arg_dict_list(model.forward, batch_x) + _info_str = '' + if len(check_res.missing) > 0: + if check_level == WARNING_CHECK_LEVEL: + for field_name in check_res.missing: + if hasattr(dataset, field_name): + _info_str += "{} " + _info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" + _info_str += "" + print("") + if len(check_res.unused) > 0: + if check_level == WARNING_CHECK_LEVEL: + _info_str += "" refined_batch_x = _build_args(model.forward, **batch_x) output = model(**refined_batch_x) @@ -233,10 +244,14 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No # loss check if batch_count == 0: - _check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level, - output=output, batch_y=batch_y) - loss_input = _build_args(model.get_loss, **output, **batch_y) - loss = model.get_loss(**loss_input) + _dict = _check_arg_dict_list(model.loss, [output, batch_y]) + if len(_dict) != 0: + pass + loss_input = _build_args(model.loss, **output, **batch_y) + loss = model.loss(**loss_input) + if batch_count == 0: + if isinstance(loss, torch.Tensor): + pass # check loss output if batch_count == 0: @@ -248,8 +263,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No model_name, loss.size() )) loss.backward() - model.zero_grad() - if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: + if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE: break if check_level > IGNORE_CHECK_LEVEL: print('Finish checking training process.', flush=True) @@ -407,14 +421,7 @@ if __name__ == '__main__': # trainer = Trainer(dataset, model) - _check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2) - - # _check_forward_error(model=model, model_func=model.forward, check_level=1, - # batch_x=fake_data_dict) - - # import inspect - # print(inspect.getfullargspec(model.forward)) - - - - + if len(_dict) != 0: + pass + refined_batch_x = _build_args(model.forward, **batch_x) + output = model(**refined_batch_x) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 84ed11e6..d816136e 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -1,8 +1,8 @@ import _pickle -import os import inspect -from collections import namedtuple +import os from collections import Counter +from collections import namedtuple CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)