@@ -9,20 +9,17 @@ class Batch(object): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler, as_numpy=False, use_cuda=False): | |||||
def __init__(self, dataset, batch_size, sampler, as_numpy=False,): | |||||
""" | """ | ||||
:param dataset: a DataSet object | :param dataset: a DataSet object | ||||
:param batch_size: int, the size of the batch | :param batch_size: int, the size of the batch | ||||
:param sampler: a Sampler object | :param sampler: a Sampler object | ||||
:param use_cuda: bool, whether to use GPU | |||||
""" | """ | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.sampler = sampler | self.sampler = sampler | ||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.use_cuda = use_cuda | |||||
self.idx_list = None | self.idx_list = None | ||||
self.curidx = 0 | self.curidx = 0 | ||||
@@ -53,15 +50,13 @@ class Batch(object): | |||||
indices = self.idx_list[self.curidx:endidx] | indices = self.idx_list[self.curidx:endidx] | ||||
for field_name, field in self.dataset.get_fields().items(): | for field_name, field in self.dataset.get_fields().items(): | ||||
if field.need_tensor: | |||||
if field.is_target or field.is_input: | |||||
batch = field.get(indices) | batch = field.get(indices) | ||||
if not self.as_numpy: | if not self.as_numpy: | ||||
batch = torch.from_numpy(batch) | batch = torch.from_numpy(batch) | ||||
if self.use_cuda: | |||||
batch = batch.cuda() | |||||
if field.is_target: | if field.is_target: | ||||
batch_y[field_name] = batch | batch_y[field_name] = batch | ||||
else: | |||||
if field.is_input: | |||||
batch_x[field_name] = batch | batch_x[field_name] = batch | ||||
self.curidx = endidx | self.curidx = endidx | ||||
@@ -189,26 +189,15 @@ class DataSet(object): | |||||
self.field_arrays[name].is_target = val | self.field_arrays[name].is_target = val | ||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
self._set_need_tensor(**fields) | |||||
return self | return self | ||||
def set_input(self, **fields): | def set_input(self, **fields): | ||||
for name, val in fields.items(): | for name, val in fields.items(): | ||||
if name in self.field_arrays: | if name in self.field_arrays: | ||||
assert isinstance(val, bool) | assert isinstance(val, bool) | ||||
self.field_arrays[name].is_target = not val | |||||
self.field_arrays[name].is_input = val | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
self._set_need_tensor(**fields) | |||||
return self | |||||
def _set_need_tensor(self, **kwargs): | |||||
for name, val in kwargs.items(): | |||||
if name in self.field_arrays: | |||||
assert isinstance(val, bool) | |||||
self.field_arrays[name].need_tensor = val | |||||
else: | |||||
raise KeyError | |||||
return self | return self | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
@@ -2,12 +2,12 @@ import numpy as np | |||||
class FieldArray(object): | class FieldArray(object): | ||||
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False): | |||||
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | |||||
self.name = name | self.name = name | ||||
self.content = content | self.content = content | ||||
self.padding_val = padding_val | self.padding_val = padding_val | ||||
self.is_target = is_target | self.is_target = is_target | ||||
self.need_tensor = need_tensor | |||||
self.is_input = is_input | |||||
self.dtype = None | self.dtype = None | ||||
def __repr__(self): | def __repr__(self): | ||||
@@ -27,7 +27,7 @@ class FieldArray(object): | |||||
def get(self, idxes): | def get(self, idxes): | ||||
if isinstance(idxes, int): | if isinstance(idxes, int): | ||||
return self.content[idxes] | return self.content[idxes] | ||||
assert self.need_tensor is True | |||||
assert self.is_input is True or self.is_target is True | |||||
batch_size = len(idxes) | batch_size = len(idxes) | ||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | ||||
if isinstance(self.content[0], int) or isinstance(self.content[0], float): | if isinstance(self.content[0], int) or isinstance(self.content[0], float): | ||||
@@ -9,6 +9,7 @@ from fastNLP.core.loss import Loss | |||||
from fastNLP.core.metrics import Evaluator | from fastNLP.core.metrics import Evaluator | ||||
from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
@@ -194,3 +195,77 @@ def best_eval_result(self, metrics): | |||||
return True | return True | ||||
else: | else: | ||||
return False | return False | ||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _build_args | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||||
DEFAULT_CHECK_NUM_BATCH = 2 | |||||
IGNORE_CHECK_LEVEL=0 | |||||
WARNING_CHECK_LEVEL=1 | |||||
STRICT_CHECK_LEVEL=2 | |||||
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): | |||||
# check loss 方法 | |||||
if not hasattr(model, 'get_loss'): | |||||
raise AttributeError("{} has to have a 'get_loss' function.".format(type(model))) | |||||
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) | |||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||||
if batch_count==0: | |||||
check_res = _check_arg_dict_list(model.forward, batch_x) | |||||
_info_str = '' | |||||
if len(check_res.missing)>0: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
for field_name in check_res.missing: | |||||
if hasattr(dataset, field_name): | |||||
_info_str += "{} " | |||||
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" | |||||
_info_str += "" | |||||
print("") | |||||
if len(check_res.unused)>0: | |||||
if check_level == WARNING_CHECK_LEVEL: | |||||
_info_str += "" | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | |||||
output = model(**refined_batch_x) | |||||
if batch_count == 0: | |||||
_dict = _check_arg_dict_list(model.loss, [output, batch_y]) | |||||
if len(_dict)!=0: | |||||
pass | |||||
loss_input = _build_args(model.loss, **output, **batch_y) | |||||
loss = model.loss(**loss_input) | |||||
if batch_count == 0: | |||||
if isinstance(loss, torch.Tensor): | |||||
pass | |||||
loss.backward() | |||||
if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE: | |||||
break | |||||
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
if dev_data is not None: | |||||
if not hasattr(model, 'evaluate'): | |||||
raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set" | |||||
"dev_data to 'None'." | |||||
.format(type(model), type(model))) | |||||
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | |||||
if batch_count == 0: | |||||
_dict = _check_arg_dict_list(model.evaluate, [output, batch_y]) | |||||
if len(_dict)!=0: | |||||
pass | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | |||||
output = model(**refined_batch_x) | |||||
@@ -1,6 +1,11 @@ | |||||
import _pickle | import _pickle | ||||
import os | import os | ||||
import inspect | import inspect | ||||
from collections import namedtuple | |||||
from collections import Counter | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True) | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
"""Save an object into a pickle file. | """Save an object into a pickle file. | ||||
@@ -45,7 +50,7 @@ def pickle_exist(pickle_path, pickle_name): | |||||
else: | else: | ||||
return False | return False | ||||
def build_args(func, **kwargs): | |||||
def _build_args(func, **kwargs): | |||||
spect = inspect.getfullargspec(func) | spect = inspect.getfullargspec(func) | ||||
if spect.varkw is not None: | if spect.varkw is not None: | ||||
return kwargs | return kwargs | ||||
@@ -55,11 +60,9 @@ def build_args(func, **kwargs): | |||||
output.update({name: val for name, val in kwargs.items() if name in needed_args}) | output.update({name: val for name, val in kwargs.items() if name in needed_args}) | ||||
return output | return output | ||||
from collections import namedtuple, Counter | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True) | |||||
# check args | # check args | ||||
def check_arg_dict_list(func, args): | |||||
def _check_arg_dict_list(func, args): | |||||
if isinstance(args, dict): | if isinstance(args, dict): | ||||
arg_dict_list = [args] | arg_dict_list = [args] | ||||
else: | else: | ||||
@@ -60,7 +60,6 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.word_count.update(word_lst) | self.word_count.update(word_lst) | ||||
def add(self, word): | def add(self, word): | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||