From ded4228f93d1c90ac62b77ee74a05b657b37ec1b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 26 Apr 2019 22:34:05 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=A2=9E=E5=8A=A0=E5=AF=B9Trainer=E5=92=8CTe?= =?UTF-8?q?ster=E7=9A=84=E5=A4=9A=E5=8D=A1=E6=94=AF=E6=8C=81;?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 8 ++-- fastNLP/core/tester.py | 29 ++++++++----- fastNLP/core/trainer.py | 51 +++++++++++++--------- fastNLP/core/utils.py | 84 ++++++++++++++++++++++++++---------- test/core/test_metrics.py | 2 +- test/core/test_utils.py | 89 ++++++++++++++++++++++++++++++++++++++- 6 files changed, 205 insertions(+), 58 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 206904ca..938a67be 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -108,8 +108,8 @@ class MetricBase(object): 如果kwargs是self.evaluate的参数,则不会检测 - self.evaluate将计算一个批次(batch)的评价指标,并累计 - self.get_metric将统计当前的评价指标并返回评价结果 + self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值 + self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 """ def __init__(self): @@ -302,7 +302,7 @@ class AccuracyMetric(MetricBase): if seq_len is not None and not isinstance(seq_len, torch.Tensor): raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(seq_lens)}.") + f"got {type(seq_len)}.") if seq_len is not None: masks = seq_lens_to_masks(seq_lens=seq_len) @@ -320,7 +320,7 @@ class AccuracyMetric(MetricBase): target = target.to(pred) if masks is not None: - self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks, 0)).item() + self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() self.total += torch.sum(masks).item() else: self.acc_count += torch.sum(torch.eq(pred, target)).item() diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 6e3f98b5..c2aae37b 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -10,7 +10,8 @@ from fastNLP.core.utils import _build_args from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _get_func_signature -from fastNLP.core.utils import _get_device +from fastNLP.core.utils import _get_model_device +from fastNLP.core.utils import _move_model_to_device class Tester(object): @@ -57,9 +58,20 @@ class Tester(object): :param torch.nn.module model: 使用的模型 :param MetricBase metrics: 一个Metric或者一个列表的metric对象 :param int batch_size: evaluation时使用的batch_size有多大。 - :param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持 - 以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, - 可见的第二个GPU中; torch.device,将模型装载到torch.device上。 + :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 + 的计算位置进行管理。支持以下的输入: + + 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, + 可见的第二个GPU中; + + 2. torch.device:将模型装载到torch.device上。 + + 3. int: 将使用device_id为该值的gpu进行训练 + + 4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 + + 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 + :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 """ @@ -74,16 +86,10 @@ class Tester(object): self.metrics = _prepare_metrics(metrics) self.data = data - self.device = _get_device(device, check_exist=False) + self._model = _move_model_to_device(model, device=device) self.batch_size = batch_size self.verbose = verbose - if self.device is not None: - self._model = model.to(self.device) - else: - self._model = model - self._model_device = model.parameters().__next__().device - # check predict if hasattr(self._model, 'predict'): self._predict_func = self._model.predict @@ -101,6 +107,7 @@ class Tester(object): 一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 """ # turn on the testing mode; clean up the history + self._model_device = _get_model_device(self._model) network = self._model self._mode(network, is_test=True) data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 48733652..b6c282b4 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -321,11 +321,12 @@ from fastNLP.core.utils import _check_forward_error from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _get_func_signature -from fastNLP.core.utils import _get_device +from fastNLP.core.utils import _get_model_device from fastNLP.core.optimizer import Optimizer +from fastNLP.core.utils import _move_model_to_device class Trainer(object): - def __init__(self, train_data, model, optimizer, loss=None, + def __init__(self, train_data, model, optimizer=None, loss=None, batch_size=32, sampler=None, update_every=1, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, @@ -336,7 +337,7 @@ class Trainer(object): """ :param DataSet train_data: 训练集 :param nn.modules model: 待训练的模型 - :param torch.optim.Optimizer,None optimizer: 优化器。如果为None,则Trainer不会更新模型,请确保已在callback中进行了更新。 + :param torch.optim.Optimizer,None optimizer: 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 :param int batch_size: 训练和验证的时候的batch大小。 :param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。 :param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。 @@ -354,12 +355,23 @@ class Trainer(object): :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有 效。 :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模 - 型。保存的时候不仅保存了参数,还保存了模型结构。 + 型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 :param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 - :param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持 - 以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, - 可见的第二个GPU中; torch.device,将模型装载到torch.device上。 + :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 + 的计算位置进行管理。支持以下的输入: + + 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, + 可见的第二个GPU中; + + 2. torch.device:将模型装载到torch.device上。 + + 3. int: 将使用device_id为该值的gpu进行训练 + + 4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 + + 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 + :param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 通过callback机制实现。 可使用的callback参见 Callback_ 。 :param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用, @@ -432,17 +444,15 @@ class Trainer(object): self.n_steps = (len(self.train_data) // self.batch_size + int( len(self.train_data) % self.batch_size != 0)) * self.n_epochs - check_exist = check_code_level>-1 - self.device = _get_device(device, check_exist=check_exist) + # 是否一开始就是DataParallel的。 + self.model = _move_model_to_device(self.model, device=device) if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer elif isinstance(optimizer, Optimizer): self.optimizer = optimizer.construct_from_pytorch(model.parameters()) elif optimizer is None: - warnings.warn("The optimizer is set to None, Trainer will update your model. Make sure you update the model" - " in the callback.") - self.optimizer = None + self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) else: raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) @@ -455,7 +465,7 @@ class Trainer(object): data=self.dev_data, metrics=self.metrics, batch_size=self.batch_size, - device=self.device, + device=None, # 由上面的部分处理device verbose=0) self.step = 0 @@ -486,11 +496,9 @@ class Trainer(object): results['seconds'] = 0. return results try: - if self.device is not None: - self.model = self.model.to(self.device) - self._model_device = self.model.parameters().__next__().device + self._model_device = _get_model_device(self.model) self._mode(self.model, is_test=False) - + self._load_best_model = load_best_model self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) start_time = time.time() print("training epochs started " + self.start_time, flush=True) @@ -605,7 +613,7 @@ class Trainer(object): if self.save_path is not None: self._save_model(self.model, "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) - else: + elif self._load_best_model: self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} self.best_dev_perf = res self.best_dev_epoch = epoch @@ -672,6 +680,8 @@ class Trainer(object): model_path = os.path.join(self.save_path, model_name) if not os.path.exists(self.save_path): os.makedirs(self.save_path, exist_ok=True) + if isinstance(model, nn.DataParallel): + model = model.module if only_param: state_dict = model.state_dict() for key in state_dict: @@ -690,7 +700,10 @@ class Trainer(object): states = torch.load(model_path) else: states = torch.load(model_path).state_dict() - model.load_state_dict(states) + if isinstance(model, nn.DataParallel): + model.module.load_state_dict(states) + else: + model.load_state_dict(states) elif hasattr(self, "_best_model_states"): model.load_state_dict(self._best_model_states) else: diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index f34092df..efb4faa7 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -168,33 +168,73 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): # else: # return False -def _get_device(device, check_exist=False): +def _move_model_to_device(model, device): """ - 传入一个device,返回None或者torch.device。当不为None时,且被设置为使用gpu, 但机器没有gpu时,会返回torch.device('cpu') + 将model移动到device - :param str,None,torch.device device: str, None或者torch.device。 - :param bool check_exist: 检查该device是否存在,不存在的话报错 - :return: None,torch.device - """ - if device is not None: - if isinstance(device, str): - device = torch.device(device) - elif isinstance(device, torch.device): - device = device - else: - raise ValueError("device does not support {} type.".format(type(device))) + :param model: torch.nn.DataParallel or torch.nn.Module. 当为torch.nn.DataParallel, 则只是调用一次cuda。device必须为 + None。 + :param str,int,torch.device,list(int),list(torch.device) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 + 的计算位置进行管理。支持以下的输入: - if device.type=='cuda' and not torch.cuda.is_available(): - device = torch.device('cpu') + 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, + 可见的第二个GPU中; - if check_exist: - tensor = torch.zeros(0).to(device) - tensor = tensor.to('cpu') - del tensor - else: - device = None + 2. torch.device:将模型装载到torch.device上。 + + 3. int: 将使用device_id为该值的gpu进行训练 - return device + 4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 + + 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 + + :return: torch.nn.DataParallel or torch.nn.Module + """ + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") + + if not torch.cuda.is_available() and (device!='cpu' or (isinstance(device, torch.device) and device.type!='cpu')): + raise ValueError("There is no usable gpu. set `device` as `cpu`.") + + if device is None: + if isinstance(model, torch.nn.DataParallel): + model.cuda() + return model + + if isinstance(model, torch.nn.DataParallel): + raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") + + if isinstance(device, int): + assert device>-1, "device can only be positive integer" + assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(), + device) + device = torch.device('cuda:{}'.format(device)) + elif isinstance(device, str): + device = torch.device(device) + if device.type == 'cuda' and device.index is not None: + assert device.index-1, "Only positive device id allowed." + if len(device)>1: + output_device = device[0] + model = nn.DataParallel(model, device_ids=device, output_device=output_device) + device = torch.device(device[0]) + else: + raise TypeError("Unsupported device type.") + model = model.to(device) + return model def _get_model_device(model): diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index a0e8f1a5..3f37c495 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -123,7 +123,7 @@ class TestAccuracyMetric(unittest.TestCase): # (10) check _fast_metric try: metric = AccuracyMetric() - pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)} + pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3)*3} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 1}) diff --git a/test/core/test_utils.py b/test/core/test_utils.py index 11bb0f22..33202364 100644 --- a/test/core/test_utils.py +++ b/test/core/test_utils.py @@ -7,6 +7,94 @@ from fastNLP import DataSet from fastNLP import Instance import time import os +import torch +from torch import nn +from fastNLP.core.utils import _move_model_to_device, _get_model_device + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.param = nn.Parameter(torch.zeros(0)) + +class TestMoveModelDeivce(unittest.TestCase): + def test_case1(self): + # 测试str + model = Model() + model = _move_model_to_device(model, 'cpu') + assert model.param.device == torch.device('cpu') + # 测试不存在的device报错 + with self.assertRaises(Exception): + _move_model_to_device(model, 'cpuu') + # 测试gpu + if torch.cuda.is_available(): + model = _move_model_to_device(model, 'cuda') + assert model.param.is_cuda + model = _move_model_to_device(model, 'cuda:0') + assert model.param.device == torch.device('cuda:0') + with self.assertRaises(Exception): + _move_model_to_device(model, 'cuda:1000') + + def test_case2(self): + # 测试使用int初始化 + model = Model() + if torch.cuda.is_available(): + model = _move_model_to_device(model, 0) + assert model.param.device == torch.device('cuda:0') + assert model.param.device==torch.device('cuda:0'), "The model should be in " + with self.assertRaises(Exception): + _move_model_to_device(model, 100) + with self.assertRaises(Exception): + _move_model_to_device(model, -1) + + def test_case3(self): + # 测试None + model = Model() + device = _get_model_device(model) + model = _move_model_to_device(model, None) + assert device==_get_model_device(model), "The device should not change." + if torch.cuda.is_available(): + model.cuda() + device = _get_model_device(model) + model = _move_model_to_device(model, None) + assert device==_get_model_device(model), "The device should not change." + + model = nn.DataParallel(model, device_ids=[0]) + _move_model_to_device(model, None) + with self.assertRaises(Exception): + _move_model_to_device(model, 'cpu') + + def test_case4(self): + # 测试传入list的内容 + model = Model() + device = ['cpu'] + with self.assertRaises(Exception): + _move_model_to_device(model, device) + if torch.cuda.is_available(): + device = [0] + _model = _move_model_to_device(model, device) + assert isinstance(_model, nn.DataParallel) + device = [torch.device('cuda:0'), torch.device('cuda:0')] + with self.assertRaises(Exception): + _model = _move_model_to_device(model, device) + if torch.cuda.device_count()>1: + device = [0, 1] + _model = _move_model_to_device(model, device) + assert isinstance(_model, nn.DataParallel) + device = ['cuda', 'cuda:1'] + with self.assertRaises(Exception): + _move_model_to_device(model, device) + + def test_case5(self): + # torch.device() + device = torch.device('cpu') + model = Model() + _move_model_to_device(model, device) + device = torch.device('cuda') + model = _move_model_to_device(model, device) + assert model.param.device == torch.device('cuda:0') + with self.assertRaises(Exception): + _move_model_to_device(model, torch.device('cuda:100')) + @cache_results('test/demo1.pkl') def process_data_1(embed_file, cws_train): @@ -20,7 +108,6 @@ def process_data_1(embed_file, cws_train): d.append(Instance(raw=line)) return embed, vocab, d - class TestCache(unittest.TestCase): def test_cache_save(self): try: