Browse Source

尝试提供check parameter的功能

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
0836ce006f
3 changed files with 237 additions and 55 deletions
  1. +8
    -1
      fastNLP/core/fieldarray.py
  2. +198
    -49
      fastNLP/core/trainer.py
  3. +31
    -5
      fastNLP/core/utils.py

+ 8
- 1
fastNLP/core/fieldarray.py View File

@@ -47,7 +47,7 @@ class FieldArray(object):
assert self.is_input is True or self.is_target is True
batch_size = len(indices)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float):
if not isiterable(self.content[0]):
if self.dtype is None:
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double
array = np.array([self.content[i] for i in indices], dtype=self.dtype)
@@ -63,3 +63,10 @@ class FieldArray(object):

def __len__(self):
return len(self.content)

def isiterable(content):
try:
_ = (e for e in content)
except TypeError:
return False
return True

+ 198
- 49
fastNLP/core/trainer.py View File

@@ -1,5 +1,9 @@
import time
from datetime import timedelta, datetime
from datetime import timedelta
from datetime import datetime

import warnings
from collections import defaultdict

import torch
from tensorboardX import SummaryWriter
@@ -12,13 +16,17 @@ from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester

from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _syn_model_data
from fastNLP.core.utils import get_func_signature

class Trainer(object):
"""Main Training Loop

"""

def __init__(self, train_data, model, n_epochs, batch_size, n_print,
def __init__(self, train_data, model, n_epochs=1, batch_size=32, print_every=-1,
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
evaluator=Evaluator(),
@@ -32,7 +40,7 @@ class Trainer(object):
self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda)
self.save_path = str(save_path)
self.n_print = int(n_print)
self.print_every = int(print_every)

self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get()
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
@@ -51,7 +59,7 @@ class Trainer(object):
self.step = 0
self.start_time = None # start timestamp

print(self.__dict__)
# print(self.__dict__)

def train(self):
"""Start Training.
@@ -70,17 +78,16 @@ class Trainer(object):
epoch = 1
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler())

self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start, self.n_print)
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)

if self.dev_data:
self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1

def _train_epoch(self, data_iterator, model, epoch, dev_data, start, n_print, **kwargs):
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
"""Training process in one epoch.

kwargs should contain:
@@ -103,7 +110,7 @@ class Trainer(object):
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0:
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
@@ -197,9 +204,6 @@ def best_eval_result(self, metrics):
return False


from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args

DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2

@@ -207,64 +211,209 @@ IGNORE_CHECK_LEVEL=0
WARNING_CHECK_LEVEL=1
STRICT_CHECK_LEVEL=2

def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
# check loss 方法
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1):
# check get_loss 方法
model_name = model.__class__.__name__
if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(type(model)))
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name))

batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size)
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
# forward check
if batch_count==0:
check_res = _check_arg_dict_list(model.forward, batch_x)
_info_str = ''
if len(check_res.missing)>0:
if check_level == WARNING_CHECK_LEVEL:
for field_name in check_res.missing:
if hasattr(dataset, field_name):
_info_str += "{} "
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n"
_info_str += ""
print("")
if len(check_res.unused)>0:
if check_level == WARNING_CHECK_LEVEL:
_info_str += ""
_check_forward_error(model=model, model_func=model.forward, check_level=check_level,
batch_x=batch_x)

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)

assert isinstance(output, dict), "The return value of {}.forward() should be dict.".format(model_name)

# loss check
if batch_count == 0:
_dict = _check_arg_dict_list(model.loss, [output, batch_y])
if len(_dict)!=0:
pass
loss_input = _build_args(model.loss, **output, **batch_y)
loss = model.loss(**loss_input)
if batch_count == 0:
if isinstance(loss, torch.Tensor):
pass
_check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)

# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.".
format(model_name, type(loss)))
if len(loss.size())!=0:
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format(
model_name, loss.size()
))
loss.backward()

if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE:
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
break
if check_level > IGNORE_CHECK_LEVEL:
print('Finish checking training process.', flush=True)


dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
if dev_data is not None:
if not hasattr(model, 'evaluate'):
raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set"
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set"
"dev_data to 'None'."
.format(type(model), type(model)))
.format(model_name))
outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y)

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
for k, v in output.items():
outputs[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break
_check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths)
print("Finish checking evaluate process.", flush=True)


def _check_forward_error(model, model_func, check_level, batch_x):
check_res = _check_arg_dict_list(model_func, batch_x)
_missing = ''
_unused = ''
signature_str = get_func_signature(model_func)
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1])
if len(check_res.missing)!=0:
_missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing,
list(batch_x.keys()))
if len(check_res.unused)!=0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if _missing:
if not _unused and STRICT_CHECK_LEVEL:
_error_str = "(1).{} (2).{}".format(_missing, _unused)
else:
_error_str = _missing
# TODO 这里可能需要自定义一些Error类型
raise TypeError(_error_str)
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
raise ValueError(_unused)
elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused, )

