Browse Source

Merge branch 'trainer' of github.com:FengZiYjun/fastNLP into trainer

# Conflicts:
#	fastNLP/core/dataset.py
#	fastNLP/core/trainer.py
#	test/core/test_trainer.py

Trainer support print_train and tqdm train.
tags/v0.2.0^2
yh 6 years ago
parent
commit
785c41ded5
16 changed files with 359 additions and 220 deletions
  1. +46
    -21
      fastNLP/core/dataset.py
  2. +2
    -3
      fastNLP/core/instance.py
  3. +28
    -15
      fastNLP/core/losses.py
  4. +35
    -32
      fastNLP/core/metrics.py
  5. +3
    -53
      fastNLP/core/optimizer.py
  6. +67
    -24
      fastNLP/core/trainer.py
  7. +74
    -15
      fastNLP/core/utils.py
  8. +3
    -3
      fastNLP/io/embed_loader.py
  9. +1
    -1
      fastNLP/modules/encoder/char_embedding.py
  10. +61
    -0
      test/core/test_dataset.py
  11. +6
    -0
      test/core/test_instance.py
  12. +16
    -24
      test/core/test_loss.py
  13. +0
    -8
      test/core/test_optimizer.py
  14. +6
    -6
      test/core/test_trainer.py
  15. +3
    -3
      test/io/test_embed_loader.py
  16. +8
    -12
      test/test_tutorial.py

+ 46
- 21
fastNLP/core/dataset.py View File

@@ -1,7 +1,9 @@
import _pickle as pickle
import numpy as np import numpy as np


from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.utils import get_func_signature


_READERS = {} _READERS = {}


@@ -26,24 +28,6 @@ class DataSet(object):
However, it stores data in a different way: Field-first, Instance-second. However, it stores data in a different way: Field-first, Instance-second.


""" """

class DataSetIter(object):
def __init__(self, data_set, idx=-1, **fields):
self.data_set = data_set
self.idx = idx
self.fields = fields

def __next__(self):
self.idx += 1
if self.idx >= len(self.data_set):
raise StopIteration
# this returns a copy
return self.data_set[self.idx]

def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name
in self.data_set.get_fields().keys()])

def __init__(self, data=None): def __init__(self, data=None):
""" """


@@ -72,7 +56,27 @@ class DataSet(object):
return item in self.field_arrays return item in self.field_arrays


def __iter__(self): def __iter__(self):
return self.DataSetIter(self)
def iter_func():
for idx in range(len(self)):
yield self[idx]
return iter_func()

def _inner_iter(self):
class Iter_ptr:
def __init__(self, dataset, idx):
self.dataset = dataset
self.idx = idx
def __getitem__(self, item):
assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx)
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx])
return self.dataset.field_arrays[item][self.idx]
def __repr__(self):
return self.dataset[self.idx].__repr__()

def inner_iter_func():
for idx in range(len(self)):
yield Iter_ptr(self, idx)
return inner_iter_func()


def __getitem__(self, idx): def __getitem__(self, idx):
"""Fetch Instance(s) at the `idx` position(s) in the dataset. """Fetch Instance(s) at the `idx` position(s) in the dataset.
@@ -110,6 +114,15 @@ class DataSet(object):
field = iter(self.field_arrays.values()).__next__() field = iter(self.field_arrays.values()).__next__()
return len(field) return len(field)


def __inner_repr__(self):
if len(self) < 20:
return ",\n".join([ins.__repr__() for ins in self])
else:
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__()

def __repr__(self):
return "DataSet(" + self.__inner_repr__() + ")"

def append(self, ins): def append(self, ins):
"""Add an instance to the DataSet. """Add an instance to the DataSet.
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet.
@@ -226,7 +239,10 @@ class DataSet(object):
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target.
:return results: if new_field_name is not passed, returned values of the function over all instances. :return results: if new_field_name is not passed, returned values of the function over all instances.
""" """
results = [func(ins) for ins in self]
results = [func(ins) for ins in self._inner_iter()]
if len(list(filter(lambda x: x is not None, results)))==0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))

extra_param = {} extra_param = {}
if 'is_input' in kwargs: if 'is_input' in kwargs:
extra_param['is_input'] = kwargs['is_input'] extra_param['is_input'] = kwargs['is_input']
@@ -250,7 +266,7 @@ class DataSet(object):
return results return results


def drop(self, func): def drop(self, func):
results = [ins for ins in self if not func(ins)]
results = [ins for ins in self._inner_iter() if not func(ins)]
for name, old_field in self.field_arrays.items(): for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results] self.field_arrays[name].content = [ins[name] for ins in results]


