Browse Source

Merge branch 'trainer' of github.com:FengZiYjun/fastNLP into trainer

# Conflicts:
#	fastNLP/core/dataset.py
#	fastNLP/core/trainer.py
#	test/core/test_trainer.py

Trainer support print_train and tqdm train.
tags/v0.2.0^2
yh 6 years ago
parent
commit
785c41ded5
16 changed files with 359 additions and 220 deletions
  1. +46
    -21
      fastNLP/core/dataset.py
  2. +2
    -3
      fastNLP/core/instance.py
  3. +28
    -15
      fastNLP/core/losses.py
  4. +35
    -32
      fastNLP/core/metrics.py
  5. +3
    -53
      fastNLP/core/optimizer.py
  6. +67
    -24
      fastNLP/core/trainer.py
  7. +74
    -15
      fastNLP/core/utils.py
  8. +3
    -3
      fastNLP/io/embed_loader.py
  9. +1
    -1
      fastNLP/modules/encoder/char_embedding.py
  10. +61
    -0
      test/core/test_dataset.py
  11. +6
    -0
      test/core/test_instance.py
  12. +16
    -24
      test/core/test_loss.py
  13. +0
    -8
      test/core/test_optimizer.py
  14. +6
    -6
      test/core/test_trainer.py
  15. +3
    -3
      test/io/test_embed_loader.py
  16. +8
    -12
      test/test_tutorial.py

+ 46
- 21
fastNLP/core/dataset.py View File

@@ -1,7 +1,9 @@
import _pickle as pickle
import numpy as np

from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance
from fastNLP.core.utils import get_func_signature

_READERS = {}

@@ -26,24 +28,6 @@ class DataSet(object):
However, it stores data in a different way: Field-first, Instance-second.

"""

class DataSetIter(object):
def __init__(self, data_set, idx=-1, **fields):
self.data_set = data_set
self.idx = idx
self.fields = fields

def __next__(self):
self.idx += 1
if self.idx >= len(self.data_set):
raise StopIteration
# this returns a copy
return self.data_set[self.idx]

def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name
in self.data_set.get_fields().keys()])

def __init__(self, data=None):
"""

@@ -72,7 +56,27 @@ class DataSet(object):
return item in self.field_arrays

def __iter__(self):
return self.DataSetIter(self)
def iter_func():
for idx in range(len(self)):
yield self[idx]
return iter_func()

def _inner_iter(self):
class Iter_ptr:
def __init__(self, dataset, idx):
self.dataset = dataset
self.idx = idx
def __getitem__(self, item):
assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx)
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx])
return self.dataset.field_arrays[item][self.idx]
def __repr__(self):
return self.dataset[self.idx].__repr__()

def inner_iter_func():
for idx in range(len(self)):
yield Iter_ptr(self, idx)
return inner_iter_func()

def __getitem__(self, idx):
"""Fetch Instance(s) at the `idx` position(s) in the dataset.
@@ -110,6 +114,15 @@ class DataSet(object):
field = iter(self.field_arrays.values()).__next__()
return len(field)

def __inner_repr__(self):
if len(self) < 20:
return ",\n".join([ins.__repr__() for ins in self])
else:
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__()

def __repr__(self):
return "DataSet(" + self.__inner_repr__() + ")"

def append(self, ins):
"""Add an instance to the DataSet.
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet.
@@ -226,7 +239,10 @@ class DataSet(object):
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target.
:return results: if new_field_name is not passed, returned values of the function over all instances.
"""
results = [func(ins) for ins in self]
results = [func(ins) for ins in self._inner_iter()]
if len(list(filter(lambda x: x is not None, results)))==0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))

extra_param = {}
if 'is_input' in kwargs:
extra_param['is_input'] = kwargs['is_input']
@@ -250,7 +266,7 @@ class DataSet(object):
return results

def drop(self, func):
results = [ins for ins in self if not func(ins)]
results = [ins for ins in self._inner_iter() if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]

@@ -317,3 +333,12 @@ class DataSet(object):
for header, content in zip(headers, contents):
_dict[header].append(content)
return cls(_dict)

def save(self, path):
with open(path, 'wb') as f:
pickle.dump(self, f)

@staticmethod
def load(self, path):
with open(path, 'rb') as f:
return pickle.load(f)

+ 2
- 3
fastNLP/core/instance.py View File

@@ -1,5 +1,3 @@


class Instance(object):
"""An Instance is an example of data. It is the collection of Fields.

