@@ -9,20 +9,17 @@ class Batch(object): | |||
""" | |||
def __init__(self, dataset, batch_size, sampler, as_numpy=False, use_cuda=False): | |||
def __init__(self, dataset, batch_size, sampler, as_numpy=False,): | |||
""" | |||
:param dataset: a DataSet object | |||
:param batch_size: int, the size of the batch | |||
:param sampler: a Sampler object | |||
:param use_cuda: bool, whether to use GPU | |||
""" | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
self.sampler = sampler | |||
self.as_numpy = as_numpy | |||
self.use_cuda = use_cuda | |||
self.idx_list = None | |||
self.curidx = 0 | |||
@@ -53,15 +50,13 @@ class Batch(object): | |||
indices = self.idx_list[self.curidx:endidx] | |||
for field_name, field in self.dataset.get_fields().items(): | |||
if field.need_tensor: | |||
if field.is_target or field.is_input: | |||
batch = field.get(indices) | |||
if not self.as_numpy: | |||
batch = torch.from_numpy(batch) | |||
if self.use_cuda: | |||
batch = batch.cuda() | |||
if field.is_target: | |||
batch_y[field_name] = batch | |||
else: | |||
if field.is_input: | |||
batch_x[field_name] = batch | |||
self.curidx = endidx | |||
@@ -189,26 +189,15 @@ class DataSet(object): | |||
self.field_arrays[name].is_target = val | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
self._set_need_tensor(**fields) | |||
return self | |||
def set_input(self, **fields): | |||
for name, val in fields.items(): | |||
if name in self.field_arrays: | |||
assert isinstance(val, bool) | |||
self.field_arrays[name].is_target = not val | |||
self.field_arrays[name].is_input = val | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
self._set_need_tensor(**fields) | |||
return self | |||
def _set_need_tensor(self, **kwargs): | |||
for name, val in kwargs.items(): | |||
if name in self.field_arrays: | |||
assert isinstance(val, bool) | |||
self.field_arrays[name].need_tensor = val | |||
else: | |||
raise KeyError | |||
return self | |||
def __getattr__(self, item): | |||
@@ -2,12 +2,12 @@ import numpy as np | |||
class FieldArray(object): | |||
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False): | |||
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | |||
self.name = name | |||
self.content = content | |||
self.padding_val = padding_val | |||
self.is_target = is_target | |||
self.need_tensor = need_tensor | |||
self.is_input = is_input | |||
self.dtype = None | |||
def __repr__(self): | |||
@@ -27,7 +27,7 @@ class FieldArray(object): | |||
def get(self, idxes): | |||
if isinstance(idxes, int): | |||
return self.content[idxes] | |||
assert self.need_tensor is True | |||
assert self.is_input is True or self.is_target is True | |||
batch_size = len(idxes) | |||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | |||
if isinstance(self.content[0], int) or isinstance(self.content[0], float): | |||
@@ -9,6 +9,7 @@ from fastNLP.core.loss import Loss | |||
from fastNLP.core.metrics import Evaluator | |||
from fastNLP.core.optimizer import Optimizer | |||
from fastNLP.core.sampler import RandomSampler | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.tester import Tester | |||
@@ -194,3 +195,77 @@ def best_eval_result(self, metrics): | |||
return True | |||
else: | |||
return False | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import _build_args | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
DEFAULT_CHECK_NUM_BATCH = 2 | |||
IGNORE_CHECK_LEVEL=0 | |||
WARNING_CHECK_LEVEL=1 | |||
STRICT_CHECK_LEVEL=2 | |||
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): | |||
# check loss 方法 | |||
if not hasattr(model, 'get_loss'): | |||
raise AttributeError("{} has to have a 'get_loss' function.".format(type(model))) | |||
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
if batch_count==0: | |||
check_res = _check_arg_dict_list(model.forward, batch_x) | |||
_info_str = '' | |||
if len(check_res.missing)>0: | |||
if check_level == WARNING_CHECK_LEVEL: | |||
for field_name in check_res.missing: | |||
if hasattr(dataset, field_name): | |||
_info_str += "{} " | |||
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" | |||
_info_str += "" | |||
print("") | |||
if len(check_res.unused)>0: | |||
if check_level == WARNING_CHECK_LEVEL: | |||
_info_str += "" | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
output = model(**refined_batch_x) | |||
if batch_count == 0: | |||
_dict = _check_arg_dict_list(model.loss, [output, batch_y]) | |||
if len(_dict)!=0: | |||
pass | |||
loss_input = _build_args(model.loss, **output, **batch_y) | |||
loss = model.loss(**loss_input) | |||
if batch_count == 0: | |||
if isinstance(loss, torch.Tensor): | |||
pass | |||
loss.backward() | |||
if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE: | |||
break | |||
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
if dev_data is not None: | |||
if not hasattr(model, 'evaluate'): | |||
raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set" | |||
"dev_data to 'None'." | |||
.format(type(model), type(model))) | |||
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): | |||
if batch_count == 0: | |||
_dict = _check_arg_dict_list(model.evaluate, [output, batch_y]) | |||
if len(_dict)!=0: | |||
pass | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
output = model(**refined_batch_x) | |||
@@ -1,6 +1,11 @@ | |||
import _pickle | |||
import os | |||
import inspect | |||
from collections import namedtuple | |||
from collections import Counter | |||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True) | |||
def save_pickle(obj, pickle_path, file_name): | |||
"""Save an object into a pickle file. | |||
@@ -45,7 +50,7 @@ def pickle_exist(pickle_path, pickle_name): | |||
else: | |||
return False | |||
def build_args(func, **kwargs): | |||
def _build_args(func, **kwargs): | |||
spect = inspect.getfullargspec(func) | |||
if spect.varkw is not None: | |||
return kwargs | |||
@@ -55,11 +60,9 @@ def build_args(func, **kwargs): | |||
output.update({name: val for name, val in kwargs.items() if name in needed_args}) | |||
return output | |||
from collections import namedtuple, Counter | |||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated'], verbose=True) | |||
# check args | |||
def check_arg_dict_list(func, args): | |||
def _check_arg_dict_list(func, args): | |||
if isinstance(args, dict): | |||
arg_dict_list = [args] | |||
else: | |||
@@ -60,7 +60,6 @@ class Vocabulary(object): | |||
""" | |||
self.word_count.update(word_lst) | |||
def add(self, word): | |||
self.word_count[word] += 1 | |||