[new] 兼容pytorch的DataLoader,替换Batch为DataSetItertags/v0.4.10
@@ -12,7 +12,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||||
__all__ = [ | __all__ = [ | ||||
"Instance", | "Instance", | ||||
"FieldArray", | "FieldArray", | ||||
"Batch", | |||||
"DataSetIter", | |||||
"BatchIter", | |||||
"TorchLoaderIter", | |||||
"Vocabulary", | "Vocabulary", | ||||
"DataSet", | "DataSet", | ||||
"Const", | "Const", | ||||
@@ -14,7 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||||
介绍core 的子模块的分工,好像必要性不大 | 介绍core 的子模块的分工,好像必要性不大 | ||||
""" | """ | ||||
from .batch import Batch | |||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | ||||
from .const import Const | from .const import Const | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
@@ -3,7 +3,9 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"Batch" | |||||
"BatchIter", | |||||
"DataSetIter", | |||||
"TorchLoaderIter", | |||||
] | ] | ||||
import atexit | import atexit | ||||
@@ -12,9 +14,11 @@ from queue import Empty, Full | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
import torch.utils.data | |||||
from numbers import Number | from numbers import Number | ||||
from .sampler import RandomSampler | |||||
from .sampler import SequentialSampler | |||||
from .dataset import DataSet | |||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -27,162 +31,151 @@ def _set_python_is_exit(): | |||||
atexit.register(_set_python_is_exit) | atexit.register(_set_python_is_exit) | ||||
class Batch(object): | |||||
""" | |||||
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch` | |||||
Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||||
组成 `x` 和 `y`:: | |||||
batch = Batch(data_set, batch_size=16, sampler=SequentialSampler()) | |||||
num_batch = len(batch) | |||||
for batch_x, batch_y in batch: | |||||
# do stuff ... | |||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||||
:param int batch_size: 取出的batch大小 | |||||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.RandomSampler`. | |||||
Default: ``None`` | |||||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||||
Default: ``False`` | |||||
:param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch. | |||||
Default: ``False`` | |||||
""" | |||||
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | |||||
class DataSetGetter: | |||||
def __init__(self, dataset: DataSet, as_numpy=False): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | |||||
if sampler is None: | |||||
sampler = RandomSampler() | |||||
self.sampler = sampler | |||||
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} | |||||
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} | |||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = None | |||||
self.curidx = 0 | |||||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||||
self.cur_batch_indices = None | |||||
self.prefetch = prefetch | |||||
self.lengths = 0 | |||||
def fetch_one(self): | |||||
if self.curidx >= len(self.idx_list): | |||||
return None | |||||
else: | |||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||||
batch_x, batch_y = {}, {} | |||||
indices = self.idx_list[self.curidx:endidx] | |||||
self.cur_batch_indices = indices | |||||
for field_name, field in self.dataset.get_all_fields().items(): | |||||
if field.is_target or field.is_input: | |||||
batch = field.get(indices) | |||||
if not self.as_numpy and \ | |||||
field.dtype is not None and \ | |||||
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||||
batch = _to_tensor(batch) | |||||
if field.is_target: | |||||
batch_y[field_name] = batch | |||||
if field.is_input: | |||||
batch_x[field_name] = batch | |||||
self.curidx = endidx | |||||
return batch_x, batch_y | |||||
self.idx_list = list(range(len(dataset))) | |||||
def __getitem__(self, idx: int): | |||||
# mapping idx to sampled idx | |||||
idx = self.idx_list[idx] | |||||
inputs = {n:f.get(idx) for n, f in self.inputs.items()} | |||||
targets = {n:f.get(idx) for n, f in self.targets.items()} | |||||
return idx, inputs, targets | |||||
def __len__(self): | |||||
return len(self.dataset) | |||||
def collate_fn(self, batch: list): | |||||
batch_x = {n:[] for n in self.inputs.keys()} | |||||
batch_y = {n:[] for n in self.targets.keys()} | |||||
indices = [] | |||||
for idx, x, y in batch: | |||||
indices.append(idx) | |||||
for n, v in x.items(): | |||||
batch_x[n].append(v) | |||||
for n, v in y.items(): | |||||
batch_y[n].append(v) | |||||
def pad_batch(batch_dict, field_array): | |||||
for n, vlist in batch_dict.items(): | |||||
f = field_array[n] | |||||
if f.padder is None: | |||||
batch_dict[n] = np.array(vlist) | |||||
else: | |||||
data = f.pad(vlist) | |||||
if not self.as_numpy: | |||||
data, flag = _to_tensor(data, f.dtype) | |||||
batch_dict[n] = data | |||||
return batch_dict | |||||
return (indices, | |||||
pad_batch(batch_x, self.inputs), | |||||
pad_batch(batch_y, self.targets)) | |||||
def set_idx_list(self, idx_list): | |||||
if len(idx_list) != len(self.idx_list): | |||||
raise ValueError | |||||
self.idx_list = idx_list | |||||
class SamplerAdapter(torch.utils.data.Sampler): | |||||
def __init__(self, sampler, dataset): | |||||
self.sampler = sampler | |||||
self.dataset = dataset | |||||
def __iter__(self): | def __iter__(self): | ||||
""" | |||||
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process | |||||
:return: | |||||
""" | |||||
if self.prefetch: | |||||
return self._run_batch_iter(self) | |||||
def batch_iter(): | |||||
self.init_iter() | |||||
while 1: | |||||
res = self.fetch_one() | |||||
if res is None: | |||||
break | |||||
yield res | |||||
return batch_iter() | |||||
return iter(self.sampler(self.dataset)) | |||||
class BatchIter: | |||||
def __init__(self): | |||||
self.dataiter = None | |||||
self.num_batches = None | |||||
self.cur_batch_indices = None | |||||
self.batch_size = None | |||||
def init_iter(self): | def init_iter(self): | ||||
self.idx_list = self.sampler(self.dataset) | |||||
self.curidx = 0 | |||||
self.lengths = self.dataset.get_length() | |||||
pass | |||||
@staticmethod | |||||
def get_num_batches(num_samples, batch_size, drop_last): | |||||
num_batches = num_samples // batch_size | |||||
if not drop_last and (num_samples % batch_size > 0): | |||||
num_batches += 1 | |||||
return num_batches | |||||
def __iter__(self): | |||||
self.init_iter() | |||||
for indices, batch_x, batch_y in self.dataiter: | |||||
self.cur_batch_indices = indices | |||||
yield batch_x, batch_y | |||||
def get_batch_indices(self): | |||||
return self.cur_batch_indices | |||||
def __len__(self): | def __len__(self): | ||||
return self.num_batches | return self.num_batches | ||||
def get_batch_indices(self): | |||||
""" | |||||
取得当前batch在DataSet中所在的index下标序列 | |||||
:return list(int) indexes: 下标序列 | |||||
""" | |||||
return self.cur_batch_indices | |||||
@staticmethod | |||||
def _run_fetch(batch, q): | |||||
try: | |||||
global _python_is_exit | |||||
batch.init_iter() | |||||
# print('start fetch') | |||||
while 1: | |||||
res = batch.fetch_one() | |||||
# print('fetch one') | |||||
while 1: | |||||
try: | |||||
q.put(res, timeout=3) | |||||
break | |||||
except Full: | |||||
if _python_is_exit: | |||||
return | |||||
if res is None: | |||||
# print('fetch done, waiting processing') | |||||
break | |||||
# print('fetch exit') | |||||
except Exception as e: | |||||
q.put(e) | |||||
finally: | |||||
q.join() | |||||
@staticmethod | |||||
def _run_batch_iter(batch): | |||||
q = mp.JoinableQueue(maxsize=10) | |||||
fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q)) | |||||
fetch_p.daemon = True | |||||
fetch_p.start() | |||||
# print('fork fetch process') | |||||
while 1: | |||||
try: | |||||
res = q.get(timeout=1) | |||||
q.task_done() | |||||
# print('get fetched') | |||||
if res is None: | |||||
break | |||||
elif isinstance(res, Exception): | |||||
raise res | |||||
yield res | |||||
except Empty as e: | |||||
if fetch_p.is_alive(): | |||||
continue | |||||
else: | |||||
break | |||||
fetch_p.terminate() | |||||
fetch_p.join() | |||||
# print('iter done') | |||||
@property | |||||
def dataset(self): | |||||
return self.dataiter.dataset | |||||
class DataSetIter(BatchIter): | |||||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||||
num_workers=0, pin_memory=False, drop_last=False, | |||||
timeout=0, worker_init_fn=None): | |||||
super().__init__() | |||||
assert isinstance(dataset, DataSet) | |||||
dataset = DataSetGetter(dataset, as_numpy) | |||||
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | |||||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||||
self.dataiter = torch.utils.data.DataLoader( | |||||
dataset=dataset, batch_size=batch_size, sampler=sampler, | |||||
collate_fn=collate_fn, num_workers=num_workers, | |||||
pin_memory=pin_memory, drop_last=drop_last, | |||||
timeout=timeout, worker_init_fn=worker_init_fn) | |||||
self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) | |||||
self.batch_size = batch_size | |||||
class TorchLoaderIter(BatchIter): | |||||
def __init__(self, dataset): | |||||
super().__init__() | |||||
assert isinstance(dataset, torch.utils.data.DataLoader) | |||||
self.dataiter = dataset | |||||
self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) | |||||
self.batch_size = dataset.batch_size | |||||
def _to_tensor(batch): | |||||
class OnlineDataGettter: | |||||
# TODO | |||||
pass | |||||
class OnlineDataIter(BatchIter): | |||||
# TODO | |||||
def __init__(self, dataset, batch_size=1, buffer_size=10000, sampler=None, as_numpy=False, | |||||
num_workers=0, pin_memory=False, drop_last=False, | |||||
timeout=0, worker_init_fn=None, **kwargs): | |||||
super().__init__() | |||||
def _to_tensor(batch, field_dtype): | |||||
try: | try: | ||||
if issubclass(batch.dtype.type, np.floating): | |||||
batch = torch.as_tensor(batch).float() # 默认使用float32 | |||||
if field_dtype is not None \ | |||||
and issubclass(field_dtype, Number) \ | |||||
and not isinstance(batch, torch.Tensor): | |||||
if issubclass(batch.dtype.type, np.floating): | |||||
new_batch = torch.as_tensor(batch).float() # 默认使用float32 | |||||
else: | |||||
new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||||
return new_batch, True | |||||
else: | else: | ||||
batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||||
return batch, False | |||||
except: | except: | ||||
pass | |||||
return batch | |||||
return batch, False |
@@ -176,7 +176,12 @@ class FieldArray: | |||||
if self.padder is None or pad is False: | if self.padder is None or pad is False: | ||||
return np.array(contents) | return np.array(contents) | ||||
else: | else: | ||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||||
return self.pad(contents) | |||||
def pad(self, contents): | |||||
if self.padder is None: | |||||
raise RuntimeError | |||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
""" | """ | ||||
@@ -6,7 +6,7 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
from . import Batch | |||||
from . import DataSetIter | |||||
from . import DataSet | from . import DataSet | ||||
from . import SequentialSampler | from . import SequentialSampler | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -44,8 +44,7 @@ class Predictor(object): | |||||
self.network.eval() | self.network.eval() | ||||
batch_output = defaultdict(list) | batch_output = defaultdict(list) | ||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, | |||||
prefetch=False) | |||||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
if hasattr(self.network, "predict"): | if hasattr(self.network, "predict"): | ||||
predict_func = self.network.predict | predict_func = self.network.predict | ||||
@@ -37,7 +37,7 @@ import warnings | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from .batch import Batch | |||||
from .batch import BatchIter, DataSetIter | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .metrics import _prepare_metrics | from .metrics import _prepare_metrics | ||||
from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
@@ -82,7 +82,7 @@ class Tester(object): | |||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | ||||
""" | """ | ||||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | |||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
@@ -96,6 +96,14 @@ class Tester(object): | |||||
self._model = _move_model_to_device(model, device=device) | self._model = _move_model_to_device(model, device=device) | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.verbose = verbose | self.verbose = verbose | ||||
if isinstance(data, DataSet): | |||||
self.data_iterator = DataSetIter( | |||||
dataset=data, batch_size=batch_size, num_workers=num_workers) | |||||
elif isinstance(data, BatchIter): | |||||
self.data_iterator = data | |||||
else: | |||||
raise TypeError("data type {} not support".format(type(data))) | |||||
# 如果是DataParallel将没有办法使用predict方法 | # 如果是DataParallel将没有办法使用predict方法 | ||||
if isinstance(self._model, nn.DataParallel): | if isinstance(self._model, nn.DataParallel): | ||||
@@ -124,7 +132,7 @@ class Tester(object): | |||||
self._model_device = _get_model_device(self._model) | self._model_device = _get_model_device(self._model) | ||||
network = self._model | network = self._model | ||||
self._mode(network, is_test=True) | self._mode(network, is_test=True) | ||||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
data_iterator = self.data_iterator | |||||
eval_results = {} | eval_results = {} | ||||
try: | try: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
@@ -311,8 +311,9 @@ try: | |||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
except: | except: | ||||
from .utils import _pseudo_tqdm as tqdm | from .utils import _pseudo_tqdm as tqdm | ||||
import warnings | |||||
from .batch import Batch | |||||
from .batch import DataSetIter, BatchIter | |||||
from .callback import CallbackManager, CallbackException | from .callback import CallbackManager, CallbackException | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
@@ -320,7 +321,6 @@ from .metrics import _prepare_metrics | |||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
from .sampler import Sampler | from .sampler import Sampler | ||||
from .sampler import RandomSampler | from .sampler import RandomSampler | ||||
from .sampler import SequentialSampler | |||||
from .tester import Tester | from .tester import Tester | ||||
from .utils import _CheckError | from .utils import _CheckError | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -351,6 +351,8 @@ class Trainer(object): | |||||
:param int batch_size: 训练和验证的时候的batch大小。 | :param int batch_size: 训练和验证的时候的batch大小。 | ||||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | ||||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | ||||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | ||||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | ||||
:param int n_epochs: 需要优化迭代多少次。 | :param int n_epochs: 需要优化迭代多少次。 | ||||
@@ -367,7 +369,6 @@ class Trainer(object): | |||||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | ||||
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | ||||
保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | 保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | ||||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 | |||||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | ||||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | ||||
的计算位置进行管理。支持以下的输入: | 的计算位置进行管理。支持以下的输入: | ||||
@@ -394,16 +395,17 @@ class Trainer(object): | |||||
""" | """ | ||||
def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
batch_size=32, sampler=None, update_every=1, | |||||
n_epochs=10, print_every=5, | |||||
batch_size=32, sampler=None, drop_last=False, update_every=1, | |||||
num_workers=0, n_epochs=10, print_every=5, | |||||
dev_data=None, metrics=None, metric_key=None, | dev_data=None, metrics=None, metric_key=None, | ||||
validate_every=-1, save_path=None, | |||||
prefetch=False, use_tqdm=True, device=None, | |||||
callbacks=None, | |||||
check_code_level=0): | |||||
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, | |||||
callbacks=None, check_code_level=0): | |||||
if prefetch and num_workers==0: | |||||
num_workers = 1 | |||||
if prefetch: | |||||
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | |||||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | ||||
@@ -439,9 +441,22 @@ class Trainer(object): | |||||
# sampler check | # sampler check | ||||
if sampler is not None and not isinstance(sampler, Sampler): | if sampler is not None and not isinstance(sampler, Sampler): | ||||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | ||||
if sampler is None: | |||||
sampler = RandomSampler() | |||||
if isinstance(train_data, DataSet): | |||||
self.data_iterator = DataSetIter( | |||||
dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) | |||||
elif isinstance(train_data, BatchIter): | |||||
self.data_iterator = train_data | |||||
else: | |||||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||||
if check_code_level > -1: | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||||
# TODO 考虑不同的dataset类型怎么check | |||||
_check_code(data_iterator=self.data_iterator, | |||||
model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||||
metric_key=metric_key, check_level=check_code_level, | metric_key=metric_key, check_level=check_code_level, | ||||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | ||||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | ||||
@@ -460,8 +475,6 @@ class Trainer(object): | |||||
self.best_dev_epoch = None | self.best_dev_epoch = None | ||||
self.best_dev_step = None | self.best_dev_step = None | ||||
self.best_dev_perf = None | self.best_dev_perf = None | ||||
self.sampler = sampler if sampler is not None else RandomSampler() | |||||
self.prefetch = prefetch | |||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
@@ -493,7 +506,7 @@ class Trainer(object): | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True, on_exception='auto'): | def train(self, load_best_model=True, on_exception='auto'): | ||||
""" | """ | ||||
使用该函数使Trainer开始训练。 | 使用该函数使Trainer开始训练。 | ||||
@@ -572,8 +585,7 @@ class Trainer(object): | |||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
self.pbar = pbar | self.pbar = pbar | ||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
data_iterator = self.data_iterator | |||||
self.batch_per_epoch = data_iterator.num_batches | self.batch_per_epoch = data_iterator.num_batches | ||||
for epoch in range(1, self.n_epochs + 1): | for epoch in range(1, self.n_epochs + 1): | ||||
self.epoch = epoch | self.epoch = epoch | ||||
@@ -786,13 +798,14 @@ def _get_value_info(_dict): | |||||
return strs | return strs | ||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||||
def _check_code(data_iterator, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||||
dev_data=None, metric_key=None, | dev_data=None, metric_key=None, | ||||
check_level=0): | check_level=0): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_devcie = model.parameters().__next__().device | model_devcie = model.parameters().__next__().device | ||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
batch = data_iterator | |||||
dataset = data_iterator.dataset | |||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | for batch_count, (batch_x, batch_y) in enumerate(batch): | ||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | ||||
# forward check | # forward check | ||||
@@ -15,7 +15,7 @@ from ...io.file_utils import cached_path, _get_base_url | |||||
from ._bert import _WordBertModel | from ._bert import _WordBertModel | ||||
from typing import List | from typing import List | ||||
from ... import DataSet, Batch, SequentialSampler | |||||
from ... import DataSet, DataSetIter, SequentialSampler | |||||
from ...core.utils import _move_model_to_device, _get_model_device | from ...core.utils import _move_model_to_device, _get_model_device | ||||
@@ -234,7 +234,7 @@ class ContextualEmbedding(TokenEmbedding): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for index, dataset in enumerate(datasets): | for index, dataset in enumerate(datasets): | ||||
try: | try: | ||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), prefetch=False) | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
words = batch_x['words'].to(device) | words = batch_x['words'].to(device) | ||||
words_list = words.tolist() | words_list = words.tolist() | ||||
@@ -184,11 +184,8 @@ def train(path): | |||||
m.weight.requires_grad = True | m.weight.requires_grad = True | ||||
# Trainer | # Trainer | ||||
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||||
**train_args.data, | |||||
optimizer=fastNLP.Adam(**optim_args.data), | |||||
save_path=path, | |||||
trainer = Trainer(train_data=train_data, model=model, optimizer=fastNLP.Adam(**optim_args.data), loss=ParserLoss(), | |||||
dev_data=dev_data, metrics=ParserMetric(), metric_key='UAS', save_path=path, | |||||
callbacks=[MyCallback()]) | callbacks=[MyCallback()]) | ||||
# Start training | # Start training | ||||
@@ -89,11 +89,11 @@ def train(train_data_path, dev_data_path, checkpoint=None, save=None): | |||||
model = torch.load(checkpoint) | model = torch.load(checkpoint) | ||||
# call trainer to train | # call trainer to train | ||||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
target="truth", | |||||
seq_lens="word_seq_origin_len"), | |||||
dev_data=dev_data, metric_key="f", | |||||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) | |||||
trainer = Trainer(dataset, model, loss=None, n_epochs=20, print_every=10, dev_data=dev_data, | |||||
metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
target="truth", | |||||
seq_lens="word_seq_origin_len"), metric_key="f", save_path=save, | |||||
use_tqdm=True) | |||||
trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
# save model & pipeline | # save model & pipeline | ||||
@@ -149,14 +149,10 @@ def train(): | |||||
) if x.requires_grad and x.size(0) != len(word_v)] | ) if x.requires_grad and x.size(0) != len(word_v)] | ||||
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | ||||
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | {'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | ||||
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=loss, metrics=metric, metric_key=metric_key, | |||||
optimizer=torch.optim.Adam(optim_cfg), | |||||
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000, | |||||
device=device, | |||||
use_tqdm=False, prefetch=False, | |||||
save_path=g_args.log, | |||||
callbacks=[MyCallback()]) | |||||
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, | |||||
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, | |||||
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, | |||||
device=device, callbacks=[MyCallback()]) | |||||
trainer.train() | trainer.train() | ||||
tester = FN.Tester(data=test_data, model=model, metrics=metric, | tester = FN.Tester(data=test_data, model=model, metrics=metric, | ||||
@@ -70,19 +70,10 @@ test_data = preprocess_data(test_data, bert_dirs) | |||||
model = BertForNLI(bert_dir=bert_dirs) | model = BertForNLI(bert_dir=bert_dirs) | ||||
trainer = Trainer( | |||||
train_data=train_data, | |||||
model=model, | |||||
optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 12, | |||||
n_epochs=4, | |||||
print_every=-1, | |||||
dev_data=dev_data, | |||||
metrics=AccuracyMetric(), | |||||
metric_key='acc', | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1 | |||||
) | |||||
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data, | |||||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1) | |||||
trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
tester = Tester( | tester = Tester( | ||||
@@ -57,12 +57,8 @@ callbacks = [clipper] | |||||
# if pretrain: | # if pretrain: | ||||
# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | # fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | ||||
# callbacks.append(fixer) | # callbacks.append(fixer) | ||||
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, | |||||
batch_size=32, sampler=sampler, update_every=5, | |||||
n_epochs=3, print_every=5, | |||||
dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f', | |||||
validate_every=-1, save_path=None, | |||||
prefetch=True, use_tqdm=True, device=device, | |||||
callbacks=callbacks, | |||||
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, | |||||
update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), | |||||
metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, | |||||
check_code_level=0) | check_code_level=0) | ||||
trainer.train() | trainer.train() |
@@ -13,7 +13,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
} | } | ||||
如果paths为不合法的,将直接进行raise相应的错误 | 如果paths为不合法的,将直接进行raise相应的错误 | ||||
:param paths: 路径 | |||||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, | |||||
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(paths, str): | if isinstance(paths, str): | ||||
@@ -3,7 +3,7 @@ import unittest | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP import Batch | |||||
from fastNLP import DataSetIter | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import SequentialSampler | from fastNLP import SequentialSampler | ||||
@@ -57,7 +57,7 @@ class TestCase1(unittest.TestCase): | |||||
dataset = construct_dataset( | dataset = construct_dataset( | ||||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | ||||
dataset.set_target() | dataset.set_target() | ||||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
batch = DataSetIter(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
cnt = 0 | cnt = 0 | ||||
for _, _ in batch: | for _, _ in batch: | ||||
@@ -68,7 +68,7 @@ class TestCase1(unittest.TestCase): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | ||||
self.assertEqual(len(x["x"]), 4) | self.assertEqual(len(x["x"]), 4) | ||||
@@ -81,7 +81,7 @@ class TestCase1(unittest.TestCase): | |||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertEqual(x["x"].shape, (4, 4)) | self.assertEqual(x["x"].shape, (4, 4)) | ||||
self.assertEqual(y["y"].shape, (4, 4)) | self.assertEqual(y["y"].shape, (4, 4)) | ||||
@@ -91,7 +91,7 @@ class TestCase1(unittest.TestCase): | |||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertEqual(x["x"].shape, (4, 4)) | self.assertEqual(x["x"].shape, (4, 4)) | ||||
self.assertEqual(y["y"].shape, (4, 4)) | self.assertEqual(y["y"].shape, (4, 4)) | ||||
@@ -101,7 +101,7 @@ class TestCase1(unittest.TestCase): | |||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | self.assertTrue(isinstance(x["x"], torch.Tensor)) | ||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
@@ -113,7 +113,7 @@ class TestCase1(unittest.TestCase): | |||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | self.assertTrue(isinstance(x["x"], torch.Tensor)) | ||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
@@ -125,7 +125,7 @@ class TestCase1(unittest.TestCase): | |||||
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | self.assertTrue(isinstance(x["x"], torch.Tensor)) | ||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
@@ -137,7 +137,7 @@ class TestCase1(unittest.TestCase): | |||||
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | for x, y in iter: | ||||
print(x, y) | print(x, y) | ||||
@@ -146,7 +146,7 @@ class TestCase1(unittest.TestCase): | |||||
num_samples = 1000 | num_samples = 1000 | ||||
dataset = generate_fake_dataset(num_samples) | dataset = generate_fake_dataset(num_samples) | ||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
pass | pass | ||||
@@ -40,89 +40,50 @@ class TestCallback(unittest.TestCase): | |||||
def test_gradient_clip(self): | def test_gradient_clip(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=20, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)], check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_early_stop(self): | def test_early_stop(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=20, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.01), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[EarlyStopCallback(5)]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.01), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
callbacks=[EarlyStopCallback(5)], check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_lr_scheduler(self): | def test_lr_scheduler(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=optimizer, | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||||
trainer = Trainer(data_set, model, optimizer=optimizer, loss=BCELoss(pred="predict", target="y"), batch_size=32, | |||||
n_epochs=5, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))], | |||||
check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_KeyBoardInterrupt(self): | def test_KeyBoardInterrupt(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[ControlC(False)]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, callbacks=[ControlC(False)], | |||||
check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_LRFinder(self): | def test_LRFinder(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[LRFinder(len(data_set) // 32)]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, | |||||
callbacks=[LRFinder(len(data_set) // 32)], check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_TensorboardCallback(self): | def test_TensorboardCallback(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[TensorboardCallback("loss", "metric")]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
callbacks=[TensorboardCallback("loss", "metric")], check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_readonly_property(self): | def test_readonly_property(self): | ||||
@@ -141,16 +102,9 @@ class TestCallback(unittest.TestCase): | |||||
print(self.optimizer) | print(self.optimizer) | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=total_epochs, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[MyCallback()]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=total_epochs, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, callbacks=[MyCallback()], | |||||
check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
assert passed_epochs == list(range(1, total_epochs + 1)) | assert passed_epochs == list(range(1, total_epochs + 1)) |
@@ -46,18 +46,10 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
trainer = Trainer(train_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
n_epochs=10, | |||||
batch_size=32, | |||||
print_every=50, | |||||
validate_every=-1, | |||||
dev_data=dev_set, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=True, | |||||
save_path=None) | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
@@ -83,10 +75,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = Model() | model = Model() | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model) | |||||
""" | """ | ||||
# 应该获取到的报错提示 | # 应该获取到的报错提示 | ||||
NameError: | NameError: | ||||
@@ -116,12 +105,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'loss': loss} | return {'loss': loss} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
@@ -147,12 +131,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
trainer.train() | trainer.train() | ||||
def test_trainer_suggestion4(self): | def test_trainer_suggestion4(self): | ||||
@@ -175,12 +154,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
def test_trainer_suggestion5(self): | def test_trainer_suggestion5(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
@@ -203,12 +177,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'loss': loss} | return {'loss': loss} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
def test_trainer_suggestion6(self): | def test_trainer_suggestion6(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
@@ -233,14 +202,8 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
loss=CrossEntropyLoss(), | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2) | |||||
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | |||||
metrics=AccuracyMetric(), use_tqdm=False) | |||||
""" | """ | ||||
def test_trainer_multiprocess(self): | def test_trainer_multiprocess(self): | ||||
@@ -130,11 +130,8 @@ class ModelRunner(): | |||||
tester = Tester(data=data, model=model, metrics=metrics, | tester = Tester(data=data, model=model, metrics=metrics, | ||||
batch_size=BATCH_SIZE, verbose=0) | batch_size=BATCH_SIZE, verbose=0) | ||||
before_train = tester.test() | before_train = tester.test() | ||||
trainer = Trainer(model=model, train_data=data, dev_data=None, | |||||
n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, | |||||
loss=loss, | |||||
save_path=None, | |||||
use_tqdm=False) | |||||
trainer = Trainer(train_data=data, model=model, loss=loss, batch_size=BATCH_SIZE, n_epochs=N_EPOCHS, | |||||
dev_data=None, save_path=None, use_tqdm=False) | |||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||
after_train = tester.test() | after_train = tester.test() | ||||
for metric_name, v1 in before_train.items(): | for metric_name, v1 in before_train.items(): | ||||
@@ -60,10 +60,10 @@ class TestTutorial(unittest.TestCase): | |||||
print(test_data[0]) | print(test_data[0]) | ||||
# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | ||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.batch import DataSetIter | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||||
batch_iterator = DataSetIter(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||||
for batch_x, batch_y in batch_iterator: | for batch_x, batch_y in batch_iterator: | ||||
print("batch_x has: ", batch_x) | print("batch_x has: ", batch_x) | ||||
print("batch_y has: ", batch_y) | print("batch_y has: ", batch_y) | ||||
@@ -85,21 +85,14 @@ class TestTutorial(unittest.TestCase): | |||||
# 实例化Trainer,传入模型和数据,进行训练 | # 实例化Trainer,传入模型和数据,进行训练 | ||||
# 先在test_data拟合(确保模型的实现是正确的) | # 先在test_data拟合(确保模型的实现是正确的) | ||||
copy_model = deepcopy(model) | copy_model = deepcopy(model) | ||||
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||||
loss=loss, | |||||
metrics=metric, | |||||
save_path=None, | |||||
batch_size=32, | |||||
n_epochs=5) | |||||
overfit_trainer = Trainer(train_data=test_data, model=copy_model, loss=loss, batch_size=32, n_epochs=5, | |||||
dev_data=test_data, metrics=metric, save_path=None) | |||||
overfit_trainer.train() | overfit_trainer.train() | ||||
# 用train_data训练,在test_data验证 | # 用train_data训练,在test_data验证 | ||||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
save_path=None, | |||||
batch_size=32, | |||||
n_epochs=5) | |||||
trainer = Trainer(train_data=train_data, model=model, loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
batch_size=32, n_epochs=5, dev_data=test_data, | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), save_path=None) | |||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||
@@ -147,13 +140,8 @@ class TestTutorial(unittest.TestCase): | |||||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | ||||
trainer = Trainer(model=model, | |||||
train_data=train_data, | |||||
dev_data=dev_data, | |||||
loss=CrossEntropyLoss(), | |||||
optimizer= Adam(), | |||||
metrics=AccuracyMetric(target='target') | |||||
) | |||||
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(), loss=CrossEntropyLoss(), | |||||
dev_data=dev_data, metrics=AccuracyMetric(target='target')) | |||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||