@@ -317,3 +333,12 @@ class DataSet(object):
for header, content in zip(headers, contents): for header, content in zip(headers, contents):
_dict[header].append(content) _dict[header].append(content)
return cls(_dict) return cls(_dict)

def save(self, path):
with open(path, 'wb') as f:
pickle.dump(self, f)

@staticmethod
def load(self, path):
with open(path, 'rb') as f:
return pickle.load(f)

+ 2
- 3
fastNLP/core/instance.py View File

@@ -1,5 +1,3 @@


class Instance(object): class Instance(object):
"""An Instance is an example of data. It is the collection of Fields. """An Instance is an example of data. It is the collection of Fields.


@@ -33,4 +31,5 @@ class Instance(object):
return self.add_field(name, field) return self.add_field(name, field)


def __repr__(self): def __repr__(self):
return self.fields.__repr__()
return "{" + ",\n".join(
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}"

+ 28
- 15
fastNLP/core/losses.py View File

@@ -70,13 +70,23 @@ class LossBase(object):
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
f"positional argument.).") f"positional argument.).")


def __call__(self, output_dict, target_dict, force_check=False):
def _fast_param_map(self, pred_dict, target_dict):
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return tuple(pred_dict.values())[0], tuple(target_dict.values())[0]
return None

def __call__(self, pred_dict, target_dict, check=False):
""" """
:param output_dict: A dict from forward function of the network.
:param pred_dict: A dict from forward function of the network.
:param target_dict: A dict from DataSet.batch_y. :param target_dict: A dict from DataSet.batch_y.
:param force_check: Boolean. Force to check the mapping functions when it is running.
:param check: Boolean. Force to check the mapping functions when it is running.
:return: :return:
""" """
fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param is not None:
loss = self.get_loss(*fast_param)
return loss

args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss)
if varargs is not None: if varargs is not None:
raise RuntimeError( raise RuntimeError(
@@ -88,7 +98,8 @@ class LossBase(object):
raise RuntimeError( raise RuntimeError(
f"There is not any param in function{get_func_signature(self.get_loss)}" f"There is not any param in function{get_func_signature(self.get_loss)}"
) )
self._checked = self._checked and not force_check

self._checked = self._checked and not check
if not self._checked: if not self._checked:
for keys in args: for keys in args:
if keys not in param_map: if keys not in param_map:
@@ -105,12 +116,12 @@ class LossBase(object):
duplicated = [] duplicated = []
missing = [] missing = []
if not self._checked: if not self._checked:
for keys, val in output_dict.items():
for keys, val in pred_dict.items():
if keys in target_dict.keys(): if keys in target_dict.keys():
duplicated.append(keys) duplicated.append(keys)


param_val_dict = {} param_val_dict = {}
for keys, val in output_dict.items():
for keys, val in pred_dict.items():
param_val_dict.update({keys: val}) param_val_dict.update({keys: val})
for keys, val in target_dict.items(): for keys, val in target_dict.items():
param_val_dict.update({keys: val}) param_val_dict.update({keys: val})
@@ -131,7 +142,6 @@ class LossBase(object):


param_map_val = _map_args(reversed_param_map, **param_val_dict) param_map_val = _map_args(reversed_param_map, **param_val_dict)
param_value = _build_args(self.get_loss, **param_map_val) param_value = _build_args(self.get_loss, **param_map_val)

loss = self.get_loss(**param_value) loss = self.get_loss(**param_value)


if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
@@ -158,29 +168,31 @@ class LossFunc(LossBase):




class CrossEntropyLoss(LossBase): class CrossEntropyLoss(LossBase):
def __init__(self, input=None, target=None):
def __init__(self, pred=None, target=None):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
self.get_loss = F.cross_entropy self.get_loss = F.cross_entropy
self._init_param_map(input=input, target=target)
self._init_param_map(input=pred, target=target)




class L1Loss(LossBase): class L1Loss(LossBase):
def __init__(self):
def __init__(self, pred=None, target=None):
super(L1Loss, self).__init__() super(L1Loss, self).__init__()
self.get_loss = F.l1_loss self.get_loss = F.l1_loss
self._init_param_map(input=pred, target=target)




class BCELoss(LossBase): class BCELoss(LossBase):
def __init__(self, input=None, target=None):
def __init__(self, pred=None, target=None):
super(BCELoss, self).__init__() super(BCELoss, self).__init__()
self.get_loss = F.binary_cross_entropy self.get_loss = F.binary_cross_entropy
self._init_param_map(input=input, target=target)
self._init_param_map(input=pred, target=target)




class NLLLoss(LossBase): class NLLLoss(LossBase):
def __init__(self):
def __init__(self, pred=None, target=None):
super(NLLLoss, self).__init__() super(NLLLoss, self).__init__()
self.get_loss = F.nll_loss self.get_loss = F.nll_loss
self._init_param_map(input=pred, target=target)




class LossInForward(LossBase): class LossInForward(LossBase):
@@ -199,10 +211,11 @@ class LossInForward(LossBase):
all_needed=[], all_needed=[],
varargs=[]) varargs=[])
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss))
return kwargs[self.loss_key]


