@@ -26,6 +26,7 @@ class Batch(object): | |||
self.as_numpy = as_numpy | |||
self.idx_list = None | |||
self.curidx = 0 | |||
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0) | |||
def __iter__(self): | |||
self.idx_list = self.sampler(self.dataset) | |||
@@ -56,6 +57,9 @@ class Batch(object): | |||
return batch_x, batch_y | |||
def __len__(self): | |||
return self.num_batches | |||
def to_tensor(batch, dtype): | |||
if dtype in (np.int8, np.int16, np.int32, np.int64): | |||
batch = torch.LongTensor(batch) | |||
@@ -168,6 +168,7 @@ class DataSet(object): | |||
""" | |||
if old_name in self.field_arrays: | |||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||
self.field_arrays[new_name].name = new_name | |||
else: | |||
raise KeyError("{} is not a valid name. ".format(old_name)) | |||
@@ -213,12 +214,12 @@ class DataSet(object): | |||
return wrapper | |||
def apply(self, func, new_field_name=None): | |||
def apply(self, func, new_field_name=None, is_input=False, is_target=False): | |||
"""Apply a function to every instance of the DataSet. | |||
:param func: a function that takes an instance as input. | |||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||
:return results: returned values of the function over all instances. | |||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||
""" | |||
results = [func(ins) for ins in self] | |||
if new_field_name is not None: | |||
@@ -231,7 +232,7 @@ class DataSet(object): | |||
is_input=old_field.is_input, | |||
is_target=old_field.is_target) | |||
else: | |||
self.add_field(name=new_field_name, fields=results) | |||
self.add_field(name=new_field_name, fields=results, is_input=is_input, is_target=is_target) | |||
else: | |||
return results | |||
@@ -245,7 +245,7 @@ class AccuracyMetric(MetricBase): | |||
self.total += np.prod(list(pred.size())) | |||
def get_metric(self, reset=True): | |||
evaluate_result = {'acc': self.acc_count/self.total} | |||
evaluate_result = {'acc': round(self.acc_count/self.total, 6)} | |||
if reset: | |||
self.acc_count = 0 | |||
self.total = 0 | |||
@@ -17,7 +17,7 @@ from fastNLP.core.utils import get_func_signature | |||
class Tester(object): | |||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0): | |||
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=1): | |||
super(Tester, self).__init__() | |||
if not isinstance(data, DataSet): | |||
@@ -76,7 +76,7 @@ class Tester(object): | |||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||
check_res=e.check_res, output=output, batch_y=truths, check_level=0) | |||
if self.verbose >= 0: | |||
if self.verbose >= 1: | |||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||
self._mode(network, is_test=False) | |||
return eval_results | |||
@@ -107,7 +107,7 @@ class Tester(object): | |||
""" | |||
_str = '' | |||
for metric_name, metric_result in results.items(): | |||
_str += metric_name + '\n\t' | |||
_str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||
_str += '\n' | |||
return _str | |||
_str += metric_name + ': ' | |||
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) | |||
_str += '\n' | |||
return _str[:-1] |
@@ -28,9 +28,9 @@ class Trainer(object): | |||
""" | |||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, | |||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||
validate_every=-1, | |||
dev_data=None, use_cuda=False, save_path="./save", | |||
dev_data=None, use_cuda=False, save_path=None, | |||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||
metric_key=None, | |||
**kwargs): | |||
@@ -307,8 +307,8 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
# forward check | |||
if batch_count==0: | |||
_check_forward_error(forward_func=model.forward, check_level=check_level, | |||
batch_x=batch_x) | |||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||
batch_x=batch_x, check_level=check_level) | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
output = model(**refined_batch_x) | |||
@@ -207,7 +207,7 @@ class CheckError(Exception): | |||
CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
""" | |||
def __init__(self, check_res:CheckRes, func_signature:str): | |||
errs = [f'The following problems occurred when calling {func_signature}'] | |||
errs = [f'The following problems occurred when calling `{func_signature}`'] | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||
@@ -255,7 +255,7 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: | |||
warnings.warn(message=_unused_warn) | |||
def _check_forward_error(forward_func, batch_x, check_level): | |||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||
func_signature = get_func_signature(forward_func) | |||