Browse Source

修改dataset.py的apply signature; batch当中增加num_batches属性; tester的format_eval_results修改; metric增加fast_evaluate_call机制

tags/v0.2.0^2
yh 5 years ago
parent
commit
88949ba1da
6 changed files with 21 additions and 16 deletions
  1. +4
    -0
      fastNLP/core/batch.py
  2. +4
    -3
      fastNLP/core/dataset.py
  3. +1
    -1
      fastNLP/core/metrics.py
  4. +6
    -6
      fastNLP/core/tester.py
  5. +4
    -4
      fastNLP/core/trainer.py
  6. +2
    -2
      fastNLP/core/utils.py

+ 4
- 0
fastNLP/core/batch.py View File

@@ -26,6 +26,7 @@ class Batch(object):
self.as_numpy = as_numpy
self.idx_list = None
self.curidx = 0
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0)

def __iter__(self):
self.idx_list = self.sampler(self.dataset)
@@ -56,6 +57,9 @@ class Batch(object):

return batch_x, batch_y

def __len__(self):
return self.num_batches

def to_tensor(batch, dtype):
if dtype in (np.int8, np.int16, np.int32, np.int64):
batch = torch.LongTensor(batch)


+ 4
- 3
fastNLP/core/dataset.py View File

@@ -168,6 +168,7 @@ class DataSet(object):
"""
if old_name in self.field_arrays:
self.field_arrays[new_name] = self.field_arrays.pop(old_name)
self.field_arrays[new_name].name = new_name
else:
raise KeyError("{} is not a valid name. ".format(old_name))

@@ -213,12 +214,12 @@ class DataSet(object):

return wrapper

def apply(self, func, new_field_name=None):
def apply(self, func, new_field_name=None, is_input=False, is_target=False):
"""Apply a function to every instance of the DataSet.

:param func: a function that takes an instance as input.
:param str new_field_name: If not None, results of the function will be stored as a new field.
:return results: returned values of the function over all instances.
:return results: if new_field_name is not passed, returned values of the function over all instances.
"""
results = [func(ins) for ins in self]
if new_field_name is not None:
@@ -231,7 +232,7 @@ class DataSet(object):
is_input=old_field.is_input,
is_target=old_field.is_target)
else:
self.add_field(name=new_field_name, fields=results)
self.add_field(name=new_field_name, fields=results, is_input=is_input, is_target=is_target)
else:
return results



+ 1
- 1
fastNLP/core/metrics.py View File

@@ -245,7 +245,7 @@ class AccuracyMetric(MetricBase):
self.total += np.prod(list(pred.size()))

def get_metric(self, reset=True):
evaluate_result = {'acc': self.acc_count/self.total}
evaluate_result = {'acc': round(self.acc_count/self.total, 6)}
if reset:
self.acc_count = 0
self.total = 0


+ 6
- 6
fastNLP/core/tester.py View File

@@ -17,7 +17,7 @@ from fastNLP.core.utils import get_func_signature
class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """

def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0):
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=1):
super(Tester, self).__init__()

if not isinstance(data, DataSet):
@@ -76,7 +76,7 @@ class Tester(object):
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=truths, check_level=0)

if self.verbose >= 0:
if self.verbose >= 1:
print("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False)
return eval_results
@@ -107,7 +107,7 @@ class Tester(object):
"""
_str = ''
for metric_name, metric_result in results.items():
_str += metric_name + '\n\t'
_str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
_str += '\n'
return _str
_str += metric_name + ': '
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()])
_str += '\n'
return _str[:-1]

+ 4
- 4
fastNLP/core/trainer.py View File

@@ -28,9 +28,9 @@ class Trainer(object):

"""

def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1,
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None,
**kwargs):
@@ -307,8 +307,8 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
# forward check
if batch_count==0:
_check_forward_error(forward_func=model.forward, check_level=check_level,
batch_x=batch_x)
_check_forward_error(forward_func=model.forward, dataset=dataset,
batch_x=batch_x, check_level=check_level)

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)


+ 2
- 2
fastNLP/core/utils.py View File

@@ -207,7 +207,7 @@ class CheckError(Exception):
CheckError. Used in losses.LossBase, metrics.MetricBase.
"""
def __init__(self, check_res:CheckRes, func_signature:str):
errs = [f'The following problems occurred when calling {func_signature}']
errs = [f'The following problems occurred when calling `{func_signature}`']

if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
@@ -255,7 +255,7 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
warnings.warn(message=_unused_warn)


def _check_forward_error(forward_func, batch_x, check_level):
def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = get_func_signature(forward_func)



Loading…
Cancel
Save