def __call__(self, output_dict, predict_dict, force_check=False):
def __call__(self, pred_dict, target_dict, check=False):


loss = self.get_loss(**output_dict)
loss = self.get_loss(**pred_dict)


if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):


+ 35
- 32
fastNLP/core/metrics.py View File

@@ -1,4 +1,3 @@

import inspect import inspect
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@@ -7,11 +6,12 @@ import numpy as np
import torch import torch


from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import seq_lens_to_masks from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.utils import CheckRes


class MetricBase(object): class MetricBase(object):
def __init__(self): def __init__(self):
@@ -59,9 +59,10 @@ class MetricBase(object):
func_args = [arg for arg in func_spect.args if arg != 'self'] func_args = [arg for arg in func_spect.args if arg != 'self']
for func_param, input_param in self.param_map.items(): for func_param, input_param in self.param_map.items():
if func_param not in func_args: if func_param not in func_args:
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")


# evaluate should not have varargs. # evaluate should not have varargs.
if func_spect.varargs: if func_spect.varargs:
@@ -71,7 +72,7 @@ class MetricBase(object):
def get_metric(self, reset=True): def get_metric(self, reset=True):
raise NotImplemented raise NotImplemented


def _fast_call_evaluate(self, pred_dict, target_dict):
def _fast_param_map(self, pred_dict, target_dict):
""" """


Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
@@ -80,7 +81,9 @@ class MetricBase(object):
:param target_dict: :param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. :return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
""" """
return False
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return pred_dict.values[0] and target_dict.values[0]
return None


def __call__(self, pred_dict, target_dict, check=False): def __call__(self, pred_dict, target_dict, check=False):
""" """
@@ -103,13 +106,15 @@ class MetricBase(object):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")


if not check: if not check:
if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict):
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict)
if fast_param is not None:
self.evaluate(*fast_param)
return return


if not self._checked: if not self._checked:
# 1. check consistence between signature and param_map # 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
func_args = set([arg for arg in func_spect.args if arg!='self'])
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items(): for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args: if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.")
@@ -117,7 +122,7 @@ class MetricBase(object):
# 2. only part of the param_map are passed, left are not # 2. only part of the param_map are passed, left are not
for arg in func_args: for arg in func_args:
if arg not in self.param_map: if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping.
self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}


@@ -149,14 +154,14 @@ class MetricBase(object):
replaced_missing = list(missing) replaced_missing = list(missing)
for idx, func_arg in enumerate(missing): for idx, func_arg in enumerate(missing):
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
f"in `{self.__class__.__name__}`)"


check_res = CheckRes(missing=replaced_missing, check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)


if check_res.missing or check_res.duplicated or check_res.varargs: if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res, raise CheckError(check_res=check_res,
@@ -168,6 +173,7 @@ class MetricBase(object):


return return



class AccuracyMetric(MetricBase): class AccuracyMetric(MetricBase):
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): def __init__(self, pred=None, target=None, masks=None, seq_lens=None):
super().__init__() super().__init__()
@@ -187,7 +193,7 @@ class AccuracyMetric(MetricBase):
:param target_dict: :param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. :return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
""" """
if len(pred_dict)==1 and len(target_dict)==1:
if len(pred_dict) == 1 and len(target_dict) == 1:
pred = list(pred_dict.values())[0] pred = list(pred_dict.values())[0]
target = list(target_dict.values())[0] target = list(target_dict.values())[0]
self.evaluate(pred=pred, target=target) self.evaluate(pred=pred, target=target)
@@ -207,7 +213,7 @@ class AccuracyMetric(MetricBase):
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:return: dict({'acc': float}) :return: dict({'acc': float})
""" """
#TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
if not isinstance(pred, torch.Tensor): if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.") f"got {type(pred)}.")
@@ -220,14 +226,14 @@ class AccuracyMetric(MetricBase):
f"got {type(masks)}.") f"got {type(masks)}.")
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")
f"got {type(seq_lens)}.")


if masks is None and seq_lens is not None: if masks is None and seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)