@@ -33,4 +31,5 @@ class Instance(object):
return self.add_field(name, field)

def __repr__(self):
return self.fields.__repr__()
return "{" + ",\n".join(
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}"

+ 28
- 15
fastNLP/core/losses.py View File

@@ -70,13 +70,23 @@ class LossBase(object):
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
f"positional argument.).")

def __call__(self, output_dict, target_dict, force_check=False):
def _fast_param_map(self, pred_dict, target_dict):
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return tuple(pred_dict.values())[0], tuple(target_dict.values())[0]
return None

def __call__(self, pred_dict, target_dict, check=False):
"""
:param output_dict: A dict from forward function of the network.
:param pred_dict: A dict from forward function of the network.
:param target_dict: A dict from DataSet.batch_y.
:param force_check: Boolean. Force to check the mapping functions when it is running.
:param check: Boolean. Force to check the mapping functions when it is running.
:return:
"""
fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param is not None:
loss = self.get_loss(*fast_param)
return loss

args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss)
if varargs is not None:
raise RuntimeError(
@@ -88,7 +98,8 @@ class LossBase(object):
raise RuntimeError(
f"There is not any param in function{get_func_signature(self.get_loss)}"
)
self._checked = self._checked and not force_check

self._checked = self._checked and not check
if not self._checked:
for keys in args:
if keys not in param_map:
@@ -105,12 +116,12 @@ class LossBase(object):
duplicated = []
missing = []
if not self._checked:
for keys, val in output_dict.items():
for keys, val in pred_dict.items():
if keys in target_dict.keys():
duplicated.append(keys)

param_val_dict = {}
for keys, val in output_dict.items():
for keys, val in pred_dict.items():
param_val_dict.update({keys: val})
for keys, val in target_dict.items():
param_val_dict.update({keys: val})
@@ -131,7 +142,6 @@ class LossBase(object):

param_map_val = _map_args(reversed_param_map, **param_val_dict)
param_value = _build_args(self.get_loss, **param_map_val)

loss = self.get_loss(**param_value)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
@@ -158,29 +168,31 @@ class LossFunc(LossBase):


class CrossEntropyLoss(LossBase):
def __init__(self, input=None, target=None):
def __init__(self, pred=None, target=None):
super(CrossEntropyLoss, self).__init__()
self.get_loss = F.cross_entropy
self._init_param_map(input=input, target=target)
self._init_param_map(input=pred, target=target)


class L1Loss(LossBase):
def __init__(self):
def __init__(self, pred=None, target=None):
super(L1Loss, self).__init__()
self.get_loss = F.l1_loss
self._init_param_map(input=pred, target=target)


class BCELoss(LossBase):
def __init__(self, input=None, target=None):
def __init__(self, pred=None, target=None):
super(BCELoss, self).__init__()
self.get_loss = F.binary_cross_entropy
self._init_param_map(input=input, target=target)
self._init_param_map(input=pred, target=target)


class NLLLoss(LossBase):
def __init__(self):
def __init__(self, pred=None, target=None):
super(NLLLoss, self).__init__()
self.get_loss = F.nll_loss
self._init_param_map(input=pred, target=target)


class LossInForward(LossBase):
@@ -199,10 +211,11 @@ class LossInForward(LossBase):
all_needed=[],
varargs=[])
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss))
return kwargs[self.loss_key]

def __call__(self, output_dict, predict_dict, force_check=False):
def __call__(self, pred_dict, target_dict, check=False):

loss = self.get_loss(**output_dict)
loss = self.get_loss(**pred_dict)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):


+ 35
- 32
fastNLP/core/metrics.py View File

@@ -1,4 +1,3 @@

import inspect
import warnings
from collections import defaultdict
@@ -7,11 +6,12 @@ import numpy as np
import torch

