Browse Source

Merge branch 'trainer' of https://github.com/FengZiYjun/fastNLP into check

tags/v0.2.0^2
xuyige 6 years ago
parent
commit
84eb50a810
5 changed files with 326 additions and 70 deletions
  1. +23
    -0
      fastNLP/core/losses.py
  2. +128
    -1
      fastNLP/core/metrics.py
  3. +45
    -17
      fastNLP/core/tester.py
  4. +71
    -39
      fastNLP/core/trainer.py
  5. +59
    -13
      fastNLP/core/utils.py

+ 23
- 0
fastNLP/core/losses.py View File

@@ -73,6 +73,29 @@ class NewLoss(LossBase):
raise RuntimeError("") raise RuntimeError("")




class LossInForward(LossBase):
def __init__(self, loss_key='loss'):
super().__init__()

self.loss_key = loss_key

def get_loss(self, *args, **kwargs):
pass

def __call__(self, output_dict, predict_dict):
pass


def _prepare_losser(losser):
if losser is None:
losser = LossInForward()
return losser
elif isinstance(losser, LossBase):
return losser
else:
raise TypeError(f"Type of losser should be `fastNLP.LossBase`, got {type(losser)}")


def squash(predict, truth, **kwargs): def squash(predict, truth, **kwargs):
'''To reshape tensors in order to fit Loss functions in pytorch '''To reshape tensors in order to fit Loss functions in pytorch




+ 128
- 1
fastNLP/core/metrics.py View File

@@ -1,8 +1,136 @@

import warnings import warnings
import inspect


import numpy as np import numpy as np
import torch import torch


from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args

class MetricBase(object):
def __init__(self):
self.param_map = {} # key is param in function, value is input param.
self._checked = False

def evaluate(self, *args, **kwargs):
raise NotImplementedError

def _init_param_map(self, key_map, **kwargs):
self.param_map = {}
for key, value in key_map.items():
if isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
for key, value in kwargs.items():
if isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
def __call__(self, output_dict, target_dict, force_check=False):
"""
:param output_dict:
:param target_dict:
:return:
"""
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.")
# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping.
self._evaluate_args = func_args

# need to wrap inputs in dict.
mapped_output_dict = {}
mapped_target_dict = {}
for func_arg in self._evaluate_args:
input_arg = self.param_map[func_arg]
if input_arg in output_dict:
mapped_output_dict[func_arg] = output_dict[input_arg]
if input_arg in target_dict:
mapped_target_dict[func_arg] = target_dict[input_arg]

# check duplicated, unused, missing
if force_check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
self._reverse_param_map = {value:key for key, value in check_res.items()}
for key, value in check_res.items():
new_value = value.copy()
for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param]
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res)
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict)

metrics = self.evaluate(**refined_args)

if not isinstance(metrics, dict):
raise TypeError(f"The return value of {get_func_signature(self.evaluate)} must be `dict`, "
f"got {type(metrics)}.")
self._checked = True

return metrics





class CheckError(Exception):
def __init__(self, check_res):

err = ''
if check_res.missing:
err += f'Missing: {check_res.missing}\n'
if check_res.duplicated:
err += f'Duplicated: {check_res.duplicated}\n'
self.check_res = check_res

def __str__(self):
pass


class Metric(MetricBase):
def __init__(self, func, key_map, **kwargs):
super().__init__()
pass

def _prepare_metrics(metrics):
"""

Prepare list of Metric based on input
:param metrics:
:return: List[fastNLP.MetricBase]
"""
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, MetricBase):
_metrics.append(metric)
else:
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
elif isinstance(metrics, MetricBase):
_metrics = [metrics]
else:
raise TypeError("The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, got {}."
.format(type(metrics)))
return _metrics



class Evaluator(object): class Evaluator(object):
def __init__(self): def __init__(self):
@@ -17,7 +145,6 @@ class Evaluator(object):
""" """
raise NotImplementedError raise NotImplementedError



class ClassifyEvaluator(Evaluator): class ClassifyEvaluator(Evaluator):
def __init__(self): def __init__(self):
super(ClassifyEvaluator, self).__init__() super(ClassifyEvaluator, self).__init__()


+ 45
- 17
fastNLP/core/tester.py View File

@@ -2,42 +2,61 @@ import itertools
from collections import defaultdict from collections import defaultdict


import torch import torch
from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.metrics import _prepare_metrics


class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """


def __init__(self, data, model, batch_size=16, use_cuda=False):
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=0):
super(Tester, self).__init__() super(Tester, self).__init__()
self.use_cuda = use_cuda

if not isinstance(data, DataSet):
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")

self.metrics = _prepare_metrics(metrics)

# check predict
if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict
if not callable(self._predict_func):
_model_name = model.__class__.__name__
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
else:
self._predict_func = self._model

self.data = data self.data = data
self.batch_size = batch_size
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda() self._model = model.cuda()
else: else:
self._model = model self._model = model
if hasattr(self._model, 'predict'):
assert callable(self._model.predict)
self._predict_func = self._model.predict
else:
self._predict_func = self._model
assert hasattr(model, 'evaluate')
self._evaluator = model.evaluate
self.eval_history = [] # evaluation results of all batches
self.use_cuda = use_cuda
self.batch_size = batch_size
self.verbose = verbose

self._model_device = model.parameters().__next__().device



def test(self): def test(self):
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
network = self._model network = self._model
self.mode(network, is_test=True) self.mode(network, is_test=True)
self.eval_history.clear()
output, truths = defaultdict(list), defaultdict(list) output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False) data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False)


with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(network, batch_x) prediction = self.data_forward(network, batch_x)
assert isinstance(prediction, dict) assert isinstance(prediction, dict)
for k, v in prediction.items(): for k, v in prediction.items():
@@ -48,9 +67,13 @@ class Tester(object):
output[k] = itertools.chain(*v) output[k] = itertools.chain(*v)
for k, v in truths.items(): for k, v in truths.items():
truths[k] = itertools.chain(*v) truths[k] = itertools.chain(*v)
args = _build_args(self._evaluator, **output, **truths)
eval_results = self._evaluator(**args)
print("[tester] {}".format(self.print_eval_results(eval_results)))
eval_results = {}
for metric in self.metrics:
eval_result = metric(output, truths)
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result
if self.verbose >= 0:
print("[tester] \n{}".format(self.format_eval_results(eval_results)))
self.mode(network, is_test=False) self.mode(network, is_test=False)
return eval_results return eval_results


@@ -72,10 +95,15 @@ class Tester(object):
y = self._predict_func(**x) y = self._predict_func(**x)
return y return y


def print_eval_results(self, results):
def format_eval_results(self, results):
"""Override this method to support more print formats. """Override this method to support more print formats.


:param results: dict, (str: float) is (metrics name: value) :param results: dict, (str: float) is (metrics name: value)


""" """
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
_str = ''
for metric_name, metric_result in results.items():
_str += metric_name + '\n\t'
_str += ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
_str += '\n'
return _str

+ 71
- 39
fastNLP/core/trainer.py View File

@@ -7,6 +7,7 @@ from datetime import datetime
from datetime import timedelta from datetime import timedelta


import torch import torch
from torch import nn
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
@@ -16,23 +17,50 @@ from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _syn_model_data
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.dataset import DataSet

from fastNLP.core.losses import LossBase
from fastNLP.core.metrics import MetricBase
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics




class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save", dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs): **kwargs):
super(Trainer, self).__init__() super(Trainer, self).__init__()


if not isinstance(train_data, DataSet):
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")

# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# prepare evaluate
metrics = _prepare_metrics(metrics)
# prepare loss
losser = _prepare_losser(losser)

if need_check_code:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data)

self.train_data = train_data self.train_data = train_data
self.dev_data = dev_data # If None, No validation. self.dev_data = dev_data # If None, No validation.
self.model = model self.model = model
self.losser = losser
self.metrics = metrics
self.n_epochs = int(n_epochs) self.n_epochs = int(n_epochs)
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda) self.use_cuda = bool(use_cuda)
@@ -41,23 +69,19 @@ class Trainer(object):
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self._best_accuracy = 0 self._best_accuracy = 0


if need_check_code:
_check_code(dataset=train_data, model=model, dev_data=dev_data)
self._model_device = model.parameters().__next__().device

# TODO self._best_accuracy不能表现出当前的metric多种的情况


model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
self.loss_func = self.model.get_loss
if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
else: else:
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())


assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name)
self.evaluator = self.model.evaluate

if self.dev_data is not None: if self.dev_data is not None:
self.tester = Tester(model=self.model, self.tester = Tester(model=self.model,
data=self.dev_data, data=self.dev_data,
metrics=self.metrics,
batch_size=self.batch_size, batch_size=self.batch_size,
use_cuda=self.use_cuda) use_cuda=self.use_cuda)


@@ -118,8 +142,9 @@ class Trainer(object):
- epoch: int, - epoch: int,
""" """
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(model, batch_x) prediction = self.data_forward(model, batch_x)

loss = self.get_loss(prediction, batch_y) loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss) self.grad_backward(loss)
self.update() self.update()
@@ -169,6 +194,8 @@ class Trainer(object):
def data_forward(self, network, x): def data_forward(self, network, x):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y return y


def grad_backward(self, loss): def grad_backward(self, loss):
@@ -229,11 +256,11 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1 WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2 STRICT_CHECK_LEVEL = 2


def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None,
check_level=WARNING_CHECK_LEVEL):
# check get_loss 方法 # check get_loss 方法
model_name = model.__class__.__name__ model_name = model.__class__.__name__
if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name))


batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
@@ -246,23 +273,26 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x) output = model(**refined_batch_x)
func_signature = get_func_signature(model.forward) func_signature = get_func_signature(model.forward)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.")


# loss check # loss check
if batch_count == 0:
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss
# 需要保证output与batch_y是无歧义的?
# (1) output和batch_y长度为1
# (2) output和batch_y的key是和losser接受的完全一致
pass

loss = losser(output, batch_y)


# check loss output # check loss output
if batch_count == 0: if batch_count == 0:
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.".
format(model_name, type(loss)))
raise ValueError("The return value of {} should be `torch.Tensor`, but got `{}`.".
format(type(losser), type(loss)))
if len(loss.size())!=0: if len(loss.size())!=0:
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format(
model_name, loss.size()
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format(
type(losser), loss.size()
)) ))
loss.backward() loss.backward()
model.zero_grad() model.zero_grad()
@@ -270,26 +300,29 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
break break


if dev_data is not None: if dev_data is not None:
if not hasattr(model, 'evaluate'):
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set"
"dev_data to 'None'."
.format(model_name))
outputs, truths = defaultdict(list), defaultdict(list) outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
# TODO 这里修改为使用tester
tester = Tester(data=dataset, model=model, metrics=metrics, batch_size=batch_size, )

with torch.no_grad(): with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y) _syn_model_data(model, batch_x, batch_y)


if hasattr(model, 'predict'): if hasattr(model, 'predict'):
if not callable(model.predict):
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used "
f"for evaluation.")
refined_batch_x = _build_args(model.predict, **batch_x) refined_batch_x = _build_args(model.predict, **batch_x)
prev_func = model.predict prev_func = model.predict
output = prev_func(**refined_batch_x) output = prev_func(**refined_batch_x)
func_signature = get_func_signature(model.predict)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
else: else:
refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
prev_func = model.forward prev_func = model.forward
output = prev_func(**refined_batch_x) output = prev_func(**refined_batch_x)
func_signature = get_func_signature(prev_func)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`")
for k, v in output.items(): for k, v in output.items():
outputs[k].append(v) outputs[k].append(v)
for k, v in batch_y.items(): for k, v in batch_y.items():
@@ -297,16 +330,15 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break break
for k, v in outputs.items(): for k, v in outputs.items():
outputs[k] = itertools.chain(*v)
outputs[k] = tuple(itertools.chain(*v))
for k, v in truths.items(): for k, v in truths.items():
truths[k] = itertools.chain(*v)
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths)
refined_input = _build_args(model.evaluate, **outputs, **truths)
metrics = model.evaluate(**refined_input)
func_signature = get_func_signature(model.evaluate)
assert isinstance(metrics, dict), "The return value of {} should be dict.". \
format(func_signature)
truths[k] = tuple(itertools.chain(*v))
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug









def _check_forward_error(model_func, check_level, batch_x): def _check_forward_error(model_func, check_level, batch_x):


+ 59
- 13
fastNLP/core/utils.py View File

@@ -3,9 +3,8 @@ import inspect
import os import os
from collections import Counter from collections import Counter
from collections import namedtuple from collections import namedtuple

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)

from collections import defaultdict
import torch


def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file. """Save an object into a pickle file.
@@ -121,14 +120,35 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys()) input_args = set(input_arg_count.keys())
missing = list(require_args - input_args) missing = list(require_args - input_args)
unused = list(input_args - all_args) unused = list(input_args - all_args)
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args))