if pred.size()==target.size():
if pred.size() == target.size():
pass pass
elif len(pred.size())==len(target.size())+1:
elif len(pred.size()) == len(target.size()) + 1:
pred = pred.argmax(dim=-1) pred = pred.argmax(dim=-1)
else: else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
@@ -241,18 +247,17 @@ class AccuracyMetric(MetricBase):
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item() self.total += torch.sum(masks.float()).item()
else: else:
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.total += np.prod(list(pred.size())) self.total += np.prod(list(pred.size()))


def get_metric(self, reset=True): def get_metric(self, reset=True):
evaluate_result = {'acc': round(self.acc_count/self.total, 6)}
evaluate_result = {'acc': round(self.acc_count / self.total, 6)}
if reset: if reset:
self.acc_count = 0 self.acc_count = 0
self.total = 0 self.total = 0
return evaluate_result return evaluate_result





def _prepare_metrics(metrics): def _prepare_metrics(metrics):
""" """


@@ -274,7 +279,8 @@ def _prepare_metrics(metrics):
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.")
_metrics.append(metric) _metrics.append(metric)
else: else:
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
raise TypeError(
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
elif isinstance(metrics, MetricBase): elif isinstance(metrics, MetricBase):
_metrics = [metrics] _metrics = [metrics]
else: else:
@@ -296,6 +302,7 @@ class Evaluator(object):
""" """
raise NotImplementedError raise NotImplementedError



class ClassifyEvaluator(Evaluator): class ClassifyEvaluator(Evaluator):
def __init__(self): def __init__(self):
super(ClassifyEvaluator, self).__init__() super(ClassifyEvaluator, self).__init__()
@@ -331,6 +338,7 @@ class SeqLabelEvaluator(Evaluator):
accuracy = total_correct / total_count accuracy = total_correct / total_count
return {"accuracy": float(accuracy)} return {"accuracy": float(accuracy)}



class SeqLabelEvaluator2(Evaluator): class SeqLabelEvaluator2(Evaluator):
# 上面的evaluator应该是错误的 # 上面的evaluator应该是错误的
def __init__(self, seq_lens_field_name='word_seq_origin_len'): def __init__(self, seq_lens_field_name='word_seq_origin_len'):
@@ -363,7 +371,7 @@ class SeqLabelEvaluator2(Evaluator):
if x_i in self.end_tagidx_set: if x_i in self.end_tagidx_set:
truth_count += 1 truth_count += 1
for j in range(start, idx_i + 1): for j in range(start, idx_i + 1):
if y_[j]!=x_[j]:
if y_[j] != x_[j]:
flag = False flag = False
break break
if flag: if flag:
@@ -376,8 +384,7 @@ class SeqLabelEvaluator2(Evaluator):
R = corr_count / (float(truth_count) + 1e-6) R = corr_count / (float(truth_count) + 1e-6)
F = 2 * P * R / (P + R + 1e-6) F = 2 * P * R / (P + R + 1e-6)


return {"P": P, 'R':R, 'F': F}

return {"P": P, 'R': R, 'F': F}




class SNLIEvaluator(Evaluator): class SNLIEvaluator(Evaluator):
@@ -559,10 +566,6 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0




def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2):
raise NotImplementedError


def accuracy_topk(y_true, y_prob, k=1): def accuracy_topk(y_true, y_prob, k=1):
"""Compute accuracy of y_true matching top-k probable """Compute accuracy of y_true matching top-k probable
labels in y_prob. labels in y_prob.


+ 3
- 53
fastNLP/core/optimizer.py View File

@@ -4,40 +4,13 @@ import torch
class Optimizer(object): class Optimizer(object):
def __init__(self, model_params, **kwargs): def __init__(self, model_params, **kwargs):
if model_params is not None and not hasattr(model_params, "__next__"): if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params)))
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params)))
self.model_params = model_params self.model_params = model_params
self.settings = kwargs self.settings = kwargs




class SGD(Optimizer): class SGD(Optimizer):
def __init__(self, *args, **kwargs):
model_params, lr, momentum = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
# SGD()
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
# SGD(0.001)
lr = args[0]
elif hasattr(args[0], "__next__"):
# SGD(model.parameters()) args[0] is a generator
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9)
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "momentum") for key in kwargs):
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs))
lr = kwargs.get("lr", 0.01)
momentum = kwargs.get("momentum", 0.9)
else:
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

def __init__(self, model_params=None, lr=0.01, momentum=0):
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
@@ -49,30 +22,7 @@ class SGD(Optimizer):




class Adam(Optimizer): class Adam(Optimizer):
def __init__(self, *args, **kwargs):
model_params, lr, weight_decay = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
lr = args[0]
elif hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "weight_decay") for key in kwargs):
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs))
lr = kwargs.get("lr", 0.01)
weight_decay = kwargs.get("weight_decay", 0.9)
else:
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

def __init__(self, model_params=None, lr=0.01, weight_decay=0):
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):