def _check_loss_evaluate(model, model_func, check_level, output, batch_y):
check_res = _check_arg_dict_list(model_func, [output, batch_y])
_missing = ''
_unused = ''
_duplicated = ''
signature_str = get_func_signature(model_func)
func_signature = "{}.{}(self, {})".format(model.__class__.__name__, model_func.__name__, signature_str[1:-1])
forward_func_signature = "{}.forward(self, {})".format(model.__class__.__name__, signature_str[1:-1])
model_name = model.__class__.__name__
if len(check_res.missing)>0:
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \
"{}." \
.format(func_signature, check_res.missing,
list(output.keys()), model_name,
list(batch_y.keys()))
if len(check_res.unused)>0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 1:
_duplicated = "Duplicated keys: {} are detected in function {}. Don't set {} as target and output " \
"them in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
forward_func_signature)
else:
_duplicated = "Duplicated key: {} is detected in function {}. Don't set {} as target and output " \
"it in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
forward_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
if _number_errs > 0:
_error_str = ''
if _number_errs > 1:
count = 1
if _missing:
_error_str += '({}).{}'.format(count, _missing)
count += 1
if _duplicated:
_error_str += '({}).{}'.format(count, _duplicated)
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_str += '({}).{}'.format(count, _unused)
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
_error_str = _unused
elif check_level == WARNING_CHECK_LEVEL:
_unused = _unused.strip()
warnings.warn(_unused)
else:
_error_str = _missing + _duplicated
if _error_str:
raise ValueError(_error_str)


if __name__ == '__main__':
import torch
from torch import nn
from fastNLP.core.dataset import DataSet
import numpy as np

class Model(nn.Module):
def __init__(self):
super().__init__()

self. fc1 = nn.Linear(10, 2)

def forward(self, words, chars):
output = {}
output['prediction'] = torch.randn(3, 4)
output['words'] = words
return output

def get_loss(self, prediction, labels, words):
return torch.mean(self.fc1.weight)

def evaluate(self, prediction, labels, demo=2):
return 0

model = Model()

num_samples = 4
fake_data_dict = {'words': np.random.randint(num_samples, size=(4, 3)), 'chars': np.random.randn(num_samples, 6),
'labels': np.random.randint(2, size=(num_samples,))}


dataset = DataSet(fake_data_dict)
dataset.set_input(words=True, chars=True)
dataset.set_target(labels=True)

for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
if batch_count == 0:
_dict = _check_arg_dict_list(model.evaluate, [output, batch_y])
# trainer = Trainer(dataset, model)

if len(_dict)!=0:
pass
refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
_check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2)

# _check_forward_error(model=model, model_func=model.forward, check_level=1,
# batch_x=fake_data_dict)

# import inspect
# print(inspect.getfullargspec(model.forward))





+ 31
- 5
fastNLP/core/utils.py View File

@@ -4,7 +4,7 @@ import inspect
from collections import namedtuple
from collections import Counter

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=True)
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)


def save_pickle(obj, pickle_path, file_name):
@@ -55,8 +55,11 @@ def _build_args(func, **kwargs):
if spect.varkw is not None:
return kwargs
needed_args = set(spect.args)
start_idx = len(spect.args) - len(spect.defaults)
output = {name: default for name, default in zip(spect.args[start_idx:], spect.defaults)}
defaults = []
if spect.defaults is not None:
defaults = [arg for arg in spect.defaults]
start_idx = len(spect.args) - len(defaults)
output = {name: default for name, default in zip(spect.args[start_idx:], defaults)}
output.update({name: val for name, val in kwargs.items() if name in needed_args})
return output

@@ -71,8 +74,11 @@ def _check_arg_dict_list(func, args):
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func)
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs)
all_args = set(spect.args)
start_idx = len(spect.args) - len(spect.defaults)
all_args = set([arg for arg in spect.args if arg!='self'])
defaults = []
if spect.defaults is not None:
defaults = [arg for arg in spect.defaults]
start_idx = len(spect.args) - len(defaults)
default_args = set(spect.args[start_idx:])
require_args = all_args - default_args
input_arg_count = Counter()
@@ -87,3 +93,23 @@ def _check_arg_dict_list(func, args):
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args))

def get_func_signature(func):
# function signature, does not include self.
signature = inspect.signature(func)
signature_str = str(signature)
return signature_str


# move data to model's device
import torch
def _syn_model_data(model, *args):
assert len(model.state_dict())!=0, "This model has no parameter."
device = model.parameters().__next__().device
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise ValueError("Only support dict type right now.")

Loading…
Cancel
Save