check_res = {}
check_res['missing'] = missing
check_res['unused'] = unused
check_res['duplicated'] = duplicated
check_res['required'] = list(require_args)
check_res['all_needed'] = list(all_args)

return check_res


def get_func_signature(func): def get_func_signature(func):
# can only be used in function or class method
"""

Given a function or method, return its signature.
For example:
(1) function
def func(a, b='a', *args):
xxxx
get_func_signature(func) # 'func(a, b='a', *args)'
(2) method
class Demo:
def __init__(self):
xxx
def forward(self, a, b='a', **args)
demo = Demo()
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)'
:param func: a function or a method
:return: str or None
"""
if inspect.ismethod(func): if inspect.ismethod(func):
class_name = func.__self__.__class__.__name__ class_name = func.__self__.__class__.__name__
signature = inspect.signature(func) signature = inspect.signature(func)
@@ -146,10 +166,16 @@ def get_func_signature(func):
return signature_str return signature_str




# move data to model's device
import torch
def _syn_model_data(model, *args): def _syn_model_data(model, *args):
assert len(model.state_dict())!=0, "This model has no parameter."
"""

move data to model's device, element in *args should be dict. This is a inplace change.
:param model:
:param args:
:return:
"""
if len(model.state_dict())==0:
raise ValueError("model has no parameter.")
device = model.parameters().__next__().device device = model.parameters().__next__().device
for arg in args: for arg in args:
if isinstance(arg, dict): if isinstance(arg, dict):
@@ -157,4 +183,24 @@ def _syn_model_data(model, *args):
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
arg[key] = value.to(device) arg[key] = value.to(device)
else: else:
raise ValueError("Only support dict type right now.")
raise TypeError("Only support `dict` type right now.")

def _move_dict_value_to_device(device, *args):
"""

move data to model's device, element in *args should be dict. This is a inplace change.
:param device: torch.device
:param args:
:return:
"""
if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`")

for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise TypeError("Only support `dict` type right now.")


Loading…
Cancel
Save