Browse Source

尝试提供check parameter的功能

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
0836ce006f
3 changed files with 237 additions and 55 deletions
  1. +8
    -1
      fastNLP/core/fieldarray.py
  2. +198
    -49
      fastNLP/core/trainer.py
  3. +31
    -5
      fastNLP/core/utils.py

+ 8
- 1
fastNLP/core/fieldarray.py View File

@@ -47,7 +47,7 @@ class FieldArray(object):
assert self.is_input is True or self.is_target is True assert self.is_input is True or self.is_target is True
batch_size = len(indices) batch_size = len(indices)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float):
if not isiterable(self.content[0]):
if self.dtype is None: if self.dtype is None:
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double self.dtype = np.int64 if isinstance(self.content[0], int) else np.double
array = np.array([self.content[i] for i in indices], dtype=self.dtype) array = np.array([self.content[i] for i in indices], dtype=self.dtype)
@@ -63,3 +63,10 @@ class FieldArray(object):


def __len__(self): def __len__(self):
return len(self.content) return len(self.content)

def isiterable(content):
try:
_ = (e for e in content)
except TypeError:
return False
return True

+ 198
- 49
fastNLP/core/trainer.py View File

@@ -1,5 +1,9 @@
import time import time
from datetime import timedelta, datetime
from datetime import timedelta
from datetime import datetime

import warnings
from collections import defaultdict


import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
@@ -12,13 +16,17 @@ from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester


from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _syn_model_data
from fastNLP.core.utils import get_func_signature


