diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index ce7e25c0..d8c61047 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -9,20 +9,17 @@ class Batch(object): """ - def __init__(self, dataset, batch_size, sampler, as_numpy=False, use_cuda=False): + def __init__(self, dataset, batch_size, sampler, as_numpy=False,): """ :param dataset: a DataSet object :param batch_size: int, the size of the batch :param sampler: a Sampler object - :param use_cuda: bool, whether to use GPU - """ self.dataset = dataset self.batch_size = batch_size self.sampler = sampler self.as_numpy = as_numpy - self.use_cuda = use_cuda self.idx_list = None self.curidx = 0 @@ -53,15 +50,13 @@ class Batch(object): indices = self.idx_list[self.curidx:endidx] for field_name, field in self.dataset.get_fields().items(): - if field.need_tensor: + if field.is_target or field.is_input: batch = field.get(indices) if not self.as_numpy: batch = torch.from_numpy(batch) - if self.use_cuda: - batch = batch.cuda() if field.is_target: batch_y[field_name] = batch - else: + if field.is_input: batch_x[field_name] = batch self.curidx = endidx diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 32f109e4..39af672c 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -189,26 +189,15 @@ class DataSet(object): self.field_arrays[name].is_target = val else: raise KeyError("{} is not a valid field name.".format(name)) - self._set_need_tensor(**fields) return self def set_input(self, **fields): for name, val in fields.items(): if name in self.field_arrays: assert isinstance(val, bool) - self.field_arrays[name].is_target = not val + self.field_arrays[name].is_input = val else: raise KeyError("{} is not a valid field name.".format(name)) - self._set_need_tensor(**fields) - return self - - def _set_need_tensor(self, **kwargs): - for name, val in kwargs.items(): - if name in self.field_arrays: - assert isinstance(val, bool) - self.field_arrays[name].need_tensor = val - else: - raise KeyError return self def __getattr__(self, item): diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 7ead3a64..473738b0 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -2,12 +2,12 @@ import numpy as np class FieldArray(object): - def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False): + def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): self.name = name self.content = content self.padding_val = padding_val self.is_target = is_target - self.need_tensor = need_tensor + self.is_input = is_input self.dtype = None def __repr__(self): @@ -27,7 +27,7 @@ class FieldArray(object): def get(self, idxes): if isinstance(idxes, int): return self.content[idxes] - assert self.need_tensor is True + assert self.is_input is True or self.is_target is True batch_size = len(idxes) # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 if isinstance(self.content[0], int) or isinstance(self.content[0], float): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b4f11090..9538d3fc 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -9,6 +9,7 @@ from fastNLP.core.loss import Loss from fastNLP.core.metrics import Evaluator from fastNLP.core.optimizer import Optimizer from fastNLP.core.sampler import RandomSampler +from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester @@ -194,3 +195,77 @@ def best_eval_result(self, metrics): return True else: return False + + +from fastNLP.core.utils import _check_arg_dict_list +from fastNLP.core.utils import _build_args + +DEFAULT_CHECK_BATCH_SIZE = 2 +DEFAULT_CHECK_NUM_BATCH = 2 + +IGNORE_CHECK_LEVEL=0 +WARNING_CHECK_LEVEL=1 +STRICT_CHECK_LEVEL=2 + + +def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): + # check loss 方法 + if not hasattr(model, 'get_loss'): + raise AttributeError("{} has to have a 'get_loss' function.".format(type(model))) + + batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) + batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) + for batch_count, (batch_x, batch_y) in enumerate(batch): + if batch_count==0: + check_res = _check_arg_dict_list(model.forward, batch_x) + _info_str = '' + if len(check_res.missing)>0: + if check_level == WARNING_CHECK_LEVEL: + for field_name in check_res.missing: + if hasattr(dataset, field_name): + _info_str += "{} " + _info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" + _info_str += "" + print("") + if len(check_res.unused)>0: + if check_level == WARNING_CHECK_LEVEL: + _info_str += "" + + refined_batch_x = _build_args(model.forward, **batch_x) + output = model(**refined_batch_x) + if batch_count == 0: + _dict = _check_arg_dict_list(model.loss, [output, batch_y]) + if len(_dict)!=0: + pass + loss_input = _build_args(model.loss, **output, **batch_y) + loss = model.loss(**loss_input) + if batch_count == 0: + if isinstance(loss, torch.Tensor): + pass + + loss.backward() + + if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE: + break + + dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) + if dev_data is not None: + if not hasattr(model, 'evaluate'): + raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set" + "dev_data to 'None'." + .format(type(model), type(model))) + + for batch_count, (batch_x, batch_y) in enumerate(dev_batch): + if batch_count == 0: + _dict = _check_arg_dict_list(model.evaluate, [output, batch_y]) + + if len(_dict)!=0: + pass + refined_batch_x = _build_args(model.forward, **batch_x) + output = model(**refined_batch_x) + + + + + + diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index b672be77..6a284ab9 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -1,6 +1,11 @@ import _pickle import os import inspect +from collections import namedtuple +from collections import Counter + +CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True) + def save_pickle(obj, pickle_path, file_name): """Save an object into a pickle file. @@ -45,7 +50,7 @@ def pickle_exist(pickle_path, pickle_name): else: return False -def build_args(func, **kwargs): +def _build_args(func, **kwargs): spect = inspect.getfullargspec(func) if spect.varkw is not None: return kwargs @@ -55,11 +60,9 @@ def build_args(func, **kwargs): output.update({name: val for name, val in kwargs.items() if name in needed_args}) return output -from collections import namedtuple, Counter -CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True) # check args -def check_arg_dict_list(func, args): +def _check_arg_dict_list(func, args): if isinstance(args, dict): arg_dict_list = [args] else: diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 55a1e3f8..a9370be5 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -60,7 +60,6 @@ class Vocabulary(object): """ self.word_count.update(word_lst) - def add(self, word): self.word_count[word] += 1