from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.utils import CheckRes

class MetricBase(object):
def __init__(self):
@@ -59,9 +59,10 @@ class MetricBase(object):
func_args = [arg for arg in func_spect.args if arg != 'self']
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")

# evaluate should not have varargs.
if func_spect.varargs:
@@ -71,7 +72,7 @@ class MetricBase(object):
def get_metric(self, reset=True):
raise NotImplemented

def _fast_call_evaluate(self, pred_dict, target_dict):
def _fast_param_map(self, pred_dict, target_dict):
"""

Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
@@ -80,7 +81,9 @@ class MetricBase(object):
:param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
"""
return False
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return pred_dict.values[0] and target_dict.values[0]
return None

def __call__(self, pred_dict, target_dict, check=False):
"""
@@ -103,13 +106,15 @@ class MetricBase(object):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

if not check:
if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict):
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict)
if fast_param is not None:
self.evaluate(*fast_param)
return

if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = set([arg for arg in func_spect.args if arg!='self'])
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.")
@@ -117,7 +122,7 @@ class MetricBase(object):
# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping.
self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}

@@ -149,14 +154,14 @@ class MetricBase(object):
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
f"in `{self.__class__.__name__}`)"

check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
@@ -168,6 +173,7 @@ class MetricBase(object):

return


class AccuracyMetric(MetricBase):
def __init__(self, pred=None, target=None, masks=None, seq_lens=None):
super().__init__()
@@ -187,7 +193,7 @@ class AccuracyMetric(MetricBase):
:param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
"""
if len(pred_dict)==1 and len(target_dict)==1:
if len(pred_dict) == 1 and len(target_dict) == 1:
pred = list(pred_dict.values())[0]
target = list(target_dict.values())[0]
self.evaluate(pred=pred, target=target)
@@ -207,7 +213,7 @@ class AccuracyMetric(MetricBase):
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:return: dict({'acc': float})
"""
#TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
@@ -220,14 +226,14 @@ class AccuracyMetric(MetricBase):
f"got {type(masks)}.")
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")
f"got {type(seq_lens)}.")

if masks is None and seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)

if pred.size()==target.size():
if pred.size() == target.size():
pass
elif len(pred.size())==len(target.size())+1:
elif len(pred.size()) == len(target.size()) + 1:
pred = pred.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
@@ -241,18 +247,17 @@ class AccuracyMetric(MetricBase):
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item()
else:
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.total += np.prod(list(pred.size()))

def get_metric(self, reset=True):
evaluate_result = {'acc': round(self.acc_count/self.total, 6)}
evaluate_result = {'acc': round(self.acc_count / self.total, 6)}
if reset:
self.acc_count = 0
self.total = 0
return evaluate_result



def _prepare_metrics(metrics):
"""

@@ -274,7 +279,8 @@ def _prepare_metrics(metrics):
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.")
_metrics.append(metric)
else:
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
raise TypeError(
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
elif isinstance(metrics, MetricBase):
_metrics = [metrics]
else:
@@ -296,6 +302,7 @@ class Evaluator(object):
"""
raise NotImplementedError


class ClassifyEvaluator(Evaluator):
def __init__(self):
super(ClassifyEvaluator, self).__init__()
@@ -331,6 +338,7 @@ class SeqLabelEvaluator(Evaluator):
accuracy = total_correct / total_count
return {"accuracy": float(accuracy)}


class SeqLabelEvaluator2(Evaluator):
# 上面的evaluator应该是错误的
def __init__(self, seq_lens_field_name='word_seq_origin_len'):
@@ -363,7 +371,7 @@ class SeqLabelEvaluator2(Evaluator):
if x_i in self.end_tagidx_set:
truth_count += 1
for j in range(start, idx_i + 1):
if y_[j]!=x_[j]:
if y_[j] != x_[j]:
flag = False
break
if flag:
@@ -376,8 +384,7 @@ class SeqLabelEvaluator2(Evaluator):
R = corr_count / (float(truth_count) + 1e-6)
F = 2 * P * R / (P + R + 1e-6)

return {"P": P, 'R':R, 'F': F}

return {"P": P, 'R': R, 'F': F}


class SNLIEvaluator(Evaluator):
@@ -559,10 +566,6 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0


def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2):
raise NotImplementedError


def accuracy_topk(y_true, y_prob, k=1):
"""Compute accuracy of y_true matching top-k probable
labels in y_prob.


+ 3
- 53
fastNLP/core/optimizer.py View File

@@ -4,40 +4,13 @@ import torch
class Optimizer(object):
def __init__(self, model_params, **kwargs):
if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params)))
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs


class SGD(Optimizer):
def __init__(self, *args, **kwargs):
model_params, lr, momentum = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
# SGD()
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
# SGD(0.001)
lr = args[0]
elif hasattr(args[0], "__next__"):
# SGD(model.parameters()) args[0] is a generator
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9)
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "momentum") for key in kwargs):
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs))
lr = kwargs.get("lr", 0.01)
momentum = kwargs.get("momentum", 0.9)
else:
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

def __init__(self, model_params=None, lr=0.01, momentum=0):
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)

def construct_from_pytorch(self, model_params):
@@ -49,30 +22,7 @@ class SGD(Optimizer):


class Adam(Optimizer):
def __init__(self, *args, **kwargs):
model_params, lr, weight_decay = None, 0.01, 0.9
if len(args) == 0 and len(kwargs) == 0:
pass
elif len(args) == 1 and len(kwargs) == 0:
if isinstance(args[0], float) or isinstance(args[0], int):
lr = args[0]
elif hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
elif 2 >= len(kwargs) > 0 and len(args) <= 1:
if len(args) == 1:
if hasattr(args[0], "__next__"):
model_params = args[0]
else:
raise RuntimeError("Not supported type {}.".format(type(args[0])))
if not all(key in ("lr", "weight_decay") for key in kwargs):
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs))
lr = kwargs.get("lr", 0.01)
weight_decay = kwargs.get("weight_decay", 0.9)
else:
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args))

def __init__(self, model_params=None, lr=0.01, weight_decay=0):
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)

def construct_from_pytorch(self, model_params):


+ 67
- 24
fastNLP/core/trainer.py View File

@@ -1,6 +1,7 @@
import os
import time
from datetime import datetime
from datetime import timedelta
from tqdm import tqdm

import torch
@@ -22,17 +23,16 @@ from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _relocate_pbar

class Trainer(object):
"""Main Training Loop

"""

def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50,
validate_every=-1, dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None, sampler=RandomSampler()):
metric_key=None, sampler=RandomSampler(), use_tqdm=True):
"""

:param DataSet train_data: the training data
@@ -54,6 +54,7 @@ class Trainer(object):
::
metric_key="-PPL" # language model gets better as perplexity gets smaller
:param sampler: method used to generate batch data.
:param use_tqdm: boolean, use tqdm to show train progress.

"""
super(Trainer, self).__init__()
@@ -117,19 +118,23 @@ class Trainer(object):
else:
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())

self.use_tqdm = use_tqdm
if self.use_tqdm:
tester_verbose = 0
else:
tester_verbose = 1

if self.dev_data is not None:
self.tester = Tester(model=self.model,
data=self.dev_data,
metrics=self.metrics,
batch_size=self.batch_size,
use_cuda=self.use_cuda,
verbose=0)
verbose=tester_verbose)

self.step = 0
self.start_time = None # start timestamp

# print(self.__dict__)

def train(self):
"""Start Training.

@@ -155,8 +160,10 @@ class Trainer(object):
else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path)

self._tqdm_train()
if self.use_tqdm:
self._tqdm_train()
else:
self._print_train()

finally:
self._summary_writer.close()
@@ -196,31 +203,67 @@ class Trainer(object):
eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step)
time.sleep(0.1)
pbar = _relocate_pbar(pbar, print_str=eval_str)
if self.validate_every < 0 and self.dev_data:
eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step)
pbar = _relocate_pbar(pbar, print_str=eval_str)
if epoch!=self.n_epochs:
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)
pbar.close()

def _relocate_pbar(self, pbar, total, initial, print_str=None):
postfix = pbar.postfix
desc = pbar.desc
pbar.close()
avg_time = pbar.avg_time
start_t = pbar.start_t
if print_str:
print(print_str)
pbar = tqdm(total=total, postfix=postfix, desc=desc, leave=False, initial=initial, dynamic_ncols=True)
pbar.start_t = start_t
pbar.avg_time = avg_time
pbar.sp(pbar.__repr__())
return pbar

def _print_train(self):
"""