class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """


def __init__(self, train_data, model, n_epochs, batch_size, n_print,
def __init__(self, train_data, model, n_epochs=1, batch_size=32, print_every=-1,
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
evaluator=Evaluator(), evaluator=Evaluator(),
@@ -32,7 +40,7 @@ class Trainer(object):
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda) self.use_cuda = bool(use_cuda)
self.save_path = str(save_path) self.save_path = str(save_path)
self.n_print = int(n_print)
self.print_every = int(print_every)


self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get() self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get()
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
@@ -51,7 +59,7 @@ class Trainer(object):
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp


print(self.__dict__)
# print(self.__dict__)


def train(self): def train(self):
"""Start Training. """Start Training.
@@ -70,17 +78,16 @@ class Trainer(object):
epoch = 1 epoch = 1
while epoch <= self.n_epochs: while epoch <= self.n_epochs:


data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler())


self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start, self.n_print)
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)


if self.dev_data: if self.dev_data:
self.do_validation() self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time) self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1 epoch += 1


def _train_epoch(self, data_iterator, model, epoch, dev_data, start, n_print, **kwargs):
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
"""Training process in one epoch. """Training process in one epoch.


kwargs should contain: kwargs should contain:
@@ -103,7 +110,7 @@ class Trainer(object):
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0:
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time() end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"])) diff = timedelta(seconds=round(end - kwargs["start"]))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
@@ -197,9 +204,6 @@ def best_eval_result(self, metrics):
return False return False




from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args

DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2 DEFAULT_CHECK_NUM_BATCH = 2


@@ -207,64 +211,209 @@ IGNORE_CHECK_LEVEL=0
WARNING_CHECK_LEVEL=1 WARNING_CHECK_LEVEL=1
STRICT_CHECK_LEVEL=2 STRICT_CHECK_LEVEL=2


def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
# check loss 方法
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1):
# check get_loss 方法
model_name = model.__class__.__name__
if not hasattr(model, 'get_loss'): if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(type(model)))
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name))


batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size)
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
# forward check
if batch_count==0: if batch_count==0:
check_res = _check_arg_dict_list(model.forward, batch_x)
_info_str = ''
if len(check_res.missing)>0:
if check_level == WARNING_CHECK_LEVEL:
for field_name in check_res.missing:
if hasattr(dataset, field_name):
_info_str += "{} "
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n"
_info_str += ""
print("")
if len(check_res.unused)>0:
if check_level == WARNING_CHECK_LEVEL:
_info_str += ""
_check_forward_error(model=model, model_func=model.forward, check_level=check_level,
batch_x=batch_x)


refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x) output = model(**refined_batch_x)

assert isinstance(output, dict), "The return value of {}.forward() should be dict.".format(model_name)

# loss check
if batch_count == 0: if batch_count == 0:
_dict = _check_arg_dict_list(model.loss, [output, batch_y])
if len(_dict)!=0:
pass
loss_input = _build_args(model.loss, **output, **batch_y)
loss = model.loss(**loss_input)
if batch_count == 0:
if isinstance(loss, torch.Tensor):
pass
_check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)


# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.".
format(model_name, type(loss)))
if len(loss.size())!=0:
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format(
model_name, loss.size()
))
loss.backward() loss.backward()

if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE:
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
break break
if check_level > IGNORE_CHECK_LEVEL:
print('Finish checking training process.', flush=True)



dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
if dev_data is not None: if dev_data is not None:
if not hasattr(model, 'evaluate'): if not hasattr(model, 'evaluate'):
raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set"
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set"
"dev_data to 'None'." "dev_data to 'None'."
.format(type(model), type(model)))
.format(model_name))
outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y)

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
for k, v in output.items():
outputs[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break
_check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths)
print("Finish checking evaluate process.", flush=True)


def _check_forward_error(model, model_func, check_level, batch_x):
check_res = _check_arg_dict_list(model_func, batch_x)
_missing = ''
_unused = ''
signature_str = get_func_signature(model_func)
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1])
if len(check_res.missing)!=0:
_missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing,
list(batch_x.keys()))
if len(check_res.unused)!=0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if _missing:
if not _unused and STRICT_CHECK_LEVEL:
_error_str = "(1).{} (2).{}".format(_missing, _unused)
else:
_error_str = _missing
# TODO 这里可能需要自定义一些Error类型
raise TypeError(_error_str)
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
raise ValueError(_unused)
elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused, )

def _check_loss_evaluate(model, model_func, check_level, output, batch_y):
check_res = _check_arg_dict_list(model_func, [output, batch_y])
_missing = ''
_unused = ''
_duplicated = ''
signature_str = get_func_signature(model_func)
func_signature = "{}.{}(self, {})".format(model.__class__.__name__, model_func.__name__, signature_str[1:-1])
forward_func_signature = "{}.forward(self, {})".format(model.__class__.__name__, signature_str[1:-1])
model_name = model.__class__.__name__
if len(check_res.missing)>0:
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \
"{}." \
.format(func_signature, check_res.missing,
list(output.keys()), model_name,
list(batch_y.keys()))
if len(check_res.unused)>0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 1:
_duplicated = "Duplicated keys: {} are detected in function {}. Don't set {} as target and output " \
"them in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
forward_func_signature)
else:
_duplicated = "Duplicated key: {} is detected in function {}. Don't set {} as target and output " \
"it in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
forward_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
if _number_errs > 0:
_error_str = ''
if _number_errs > 1:
count = 1
if _missing:
_error_str += '({}).{}'.format(count, _missing)
count += 1
if _duplicated:
_error_str += '({}).{}'.format(count, _duplicated)
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_str += '({}).{}'.format(count, _unused)
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
_error_str = _unused
elif check_level == WARNING_CHECK_LEVEL:
_unused = _unused.strip()
warnings.warn(_unused)
else:
_error_str = _missing + _duplicated
if _error_str:
raise ValueError(_error_str)


if __name__ == '__main__':
import torch
from torch import nn
from fastNLP.core.dataset import DataSet
import numpy as np

class Model(nn.Module):
def __init__(self):
super().__init__()

self. fc1 = nn.Linear(10, 2)

def forward(self, words, chars):
output = {}
output['prediction'] = torch.randn(3, 4)
output['words'] = words
return output

def get_loss(self, prediction, labels, words):
return torch.mean(self.fc1.weight)

def evaluate(self, prediction, labels, demo=2):
return 0

model = Model()

num_samples = 4
fake_data_dict = {'words': np.random.randint(num_samples, size=(4, 3)), 'chars': np.random.randn(num_samples, 6),
'labels': np.random.randint(2, size=(num_samples,))}


dataset = DataSet(fake_data_dict)
dataset.set_input(words=True, chars=True)
dataset.set_target(labels=True)


for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
if batch_count == 0:
_dict = _check_arg_dict_list(model.evaluate, [output, batch_y])
# trainer = Trainer(dataset, model)


if len(_dict)!=0:
pass
refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
_check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2)


# _check_forward_error(model=model, model_func=model.forward, check_level=1,
# batch_x=fake_data_dict)


# import inspect
# print(inspect.getfullargspec(model.forward))








+ 31
- 5
fastNLP/core/utils.py View File

@@ -4,7 +4,7 @@ import inspect
from collections import namedtuple from collections import namedtuple
from collections import Counter from collections import Counter


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=True)
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)




def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
@@ -55,8 +55,11 @@ def _build_args(func, **kwargs):
if spect.varkw is not None: if spect.varkw is not None:
return kwargs return kwargs
needed_args = set(spect.args) needed_args = set(spect.args)
start_idx = len(spect.args) - len(spect.defaults)
output = {name: default for name, default in zip(spect.args[start_idx:], spect.defaults)}
defaults = []
if spect.defaults is not None:
defaults = [arg for arg in spect.defaults]
start_idx = len(spect.args) - len(defaults)
output = {name: default for name, default in zip(spect.args[start_idx:], defaults)}
output.update({name: val for name, val in kwargs.items() if name in needed_args}) output.update({name: val for name, val in kwargs.items() if name in needed_args})
return output return output


@@ -71,8 +74,11 @@ def _check_arg_dict_list(func, args):
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs) assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs)
all_args = set(spect.args)
start_idx = len(spect.args) - len(spect.defaults)
all_args = set([arg for arg in spect.args if arg!='self'])
defaults = []
if spect.defaults is not None:
defaults = [arg for arg in spect.defaults]
start_idx = len(spect.args) - len(defaults)
default_args = set(spect.args[start_idx:]) default_args = set(spect.args[start_idx:])
require_args = all_args - default_args require_args = all_args - default_args
input_arg_count = Counter() input_arg_count = Counter()
@@ -87,3 +93,23 @@ def _check_arg_dict_list(func, args):
duplicated=duplicated, duplicated=duplicated,
required=list(require_args), required=list(require_args),
all_needed=list(all_args)) all_needed=list(all_args))

def get_func_signature(func):
# function signature, does not include self.
signature = inspect.signature(func)
signature_str = str(signature)
return signature_str


# move data to model's device
import torch
def _syn_model_data(model, *args):
assert len(model.state_dict())!=0, "This model has no parameter."
device = model.parameters().__next__().device
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
if isinstance(value, torch.Tensor):
arg[key] = value.to(device)
else:
raise ValueError("Only support dict type right now.")

Loading…
Cancel
Save