|
|
@@ -73,7 +73,7 @@ class DistTrainer(): |
|
|
|
dev_data=None, metrics=None, metric_key=None, |
|
|
|
update_every=1, print_every=10, validate_every=-1, |
|
|
|
save_path=None, device='auto', |
|
|
|
fp16='', use_tqdm=True): |
|
|
|
fp16='', use_tqdm=True, **kwargs): |
|
|
|
r""" |
|
|
|
|
|
|
|
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 |
|
|
@@ -106,6 +106,9 @@ class DistTrainer(): |
|
|
|
:param str device: 指定 device,可以是 gpu,cpu 或 auto |
|
|
|
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 |
|
|
|
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 |
|
|
|
:param kwargs: 支持配置可选参数 |
|
|
|
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm |
|
|
|
Sampler test_sampler: 在evaluate的时候使用的sampler |
|
|
|
""" |
|
|
|
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" |
|
|
|
if device == 'auto': |
|
|
@@ -163,16 +166,23 @@ class DistTrainer(): |
|
|
|
self.model = self.ddp_model.module |
|
|
|
|
|
|
|
self.optimizer = optimizer |
|
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
|
if isinstance(self.train_data, DataSet): |
|
|
|
self.sampler = DistributedSampler(self.train_data) |
|
|
|
self.data_iterator = self._get_data_iter(self.train_data) |
|
|
|
self.batch_size = self.world_size * self.batch_size_per_gpu |
|
|
|
self.n_steps = self._get_n_steps() |
|
|
|
|
|
|
|
if 'test_use_tqdm' in kwargs: |
|
|
|
test_use_tqdm = kwargs.get('test_use_tqdm') |
|
|
|
else: |
|
|
|
test_use_tqdm = self.use_tqdm |
|
|
|
|
|
|
|
# for evaluation, only run eval on master proc |
|
|
|
if dev_data and metrics: |
|
|
|
cb = _TesterCallback( |
|
|
|
dev_data, model, metrics, |
|
|
|
batch_size=batch_size_per_gpu, num_workers=num_workers) |
|
|
|
batch_size=batch_size_per_gpu, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), |
|
|
|
use_tqdm=test_use_tqdm) |
|
|
|
self.test_manager.add_callback([cb], master=True) |
|
|
|
|
|
|
|
# Setup logging |
|
|
@@ -232,8 +242,10 @@ class DistTrainer(): |
|
|
|
elif optimizer is None: |
|
|
|
return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3) |
|
|
|
else: |
|
|
|
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) |
|
|
|
|
|
|
|
if not (hasattr(optimizer, 'step') and callable(optimizer.step)): |
|
|
|
raise TypeError("optimizer must have a callable step() function.") |
|
|
|
else: |
|
|
|
self.optimizer = optimizer |
|
|
|
@property |
|
|
|
def is_master(self): |
|
|
|
r"""是否是主进程""" |
|
|
|