+ 67
- 24
fastNLP/core/trainer.py View File

@@ -1,6 +1,7 @@
import os import os
import time import time
from datetime import datetime from datetime import datetime
from datetime import timedelta
from tqdm import tqdm from tqdm import tqdm


import torch import torch
@@ -22,17 +23,16 @@ from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _relocate_pbar


class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """

def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50, def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50,
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, validate_every=-1, dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None, sampler=RandomSampler()):
metric_key=None, sampler=RandomSampler(), use_tqdm=True):
""" """


:param DataSet train_data: the training data :param DataSet train_data: the training data
@@ -54,6 +54,7 @@ class Trainer(object):
:: ::
metric_key="-PPL" # language model gets better as perplexity gets smaller metric_key="-PPL" # language model gets better as perplexity gets smaller
:param sampler: method used to generate batch data. :param sampler: method used to generate batch data.
:param use_tqdm: boolean, use tqdm to show train progress.


""" """
super(Trainer, self).__init__() super(Trainer, self).__init__()
@@ -117,19 +118,23 @@ class Trainer(object):
else: else:
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())


self.use_tqdm = use_tqdm
if self.use_tqdm:
tester_verbose = 0
else:
tester_verbose = 1

if self.dev_data is not None: if self.dev_data is not None:
self.tester = Tester(model=self.model, self.tester = Tester(model=self.model,
data=self.dev_data, data=self.dev_data,
metrics=self.metrics, metrics=self.metrics,
batch_size=self.batch_size, batch_size=self.batch_size,
use_cuda=self.use_cuda, use_cuda=self.use_cuda,
verbose=0)
verbose=tester_verbose)


self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp


# print(self.__dict__)

def train(self): def train(self):
"""Start Training. """Start Training.


@@ -155,8 +160,10 @@ class Trainer(object):
else: else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)

self._tqdm_train()
if self.use_tqdm:
self._tqdm_train()
else:
self._print_train()


finally: finally:
self._summary_writer.close() self._summary_writer.close()
@@ -196,31 +203,67 @@ class Trainer(object):
eval_res = self._do_validation() eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step)
time.sleep(0.1)
pbar = _relocate_pbar(pbar, print_str=eval_str)
if self.validate_every < 0 and self.dev_data: if self.validate_every < 0 and self.dev_data:
eval_res = self._do_validation() eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step)
pbar = _relocate_pbar(pbar, print_str=eval_str)
if epoch!=self.n_epochs: if epoch!=self.n_epochs:
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)
pbar.close() pbar.close()


def _relocate_pbar(self, pbar, total, initial, print_str=None):
postfix = pbar.postfix
desc = pbar.desc
pbar.close()
avg_time = pbar.avg_time
start_t = pbar.start_t
if print_str:
print(print_str)
pbar = tqdm(total=total, postfix=postfix, desc=desc, leave=False, initial=initial, dynamic_ncols=True)
pbar.start_t = start_t
pbar.avg_time = avg_time
pbar.sp(pbar.__repr__())
return pbar

def _print_train(self):
"""

:param data_iterator:
:param model:
:param epoch:
:param start:
:return:
"""
epoch = 1
start = time.time()
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)

for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff)
print(print_output)

if (self.validate_every > 0 and self.step % self.validate_every == 0 and
self.dev_data is not None):
self._do_validation()

self.step += 1

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self._do_validation()
epoch += 1





def _do_validation(self): def _do_validation(self):
res = self.tester.test() res = self.tester.test()


+ 74
- 15
fastNLP/core/utils.py View File

@@ -7,9 +7,12 @@ from collections import namedtuple


import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False) 'varargs'], verbose=False)


def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file. """Save an object into a pickle file.


@@ -53,6 +56,7 @@ def pickle_exist(pickle_path, pickle_name):
else: else:
return False return False



def _build_args(func, **kwargs): def _build_args(func, **kwargs):
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
if spect.varkw is not None: if spect.varkw is not None:
@@ -108,7 +112,7 @@ def _check_arg_dict_list(func, args):
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) assert callable(func) and isinstance(arg_dict_list, (list, tuple))
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
all_args = set([arg for arg in spect.args if arg!='self'])
all_args = set([arg for arg in spect.args if arg != 'self'])
defaults = [] defaults = []
if spect.defaults is not None: if spect.defaults is not None:
defaults = [arg for arg in spect.defaults] defaults = [arg for arg in spect.defaults]
@@ -130,6 +134,7 @@ def _check_arg_dict_list(func, args):
all_needed=list(all_args), all_needed=list(all_args),
varargs=varargs) varargs=varargs)



def get_func_signature(func): def get_func_signature(func):
""" """