:param data_iterator:
:param model:
:param epoch:
:param start:
:return:
"""
epoch = 1
start = time.time()
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)

for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff)
print(print_output)

if (self.validate_every > 0 and self.step % self.validate_every == 0 and
self.dev_data is not None):
self._do_validation()

self.step += 1

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self._do_validation()
epoch += 1




def _do_validation(self):
res = self.tester.test()


+ 74
- 15
fastNLP/core/utils.py View File

@@ -7,9 +7,12 @@ from collections import namedtuple

import numpy as np
import torch
from tqdm import tqdm

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False)


def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.

@@ -53,6 +56,7 @@ def pickle_exist(pickle_path, pickle_name):
else:
return False


def _build_args(func, **kwargs):
spect = inspect.getfullargspec(func)
if spect.varkw is not None:
@@ -108,7 +112,7 @@ def _check_arg_dict_list(func, args):
assert callable(func) and isinstance(arg_dict_list, (list, tuple))
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func)
all_args = set([arg for arg in spect.args if arg!='self'])
all_args = set([arg for arg in spect.args if arg != 'self'])
defaults = []
if spect.defaults is not None:
defaults = [arg for arg in spect.defaults]
@@ -130,6 +134,7 @@ def _check_arg_dict_list(func, args):
all_needed=list(all_args),
varargs=varargs)


def get_func_signature(func):
"""

@@ -153,7 +158,7 @@ def get_func_signature(func):
class_name = func.__self__.__class__.__name__
signature = inspect.signature(func)
signature_str = str(signature)
if len(signature_str)>2:
if len(signature_str) > 2:
_self = '(self, '
else:
_self = '(self'
@@ -176,12 +181,13 @@ def _is_function_or_method(func):
return False
return True


def _check_function_or_method(func):
if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.")


def _move_dict_value_to_device(*args, device:torch.device):
def _move_dict_value_to_device(*args, device: torch.device):
"""

move data to model's device, element in *args should be dict. This is a inplace change.
@@ -206,7 +212,8 @@ class CheckError(Exception):

CheckError. Used in losses.LossBase, metrics.MetricBase.
"""
def __init__(self, check_res:CheckRes, func_signature:str):

def __init__(self, check_res: CheckRes, func_signature: str):
errs = [f'The following problems occurred when calling `{func_signature}`']

if check_res.varargs:
@@ -228,8 +235,9 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2

def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes,
pred_dict:dict, target_dict:dict, dataset, check_level=0):

def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes,
pred_dict: dict, target_dict: dict, dataset, check_level=0):
errs = []
unuseds = []
_unused_field = []
@@ -268,8 +276,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
f"target is {list(target_dict.keys())}).")
if _miss_out_dataset:
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
if _unused_field:
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
suggestions.append(_tmp)
@@ -277,15 +285,15 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")

if check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds)

if len(errs)>0:
if len(errs) > 0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
sugg_str = ""
if len(suggestions)>1:
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
else:
@@ -332,10 +340,10 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)

if len(errs)>0:
if len(errs) > 0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
sugg_str = ""
if len(suggestions)>1:
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
else:
@@ -357,11 +365,11 @@ def seq_lens_to_masks(seq_lens, float=True):
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length)
"""
if isinstance(seq_lens, np.ndarray):
assert len(np.shape(seq_lens))==1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}."
raise NotImplemented
elif isinstance(seq_lens, torch.LongTensor):
assert len(seq_lens.size())==1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
@@ -376,3 +384,54 @@ def seq_lens_to_masks(seq_lens, float=True):
else:
raise NotImplemented


def seq_mask(seq_len, max_len):
"""Create sequence mask.

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]


