@@ -26,6 +26,7 @@ class Batch(object): | |||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = None | self.idx_list = None | ||||
self.curidx = 0 | self.curidx = 0 | ||||
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0) | |||||
def __iter__(self): | def __iter__(self): | ||||
self.idx_list = self.sampler(self.dataset) | self.idx_list = self.sampler(self.dataset) | ||||
@@ -56,6 +57,9 @@ class Batch(object): | |||||
return batch_x, batch_y | return batch_x, batch_y | ||||
def __len__(self): | |||||
return self.num_batches | |||||
def to_tensor(batch, dtype): | def to_tensor(batch, dtype): | ||||
if dtype in (np.int8, np.int16, np.int32, np.int64): | if dtype in (np.int8, np.int16, np.int32, np.int64): | ||||
batch = torch.LongTensor(batch) | batch = torch.LongTensor(batch) | ||||
@@ -168,6 +168,7 @@ class DataSet(object): | |||||
""" | """ | ||||
if old_name in self.field_arrays: | if old_name in self.field_arrays: | ||||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | self.field_arrays[new_name] = self.field_arrays.pop(old_name) | ||||
self.field_arrays[new_name].name = new_name | |||||
else: | else: | ||||
raise KeyError("{} is not a valid name. ".format(old_name)) | raise KeyError("{} is not a valid name. ".format(old_name)) | ||||
@@ -213,12 +214,12 @@ class DataSet(object): | |||||
return wrapper | return wrapper | ||||
def apply(self, func, new_field_name=None): | |||||
def apply(self, func, new_field_name=None, is_input=False, is_target=False): | |||||
"""Apply a function to every instance of the DataSet. | """Apply a function to every instance of the DataSet. | ||||
:param func: a function that takes an instance as input. | :param func: a function that takes an instance as input. | ||||
:param str new_field_name: If not None, results of the function will be stored as a new field. | :param str new_field_name: If not None, results of the function will be stored as a new field. | ||||
:return results: returned values of the function over all instances. | |||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||||
""" | """ | ||||
results = [func(ins) for ins in self] | results = [func(ins) for ins in self] | ||||
if new_field_name is not None: | if new_field_name is not None: | ||||
@@ -231,7 +232,7 @@ class DataSet(object): | |||||
is_input=old_field.is_input, | is_input=old_field.is_input, | ||||
is_target=old_field.is_target) | is_target=old_field.is_target) | ||||
else: | else: | ||||
self.add_field(name=new_field_name, fields=results) | |||||
self.add_field(name=new_field_name, fields=results, is_input=is_input, is_target=is_target) | |||||
else: | else: | ||||
return results | return results | ||||
@@ -245,7 +245,7 @@ class AccuracyMetric(MetricBase): | |||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
evaluate_result = {'acc': self.acc_count/self.total} | |||||
evaluate_result = {'acc': round(self.acc_count/self.total, 6)} | |||||
if reset: | if reset: | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
self.total = 0 | self.total = 0 | ||||
@@ -17,7 +17,7 @@ from fastNLP.core.utils import get_func_signature | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | ||||
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0): | |||||
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=1): | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
@@ -76,7 +76,7 @@ class Tester(object): | |||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | ||||
check_res=e.check_res, output=output, batch_y=truths, check_level=0) | check_res=e.check_res, output=output, batch_y=truths, check_level=0) | ||||
if self.verbose >= 0: | |||||
if self.verbose >= 1: | |||||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | print("[tester] \n{}".format(self._format_eval_results(eval_results))) | ||||
self._mode(network, is_test=False) | self._mode(network, is_test=False) | ||||
return eval_results | return eval_results | ||||
@@ -107,7 +107,7 @@ class Tester(object): | |||||
""" | """ | ||||
_str = '' | _str = '' | ||||
for metric_name, metric_result in results.items(): | for metric_name, metric_result in results.items(): | ||||
_str += metric_name + '\n\t' | |||||
_str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||||
_str += '\n' | |||||
return _str | |||||
_str += metric_name + ': ' | |||||
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) | |||||
_str += '\n' | |||||
return _str[:-1] |
@@ -28,9 +28,9 @@ class Trainer(object): | |||||
""" | """ | ||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, | |||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
validate_every=-1, | validate_every=-1, | ||||
dev_data=None, use_cuda=False, save_path="./save", | |||||
dev_data=None, use_cuda=False, save_path=None, | |||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | ||||
metric_key=None, | metric_key=None, | ||||
**kwargs): | **kwargs): | ||||
@@ -307,8 +307,8 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | ||||
# forward check | # forward check | ||||
if batch_count==0: | if batch_count==0: | ||||
_check_forward_error(forward_func=model.forward, check_level=check_level, | |||||
batch_x=batch_x) | |||||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||||
batch_x=batch_x, check_level=check_level) | |||||
refined_batch_x = _build_args(model.forward, **batch_x) | refined_batch_x = _build_args(model.forward, **batch_x) | ||||
output = model(**refined_batch_x) | output = model(**refined_batch_x) | ||||
@@ -207,7 +207,7 @@ class CheckError(Exception): | |||||
CheckError. Used in losses.LossBase, metrics.MetricBase. | CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
def __init__(self, check_res:CheckRes, func_signature:str): | def __init__(self, check_res:CheckRes, func_signature:str): | ||||
errs = [f'The following problems occurred when calling {func_signature}'] | |||||
errs = [f'The following problems occurred when calling `{func_signature}`'] | |||||
if check_res.varargs: | if check_res.varargs: | ||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | ||||
@@ -255,7 +255,7 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: | |||||
warnings.warn(message=_unused_warn) | warnings.warn(message=_unused_warn) | ||||
def _check_forward_error(forward_func, batch_x, check_level): | |||||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
check_res = _check_arg_dict_list(forward_func, batch_x) | check_res = _check_arg_dict_list(forward_func, batch_x) | ||||
func_signature = get_func_signature(forward_func) | func_signature = get_func_signature(forward_func) | ||||