@@ -153,7 +158,7 @@ def get_func_signature(func):
class_name = func.__self__.__class__.__name__ class_name = func.__self__.__class__.__name__
signature = inspect.signature(func) signature = inspect.signature(func)
signature_str = str(signature) signature_str = str(signature)
if len(signature_str)>2:
if len(signature_str) > 2:
_self = '(self, ' _self = '(self, '
else: else:
_self = '(self' _self = '(self'
@@ -176,12 +181,13 @@ def _is_function_or_method(func):
return False return False
return True return True



def _check_function_or_method(func): def _check_function_or_method(func):
if not _is_function_or_method(func): if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.") raise TypeError(f"{type(func)} is not a method or function.")




def _move_dict_value_to_device(*args, device:torch.device):
def _move_dict_value_to_device(*args, device: torch.device):
""" """


move data to model's device, element in *args should be dict. This is a inplace change. move data to model's device, element in *args should be dict. This is a inplace change.
@@ -206,7 +212,8 @@ class CheckError(Exception):


CheckError. Used in losses.LossBase, metrics.MetricBase. CheckError. Used in losses.LossBase, metrics.MetricBase.
""" """
def __init__(self, check_res:CheckRes, func_signature:str):

def __init__(self, check_res: CheckRes, func_signature: str):
errs = [f'The following problems occurred when calling `{func_signature}`'] errs = [f'The following problems occurred when calling `{func_signature}`']


if check_res.varargs: if check_res.varargs:
@@ -228,8 +235,9 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1 WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2 STRICT_CHECK_LEVEL = 2


def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes,
pred_dict:dict, target_dict:dict, dataset, check_level=0):

def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes,
pred_dict: dict, target_dict: dict, dataset, check_level=0):
errs = [] errs = []
unuseds = [] unuseds = []
_unused_field = [] _unused_field = []
@@ -268,8 +276,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
f"target is {list(target_dict.keys())}).") f"target is {list(target_dict.keys())}).")
if _miss_out_dataset: if _miss_out_dataset:
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
if _unused_field: if _unused_field:
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
suggestions.append(_tmp) suggestions.append(_tmp)
@@ -277,15 +285,15 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.") errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of " suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")


if check_level == STRICT_CHECK_LEVEL: if check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds) errs.extend(unuseds)


if len(errs)>0:
if len(errs) > 0:
errs.insert(0, f'The following problems occurred when calling {func_signature}') errs.insert(0, f'The following problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions)>1:
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}' sugg_str += f'({idx+1}). {sugg}'
else: else:
@@ -332,10 +340,10 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
if check_level == STRICT_CHECK_LEVEL: if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused) errs.extend(_unused)


if len(errs)>0:
if len(errs) > 0:
errs.insert(0, f'The following problems occurred when calling {func_signature}') errs.insert(0, f'The following problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions)>1:
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}' sugg_str += f'({idx+1}). {sugg}'
else: else:
@@ -357,11 +365,11 @@ def seq_lens_to_masks(seq_lens, float=True):
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) :return: list, np.ndarray or torch.Tensor, shape will be (B, max_length)
""" """
if isinstance(seq_lens, np.ndarray): if isinstance(seq_lens, np.ndarray):
assert len(np.shape(seq_lens))==1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}."
raise NotImplemented raise NotImplemented
elif isinstance(seq_lens, torch.LongTensor): elif isinstance(seq_lens, torch.LongTensor):
assert len(seq_lens.size())==1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
batch_size = seq_lens.size(0) batch_size = seq_lens.size(0)
max_len = seq_lens.max() max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
@@ -376,3 +384,54 @@ def seq_lens_to_masks(seq_lens, float=True):
else: else:
raise NotImplemented raise NotImplemented



def seq_mask(seq_len, max_len):
"""Create sequence mask.

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]


def _relocate_pbar(pbar:tqdm, print_str:str):
"""