def _relocate_pbar(pbar:tqdm, print_str:str):
"""

When using tqdm, you cannot print. If you print, the tqdm will duplicate. By using this function, print_str will
show above tqdm.
:param pbar: tqdm
:param print_str:
:return:
"""

params = ['desc', 'total', 'leave', 'file', 'ncols', 'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable',
'unit', 'unit_scale', 'dynamic_ncols', 'smoothing', 'bar_format', 'initial', 'position', 'postfix', 'unit_divisor',
'gui']

attr_map = {'file': 'fp', 'initial':'n', 'position':'pos'}

param_dict = {}
for param in params:
attr_name = param
if param in attr_map:
attr_name = attr_map[param]
value = getattr(pbar, attr_name)
if attr_name == 'pos':
value = abs(value)
param_dict[param] = value

pbar.close()
avg_time = pbar.avg_time
start_t = pbar.start_t
print(print_str)
pbar = tqdm(**param_dict)
pbar.start_t = start_t
pbar.avg_time = avg_time
pbar.sp(pbar.__repr__())
return pbar

+ 3
- 3
fastNLP/io/embed_loader.py View File

@@ -105,9 +105,9 @@ class EmbedLoader(BaseLoader):

if np.sum(hit_flags) < len(vocab):
# some words from vocab are missing in pre-trained embedding
# we normally sample them
# we normally sample each dimension
vocab_embed = embedding_matrix[np.where(hit_flags)]
mean, cov = vocab_embed.mean(axis=0), np.cov(vocab_embed.T)
sampled_vectors = np.random.multivariate_normal(mean, cov, size=(len(vocab) - np.sum(hit_flags),))
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
size=(len(vocab) - np.sum(hit_flags), emb_dim))
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors
return embedding_matrix

+ 1
- 1
fastNLP/modules/encoder/char_embedding.py View File

@@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module):
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
y = torch.squeeze(y, 2)
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
y = F.tanh(y)
y = torch.tanh(y)
y, __ = torch.max(y, 2)
# [batch_size*sent_length, feature_maps[i]]
feats.append(y)


+ 61
- 0
test/core/test_dataset.py View File

@@ -44,6 +44,9 @@ class TestDataSet(unittest.TestCase):
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)

with self.assertRaises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40)

def test_delete_field(self):
dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10)
@@ -65,8 +68,66 @@ class TestDataSet(unittest.TestCase):
self.assertTrue(isinstance(sub_ds, DataSet))
self.assertEqual(len(sub_ds), 10)

def test_get_item_error(self):
with self.assertRaises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds[40:]

with self.assertRaises(KeyError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds["kom"]

def test_len_(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertEqual(len(ds), 40)

ds = DataSet()
self.assertEqual(len(ds), 0)

def test_apply(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")
self.assertTrue("rx" in ds.field_arrays)
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])

def test_contains(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds)
self.assertTrue("y" in ds)
self.assertFalse("z" in ds)

def test_rename_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.rename_field("x", "xx")
self.assertTrue("xx" in ds)
self.assertFalse("x" in ds)

with self.assertRaises(KeyError):
ds.rename_field("yyy", "oo")

def test_input_target(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.set_input("x")
ds.set_target("y")
self.assertTrue(ds.field_arrays["x"].is_input)
self.assertTrue(ds.field_arrays["y"].is_target)

with self.assertRaises(KeyError):
ds.set_input("xxx")
with self.assertRaises(KeyError):
ds.set_input("yyy")

def test_get_input_name(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input])

def test_get_target_name(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target])


class TestDataSetIter(unittest.TestCase):
def test__repr__(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
for iter in ds:
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")

+ 6
- 0
test/core/test_instance.py View File

@@ -27,3 +27,9 @@ class TestCase(unittest.TestCase):
self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1])

def test_repr(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields)
# simple print, that is enough.
print(ins)

+ 16
- 24
test/core/test_loss.py View File

