Browse Source

check code修改

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
a3bf647713
2 changed files with 70 additions and 135 deletions
  1. +54
    -131
      fastNLP/core/trainer.py
  2. +16
    -4
      fastNLP/core/utils.py

+ 54
- 131
fastNLP/core/trainer.py View File

@@ -1,7 +1,11 @@
import time
rom datetime import timedelta, datetime
from datetime import timedelta
from datetime import datetime
import warnings
from collections import defaultdict
import os
import torch
import itertools

from tensorboardX import SummaryWriter

from fastNLP.core.batch import Batch
@@ -221,30 +225,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No

batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
if batch_count == 0:
check_res = _check_arg_dict_list(model.forward, batch_x)
_info_str = ''
if len(check_res.missing) > 0:
if check_level == WARNING_CHECK_LEVEL:
for field_name in check_res.missing:
if hasattr(dataset, field_name):
_info_str += "{} "
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n"
_info_str += ""
print("")
if len(check_res.unused) > 0:
if check_level == WARNING_CHECK_LEVEL:
_info_str += ""
_syn_model_data(model, batch_x, batch_y)
# forward check
if batch_count==0:
_check_forward_error(model_func=model.forward, check_level=check_level,
batch_x=batch_x)

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
signature_str = get_func_signature(model.forward)
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1])
func_signature = get_func_signature(model.forward)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)

# loss check
if batch_count == 0:
_check_loss(model=model, model_func=model.get_loss, check_level=check_level,
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)
@@ -276,32 +270,42 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y)

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
if hasattr(model, 'predict'):
refined_batch_x = _build_args(model.predict, **batch_x)
prev_func = model.predict
output = prev_func(**refined_batch_x)
func_signature = get_func_signature(model.predict)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
else:
refined_batch_x = _build_args(model.forward, **batch_x)
prev_func = model.forward
output = prev_func(**refined_batch_x)
for k, v in output.items():
outputs[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break
_check_loss(model=model, model_func=model.evaluate, check_level=check_level,
for k, v in outputs.items():
outputs[k] = itertools.chain(*v)
for k, v in truths.items():
truths[k] = itertools.chain(*v)
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths)
refined_input = _build_args(model.evaluate, **outputs, **truths)
metrics = model.evaluate(**refined_input)
signature_str = get_func_signature(model.evaluate)
func_signature = '{}.evaluate(self, {})'.format(model.__class__.__name__, signature_str[1:-1])
func_signature = get_func_signature(model.evaluate)
assert isinstance(metrics, dict), "The return value of {} should be dict.". \
format(func_signature)
if check_level > IGNORE_CHECK_LEVEL:
print("Finish checking evaluate process.", flush=True)


def _check_forward_error(model, model_func, check_level, batch_x):
def _check_forward_error(model_func, check_level, batch_x):
check_res = _check_arg_dict_list(model_func, batch_x)
_missing = ''
_unused = ''
signature_str = get_func_signature(model_func)
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1])
func_signature = get_func_signature(model_func)
if len(check_res.missing)!=0:
_missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing,
@@ -313,8 +317,8 @@ def _check_forward_error(model, model_func, check_level, batch_x):
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if _missing:
if not _unused and STRICT_CHECK_LEVEL:
_error_str = "(1).{} (2).{}".format(_missing, _unused)
if len(_unused)>0 and STRICT_CHECK_LEVEL:
_error_str = "(1).{}\n(2).{}".format(_missing, _unused)
else:
_error_str = _missing
# TODO 这里可能需要自定义一些Error类型
@@ -326,91 +330,19 @@ def _check_forward_error(model, model_func, check_level, batch_x):
elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused)

def _check_loss(model, model_func, check_level, output, batch_y):
check_res = _check_arg_dict_list(model_func, [output, batch_y])
_missing = ''
_unused = ''
_duplicated = ''
signature_str = get_func_signature(model_func)
model_name = model.__class__.__name__
model_func_name = model_func.__name__
func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1])
forward_signature_str = get_func_signature(model.forward)
forward_func_signature = "{}.forward(self, {})".format(model_name, forward_signature_str[1:-1])
if len(check_res.missing)>0:
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \
"{}." \
.format(func_signature, check_res.missing,
list(output.keys()), model_name,
list(batch_y.keys()))
if len(check_res.unused)>0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 1:
_duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \
"them in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
forward_func_signature)
else:
_duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \
"it in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
forward_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
if _number_errs > 0:
_error_str = ''
if _number_errs > 1:
count = 1
if _missing:
_error_str += '({}).{}'.format(count, _missing)
count += 1
if _duplicated:
_error_str += '({}).{}'.format(count, _duplicated)
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_str += '({}).{}'.format(count, _unused)
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
_error_str = _unused
elif check_level == WARNING_CHECK_LEVEL:
_unused = _unused.strip()
warnings.warn(_unused)
else:
_error_str = _missing + _duplicated
if _error_str:
raise ValueError(_error_str)