When using tqdm, you cannot print. If you print, the tqdm will duplicate. By using this function, print_str will
show above tqdm.
:param pbar: tqdm
:param print_str:
:return:
"""

params = ['desc', 'total', 'leave', 'file', 'ncols', 'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable',
'unit', 'unit_scale', 'dynamic_ncols', 'smoothing', 'bar_format', 'initial', 'position', 'postfix', 'unit_divisor',
'gui']

attr_map = {'file': 'fp', 'initial':'n', 'position':'pos'}

param_dict = {}
for param in params:
attr_name = param
if param in attr_map:
attr_name = attr_map[param]
value = getattr(pbar, attr_name)
if attr_name == 'pos':
value = abs(value)
param_dict[param] = value

pbar.close()
avg_time = pbar.avg_time
start_t = pbar.start_t
print(print_str)
pbar = tqdm(**param_dict)
pbar.start_t = start_t
pbar.avg_time = avg_time
pbar.sp(pbar.__repr__())
return pbar

+ 3
- 3
fastNLP/io/embed_loader.py View File

@@ -105,9 +105,9 @@ class EmbedLoader(BaseLoader):


if np.sum(hit_flags) < len(vocab): if np.sum(hit_flags) < len(vocab):
# some words from vocab are missing in pre-trained embedding # some words from vocab are missing in pre-trained embedding
# we normally sample them
# we normally sample each dimension
vocab_embed = embedding_matrix[np.where(hit_flags)] vocab_embed = embedding_matrix[np.where(hit_flags)]
mean, cov = vocab_embed.mean(axis=0), np.cov(vocab_embed.T)
sampled_vectors = np.random.multivariate_normal(mean, cov, size=(len(vocab) - np.sum(hit_flags),))
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
size=(len(vocab) - np.sum(hit_flags), emb_dim))
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors
return embedding_matrix return embedding_matrix

+ 1
- 1
fastNLP/modules/encoder/char_embedding.py View File

@@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module):
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
y = torch.squeeze(y, 2) y = torch.squeeze(y, 2)
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
y = F.tanh(y)
y = torch.tanh(y)
y, __ = torch.max(y, 2) y, __ = torch.max(y, 2)
# [batch_size*sent_length, feature_maps[i]] # [batch_size*sent_length, feature_maps[i]]
feats.append(y) feats.append(y)


+ 61
- 0
test/core/test_dataset.py View File

@@ -44,6 +44,9 @@ class TestDataSet(unittest.TestCase):
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)


with self.assertRaises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40)

def test_delete_field(self): def test_delete_field(self):
dd = DataSet() dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("x", [[1, 2, 3]] * 10)
@@ -65,8 +68,66 @@ class TestDataSet(unittest.TestCase):
self.assertTrue(isinstance(sub_ds, DataSet)) self.assertTrue(isinstance(sub_ds, DataSet))
self.assertEqual(len(sub_ds), 10) self.assertEqual(len(sub_ds), 10)


def test_get_item_error(self):
with self.assertRaises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds[40:]

with self.assertRaises(KeyError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds["kom"]

def test_len_(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertEqual(len(ds), 40)

ds = DataSet()
self.assertEqual(len(ds), 0)

def test_apply(self): def test_apply(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")
self.assertTrue("rx" in ds.field_arrays) self.assertTrue("rx" in ds.field_arrays)
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])

def test_contains(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds)
self.assertTrue("y" in ds)
self.assertFalse("z" in ds)

def test_rename_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.rename_field("x", "xx")
self.assertTrue("xx" in ds)
self.assertFalse("x" in ds)

with self.assertRaises(KeyError):
ds.rename_field("yyy", "oo")

def test_input_target(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.set_input("x")
ds.set_target("y")
self.assertTrue(ds.field_arrays["x"].is_input)
self.assertTrue(ds.field_arrays["y"].is_target)

with self.assertRaises(KeyError):
ds.set_input("xxx")
with self.assertRaises(KeyError):
ds.set_input("yyy")

def test_get_input_name(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input])

def test_get_target_name(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target])


class TestDataSetIter(unittest.TestCase):
def test__repr__(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
for iter in ds:
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")

+ 6
- 0
test/core/test_instance.py View File

@@ -27,3 +27,9 @@ class TestCase(unittest.TestCase):
self.assertEqual(ins["x"], [1, 2, 3]) self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6]) self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1]) self.assertEqual(ins["z"], [1, 1, 1])

def test_repr(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields)
# simple print, that is enough.
print(ins)

+ 16
- 24
test/core/test_loss.py View File

@@ -271,40 +271,32 @@ class TestLoss(unittest.TestCase):
loss3 = get_loss_3({'predict': predict}, {'truth': truth}) loss3 = get_loss_3({'predict': predict}, {'truth': truth})
assert loss1 == loss2 and loss1 == loss3 assert loss1 == loss2 and loss1 == loss3


"""
get_loss_4 = LossFunc(func4)
loss4 = get_loss_4({'a': 1, 'b': 3}, {})
print(loss4)
assert loss4 == (1 + 3) * 2

get_loss_5 = LossFunc(func4)
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4})
print(loss5)
assert loss5 == (1 + 3) * 4

get_loss_6 = LossFunc(func6)
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4})
print(loss6)
assert loss6 == (1 + 3) * 4