@@ -271,40 +271,32 @@ class TestLoss(unittest.TestCase):
loss3 = get_loss_3({'predict': predict}, {'truth': truth})
assert loss1 == loss2 and loss1 == loss3

"""
get_loss_4 = LossFunc(func4)
loss4 = get_loss_4({'a': 1, 'b': 3}, {})
print(loss4)
assert loss4 == (1 + 3) * 2

get_loss_5 = LossFunc(func4)
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4})
print(loss5)
assert loss5 == (1 + 3) * 4

get_loss_6 = LossFunc(func6)
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4})
print(loss6)
assert loss6 == (1 + 3) * 4

get_loss_7 = LossFunc(func6, c='cc')
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4})
print(loss7)
assert loss7 == (1 + 3) * 4
"""


class TestLoss_v2(unittest.TestCase):
def test_CrossEntropyLoss(self):
ce = loss.CrossEntropyLoss(input="my_predict", target="my_truth")
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth")
a = torch.randn(3, 5, requires_grad=False)
b = torch.empty(3, dtype=torch.long).random_(5)
ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))

def test_BCELoss(self):
bce = loss.BCELoss(input="my_predict", target="my_truth")
bce = loss.BCELoss(pred="my_predict", target="my_truth")
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False))
b = torch.randn((3, 5), requires_grad=False)
ans = bce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b))

def test_L1Loss(self):
l1 = loss.L1Loss(pred="my_predict", target="my_truth")
a = torch.randn(3, 5, requires_grad=False)
b = torch.randn(3, 5)
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b))

def test_NLLLoss(self):
l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
b = torch.tensor([1, 0, 4])
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))

+ 0
- 8
test/core/test_optimizer.py View File

@@ -11,9 +11,6 @@ class TestOptim(unittest.TestCase):
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"])

optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = SGD(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

@@ -25,17 +22,12 @@ class TestOptim(unittest.TestCase):
_ = SGD("???")
with self.assertRaises(RuntimeError):
_ = SGD(0.001, lr=0.002)
with self.assertRaises(RuntimeError):
_ = SGD(lr=0.009, shit=9000)

def test_Adam(self):
optim = Adam(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"])

optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)

optim = Adam(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)



+ 6
- 6
test/core/test_trainer.py View File

@@ -32,14 +32,14 @@ class TrainerTestGround(unittest.TestCase):
model = NaiveClassifier(2, 1)

trainer = Trainer(train_set, model,
losser=BCELoss(input="predict", target="y"),
losser=BCELoss(pred="predict", target="y"),
metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10,
batch_size=32,
update_every=1,
validate_every=-1,
validate_every=10,
dev_data=dev_set,
optimizer=SGD(0.1),
check_code_level=2
)
trainer.train()
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=True)
trainer.train()

+ 3
- 3
test/io/test_embed_loader.py View File

@@ -1,12 +1,12 @@
import unittest

from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.embed_loader import EmbedLoader


class TestEmbedLoader(unittest.TestCase):
def test_case(self):
vocab = Vocabulary()
vocab.update(["the", "in", "I", "to", "of", "hahaha"])
# TODO: np.cov在linux上segment fault,原因未知
# embedding = EmbedLoader().fast_load_embedding(50, "../data_for_tests/glove.6B.50d_test.txt", vocab)
# self.assertEqual(tuple(embedding.shape), (len(vocab), 50))
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab)
self.assertEqual(tuple(embedding.shape), (len(vocab), 50))

+ 8
- 12
test/test_tutorial.py View File

@@ -71,20 +71,16 @@ class TestTutorial(unittest.TestCase):

# 实例化Trainer,传入模型和数据,进行训练
copy_model = deepcopy(model)
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,
losser=CrossEntropyLoss(input="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
save_path="./save",
batch_size=4,
n_epochs=10)
overfit_trainer = Trainer(train_data=test_data, model=copy_model,
losser=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save")
overfit_trainer.train()

trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,
losser=CrossEntropyLoss(input="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
save_path="./save",
batch_size=4,
n_epochs=10)
trainer = Trainer(train_data=train_data, model=model,
losser=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save")
trainer.train()
print('Train finished!')



Loading…
Cancel
Save