def _check_evaluate(model, model_func, check_level, output, batch_y):
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):

check_res = _check_arg_dict_list(model_func, [output, batch_y])
check_res = _check_arg_dict_list(func, [output, batch_y])
_missing = ''
_unused = ''
_duplicated = ''
signature_str = get_func_signature(model_func)
model_name = model.__class__.__name__
model_func_name = model_func.__name__
func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1])
if hasattr(model, 'predict'):
previous_func = model.predict
previous_func_name = 'predict'
else:
previous_func = model.forward
previous_func_name = 'forward'
previous_signature_str = get_func_signature(previous_func)
previous_func_signature = "{}.{}(self, {})".format(model_name, previous_func_name, previous_signature_str[1:-1])
func_signature = get_func_signature(func)
prev_func_signature = get_func_signature(prev_func)
if len(check_res.missing)>0:
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \
"{}." \
_missing = "Function {} misses argument {}, \n only provided with {}(from {}) and " \
"{}(from target in Dataset)." \
.format(func_signature, check_res.missing,
list(output.keys()), previous_func_signature,
list(output.keys()), prev_func_signature,
list(batch_y.keys()))
if len(check_res.unused)>0:
if len(check_res.unused) > 1:
@@ -424,40 +356,38 @@ def _check_evaluate(model, model_func, check_level, output, batch_y):
"them in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
previous_func_signature)
prev_func_signature)
else:
_duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \
"it in {} at the same time.\n".format(check_res.duplicated,
func_signature,
check_res.duplicated,
previous_func_signature)
prev_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
if _number_errs > 0:
_error_str = ''
_error_strs = []
if _number_errs > 1:
count = 1
if _missing:
_error_str += '({}).{}'.format(count, _missing)
_error_strs.append('({}).{}'.format(count, _missing))
count += 1
if _duplicated:
_error_str += '({}).{}'.format(count, _duplicated)
_error_strs.append('({}).{}'.format(count, _duplicated))
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_str += '({}).{}'.format(count, _unused)
_error_strs.append('({}).{}'.format(count, _unused))
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
_error_str = _unused
_error_strs.append(_unused)
elif check_level == WARNING_CHECK_LEVEL:
_unused = _unused.strip()
warnings.warn(_unused)
else:
_error_str = _missing + _duplicated
if _error_str:
raise ValueError(_error_str)


_error_strs = [_missing, _duplicated]
if _error_strs:
raise ValueError('\n'.join(_error_strs))


if __name__ == '__main__':
@@ -478,11 +408,12 @@ if __name__ == '__main__':
output['words'] = words
return output

def get_loss(self, prediction, labels, words):
def get_loss(self, prediction, labels, words, seq_lens):
return torch.mean(self.fc1.weight)

def evaluate(self, prediction, labels, demo=2):
return 0
return {}


model = Model()

@@ -493,7 +424,7 @@ if __name__ == '__main__':

dataset = DataSet(fake_data_dict)
dataset.set_input(words=True, chars=True)
dataset.set_target(labels=True)
dataset.set_target(labels=True, words=True)

# trainer = Trainer(dataset, model)

@@ -505,13 +436,5 @@ if __name__ == '__main__':
# import inspect
# print(inspect.getfullargspec(model.forward))

import numpy as np

a = [1, 3]
np.asarray(a)

import pandas
df = pandas.DataFrame(fake_data_dict)
df.infer_objects()



+ 16
- 4
fastNLP/core/utils.py View File

@@ -95,10 +95,22 @@ def _check_arg_dict_list(func, args):
all_needed=list(all_args))

def get_func_signature(func):
# function signature, does not include self.
signature = inspect.signature(func)
signature_str = str(signature)
return signature_str
# can only be used in function or class method
if inspect.ismethod(func):
class_name = func.__self__.__class__.__name__
signature = inspect.signature(func)
signature_str = str(signature)
if len(signature_str)>2:
_self = '(self, '
else:
_self = '(self'
signature_str = class_name + '.' + func.__name__ + _self + signature_str[1:]
return signature_str
elif inspect.isfunction(func):
signature = inspect.signature(func)
signature_str = str(signature)
signature_str = func.__name__ + signature_str
return signature_str


# move data to model's device


Loading…
Cancel
Save