get_loss_7 = LossFunc(func6, c='cc')
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4})
print(loss7)
assert loss7 == (1 + 3) * 4
"""



class TestLoss_v2(unittest.TestCase): class TestLoss_v2(unittest.TestCase):
def test_CrossEntropyLoss(self): def test_CrossEntropyLoss(self):
ce = loss.CrossEntropyLoss(input="my_predict", target="my_truth")
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth")
a = torch.randn(3, 5, requires_grad=False) a = torch.randn(3, 5, requires_grad=False)
b = torch.empty(3, dtype=torch.long).random_(5) b = torch.empty(3, dtype=torch.long).random_(5)
ans = ce({"my_predict": a}, {"my_truth": b}) ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))


def test_BCELoss(self): def test_BCELoss(self):
bce = loss.BCELoss(input="my_predict", target="my_truth")
bce = loss.BCELoss(pred="my_predict", target="my_truth")
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) a = torch.sigmoid(torch.randn((3, 5), requires_grad=False))
b = torch.randn((3, 5), requires_grad=False) b = torch.randn((3, 5), requires_grad=False)
ans = bce({"my_predict": a}, {"my_truth": b}) ans = bce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b))

def test_L1Loss(self):
l1 = loss.L1Loss(pred="my_predict", target="my_truth")
a = torch.randn(3, 5, requires_grad=False)
b = torch.randn(3, 5)
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b))

def test_NLLLoss(self):
l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
b = torch.tensor([1, 0, 4])
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))

+ 0
- 8
test/core/test_optimizer.py View File

@@ -11,9 +11,6 @@ class TestOptim(unittest.TestCase):
self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"]) self.assertTrue("momentum" in optim.__dict__["settings"])


optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = SGD(lr=0.001) optim = SGD(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)


@@ -25,17 +22,12 @@ class TestOptim(unittest.TestCase):
_ = SGD("???") _ = SGD("???")
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = SGD(0.001, lr=0.002) _ = SGD(0.001, lr=0.002)
with self.assertRaises(RuntimeError):
_ = SGD(lr=0.009, shit=9000)


def test_Adam(self): def test_Adam(self):
optim = Adam(torch.nn.Linear(10, 3).parameters()) optim = Adam(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"]) self.assertTrue("weight_decay" in optim.__dict__["settings"])


optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = Adam(lr=0.001) optim = Adam(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)




+ 6
- 6
test/core/test_trainer.py View File

@@ -32,14 +32,14 @@ class TrainerTestGround(unittest.TestCase):
model = NaiveClassifier(2, 1) model = NaiveClassifier(2, 1)


trainer = Trainer(train_set, model, trainer = Trainer(train_set, model,
losser=BCELoss(input="predict", target="y"),
losser=BCELoss(pred="predict", target="y"),
metrics=AccuracyMetric(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10, n_epochs=10,
batch_size=32, batch_size=32,
update_every=1, update_every=1,
validate_every=-1,
validate_every=10,
dev_data=dev_set, dev_data=dev_set,
optimizer=SGD(0.1),
check_code_level=2
)
trainer.train()
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=True)
trainer.train()

+ 3
- 3
test/io/test_embed_loader.py View File

@@ -1,12 +1,12 @@
import unittest import unittest


from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.embed_loader import EmbedLoader




class TestEmbedLoader(unittest.TestCase): class TestEmbedLoader(unittest.TestCase):
def test_case(self): def test_case(self):
vocab = Vocabulary() vocab = Vocabulary()
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) vocab.update(["the", "in", "I", "to", "of", "hahaha"])
# TODO: np.cov在linux上segment fault,原因未知
# embedding = EmbedLoader().fast_load_embedding(50, "../data_for_tests/glove.6B.50d_test.txt", vocab)
# self.assertEqual(tuple(embedding.shape), (len(vocab), 50))
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab)
self.assertEqual(tuple(embedding.shape), (len(vocab), 50))

+ 8
- 12
test/test_tutorial.py View File

@@ -71,20 +71,16 @@ class TestTutorial(unittest.TestCase):


# 实例化Trainer,传入模型和数据,进行训练 # 实例化Trainer,传入模型和数据,进行训练
copy_model = deepcopy(model) copy_model = deepcopy(model)
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,
losser=CrossEntropyLoss(input="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
save_path="./save",
batch_size=4,
n_epochs=10)
overfit_trainer = Trainer(train_data=test_data, model=copy_model,
losser=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save")
overfit_trainer.train() overfit_trainer.train()


trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,
losser=CrossEntropyLoss(input="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
save_path="./save",
batch_size=4,
n_epochs=10)
trainer = Trainer(train_data=train_data, model=model,
losser=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save")
trainer.train() trainer.train()
print('Train finished!') print('Train finished!')




Loading…
Cancel
Save