|
|
@@ -1,7 +1,11 @@ |
|
|
|
import time |
|
|
|
rom datetime import timedelta, datetime |
|
|
|
from datetime import timedelta |
|
|
|
from datetime import datetime |
|
|
|
import warnings |
|
|
|
from collections import defaultdict |
|
|
|
import os |
|
|
|
import torch |
|
|
|
import itertools |
|
|
|
|
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
|
|
from fastNLP.core.batch import Batch |
|
|
@@ -221,30 +225,20 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
|
|
|
|
|
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
|
for batch_count, (batch_x, batch_y) in enumerate(batch): |
|
|
|
if batch_count == 0: |
|
|
|
check_res = _check_arg_dict_list(model.forward, batch_x) |
|
|
|
_info_str = '' |
|
|
|
if len(check_res.missing) > 0: |
|
|
|
if check_level == WARNING_CHECK_LEVEL: |
|
|
|
for field_name in check_res.missing: |
|
|
|
if hasattr(dataset, field_name): |
|
|
|
_info_str += "{} " |
|
|
|
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" |
|
|
|
_info_str += "" |
|
|
|
print("") |
|
|
|
if len(check_res.unused) > 0: |
|
|
|
if check_level == WARNING_CHECK_LEVEL: |
|
|
|
_info_str += "" |
|
|
|
_syn_model_data(model, batch_x, batch_y) |
|
|
|
# forward check |
|
|
|
if batch_count==0: |
|
|
|
_check_forward_error(model_func=model.forward, check_level=check_level, |
|
|
|
batch_x=batch_x) |
|
|
|
|
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
|
output = model(**refined_batch_x) |
|
|
|
signature_str = get_func_signature(model.forward) |
|
|
|
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) |
|
|
|
func_signature = get_func_signature(model.forward) |
|
|
|
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) |
|
|
|
|
|
|
|
# loss check |
|
|
|
if batch_count == 0: |
|
|
|
_check_loss(model=model, model_func=model.get_loss, check_level=check_level, |
|
|
|
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level, |
|
|
|
output=output, batch_y=batch_y) |
|
|
|
loss_input = _build_args(model.get_loss, **output, **batch_y) |
|
|
|
loss = model.get_loss(**loss_input) |
|
|
@@ -276,32 +270,42 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
|
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): |
|
|
|
_syn_model_data(model, batch_x, batch_y) |
|
|
|
|
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
|
output = model(**refined_batch_x) |
|
|
|
if hasattr(model, 'predict'): |
|
|
|
refined_batch_x = _build_args(model.predict, **batch_x) |
|
|
|
prev_func = model.predict |
|
|
|
output = prev_func(**refined_batch_x) |
|
|
|
func_signature = get_func_signature(model.predict) |
|
|
|
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) |
|
|
|
else: |
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
|
prev_func = model.forward |
|
|
|
output = prev_func(**refined_batch_x) |
|
|
|
for k, v in output.items(): |
|
|
|
outputs[k].append(v) |
|
|
|
for k, v in batch_y.items(): |
|
|
|
truths[k].append(v) |
|
|
|
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: |
|
|
|
break |
|
|
|
_check_loss(model=model, model_func=model.evaluate, check_level=check_level, |
|
|
|
for k, v in outputs.items(): |
|
|
|
outputs[k] = itertools.chain(*v) |
|
|
|
for k, v in truths.items(): |
|
|
|
truths[k] = itertools.chain(*v) |
|
|
|
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level, |
|
|
|
output=outputs, batch_y=truths) |
|
|
|
refined_input = _build_args(model.evaluate, **outputs, **truths) |
|
|
|
metrics = model.evaluate(**refined_input) |
|
|
|
signature_str = get_func_signature(model.evaluate) |
|
|
|
func_signature = '{}.evaluate(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) |
|
|
|
func_signature = get_func_signature(model.evaluate) |
|
|
|
assert isinstance(metrics, dict), "The return value of {} should be dict.". \ |
|
|
|
format(func_signature) |
|
|
|
if check_level > IGNORE_CHECK_LEVEL: |
|
|
|
print("Finish checking evaluate process.", flush=True) |
|
|
|
|
|
|
|
|
|
|
|
def _check_forward_error(model, model_func, check_level, batch_x): |
|
|
|
def _check_forward_error(model_func, check_level, batch_x): |
|
|
|
check_res = _check_arg_dict_list(model_func, batch_x) |
|
|
|
_missing = '' |
|
|
|
_unused = '' |
|
|
|
signature_str = get_func_signature(model_func) |
|
|
|
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) |
|
|
|
func_signature = get_func_signature(model_func) |
|
|
|
if len(check_res.missing)!=0: |
|
|
|
_missing = "Function {} misses {}, only provided with {}, " \ |
|
|
|
".\n".format(func_signature, check_res.missing, |
|
|
@@ -313,8 +317,8 @@ def _check_forward_error(model, model_func, check_level, batch_x): |
|
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
|
if _missing: |
|
|
|
if not _unused and STRICT_CHECK_LEVEL: |
|
|
|
_error_str = "(1).{} (2).{}".format(_missing, _unused) |
|
|
|
if len(_unused)>0 and STRICT_CHECK_LEVEL: |
|
|
|
_error_str = "(1).{}\n(2).{}".format(_missing, _unused) |
|
|
|
else: |
|
|
|
_error_str = _missing |
|
|
|
# TODO 这里可能需要自定义一些Error类型 |
|
|
@@ -326,91 +330,19 @@ def _check_forward_error(model, model_func, check_level, batch_x): |
|
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
|
warnings.warn(message=_unused) |
|
|
|
|
|
|
|
def _check_loss(model, model_func, check_level, output, batch_y): |
|
|
|
check_res = _check_arg_dict_list(model_func, [output, batch_y]) |
|
|
|
_missing = '' |
|
|
|
_unused = '' |
|
|
|
_duplicated = '' |
|
|
|
signature_str = get_func_signature(model_func) |
|
|
|
model_name = model.__class__.__name__ |
|
|
|
model_func_name = model_func.__name__ |
|
|
|
func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1]) |
|
|
|
forward_signature_str = get_func_signature(model.forward) |
|
|
|
forward_func_signature = "{}.forward(self, {})".format(model_name, forward_signature_str[1:-1]) |
|
|
|
if len(check_res.missing)>0: |
|
|
|
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ |
|
|
|
"{}." \ |
|
|
|
.format(func_signature, check_res.missing, |
|
|
|
list(output.keys()), model_name, |
|
|
|
list(batch_y.keys())) |
|
|
|
if len(check_res.unused)>0: |
|
|
|
if len(check_res.unused) > 1: |
|
|
|
_unused = "{} are not used ".format(check_res.unused) |
|
|
|
else: |
|
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
|
if len(check_res.duplicated)>0: |
|
|
|
if len(check_res.duplicated) > 1: |
|
|
|
_duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \ |
|
|
|
"them in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
func_signature, |
|
|
|
check_res.duplicated, |
|
|
|
forward_func_signature) |
|
|
|
else: |
|
|
|
_duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \ |
|
|
|
"it in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
func_signature, |
|
|
|
check_res.duplicated, |
|
|
|
forward_func_signature) |
|
|
|
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) |
|
|
|
if _number_errs > 0: |
|
|
|
_error_str = '' |
|
|
|
if _number_errs > 1: |
|
|
|
count = 1 |
|
|
|
if _missing: |
|
|
|
_error_str += '({}).{}'.format(count, _missing) |
|
|
|
count += 1 |
|
|
|
if _duplicated: |
|
|
|
_error_str += '({}).{}'.format(count, _duplicated) |
|
|
|
count += 1 |
|
|
|
if _unused and check_level == STRICT_CHECK_LEVEL: |
|
|
|
_error_str += '({}).{}'.format(count, _unused) |
|
|
|
else: |
|
|
|
if _unused: |
|
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
|
# TODO 这里可能需要自定义一些Error类型 |
|
|
|
_error_str = _unused |
|
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
|
_unused = _unused.strip() |
|
|
|
warnings.warn(_unused) |
|
|
|
else: |
|
|
|
_error_str = _missing + _duplicated |
|
|
|
if _error_str: |
|
|
|
raise ValueError(_error_str) |
|
|
|
|
|
|
|
def _check_evaluate(model, model_func, check_level, output, batch_y): |
|
|
|
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): |
|
|
|
|
|
|
|
check_res = _check_arg_dict_list(model_func, [output, batch_y]) |
|
|
|
check_res = _check_arg_dict_list(func, [output, batch_y]) |
|
|
|
_missing = '' |
|
|
|
_unused = '' |
|
|
|
_duplicated = '' |
|
|
|
signature_str = get_func_signature(model_func) |
|
|
|
model_name = model.__class__.__name__ |
|
|
|
model_func_name = model_func.__name__ |
|
|
|
func_signature = "{}.{}(self, {})".format(model_name, model_func_name, signature_str[1:-1]) |
|
|
|
if hasattr(model, 'predict'): |
|
|
|
previous_func = model.predict |
|
|
|
previous_func_name = 'predict' |
|
|
|
else: |
|
|
|
previous_func = model.forward |
|
|
|
previous_func_name = 'forward' |
|
|
|
previous_signature_str = get_func_signature(previous_func) |
|
|
|
previous_func_signature = "{}.{}(self, {})".format(model_name, previous_func_name, previous_signature_str[1:-1]) |
|
|
|
func_signature = get_func_signature(func) |
|
|
|
prev_func_signature = get_func_signature(prev_func) |
|
|
|
if len(check_res.missing)>0: |
|
|
|
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ |
|
|
|
"{}." \ |
|
|
|
_missing = "Function {} misses argument {}, \n only provided with {}(from {}) and " \ |
|
|
|
"{}(from target in Dataset)." \ |
|
|
|
.format(func_signature, check_res.missing, |
|
|
|
list(output.keys()), previous_func_signature, |
|
|
|
list(output.keys()), prev_func_signature, |
|
|
|
list(batch_y.keys())) |
|
|
|
if len(check_res.unused)>0: |
|
|
|
if len(check_res.unused) > 1: |
|
|
@@ -424,40 +356,38 @@ def _check_evaluate(model, model_func, check_level, output, batch_y): |
|
|
|
"them in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
func_signature, |
|
|
|
check_res.duplicated, |
|
|
|
previous_func_signature) |
|
|
|
prev_func_signature) |
|
|
|
else: |
|
|
|
_duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \ |
|
|
|
"it in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
func_signature, |
|
|
|
check_res.duplicated, |
|
|
|
previous_func_signature) |
|
|
|
prev_func_signature) |
|
|
|
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) |
|
|
|
if _number_errs > 0: |
|
|
|
_error_str = '' |
|
|
|
_error_strs = [] |
|
|
|
if _number_errs > 1: |
|
|
|
count = 1 |
|
|
|
if _missing: |
|
|
|
_error_str += '({}).{}'.format(count, _missing) |
|
|
|
_error_strs.append('({}).{}'.format(count, _missing)) |
|
|
|
count += 1 |
|
|
|
if _duplicated: |
|
|
|
_error_str += '({}).{}'.format(count, _duplicated) |
|
|
|
_error_strs.append('({}).{}'.format(count, _duplicated)) |
|
|
|
count += 1 |
|
|
|
if _unused and check_level == STRICT_CHECK_LEVEL: |
|
|
|
_error_str += '({}).{}'.format(count, _unused) |
|
|
|
_error_strs.append('({}).{}'.format(count, _unused)) |
|
|
|
else: |
|
|
|
if _unused: |
|
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
|
# TODO 这里可能需要自定义一些Error类型 |
|
|
|
_error_str = _unused |
|
|
|
_error_strs.append(_unused) |
|
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
|
_unused = _unused.strip() |
|
|
|
warnings.warn(_unused) |
|
|
|
else: |
|
|
|
_error_str = _missing + _duplicated |
|
|
|
if _error_str: |
|
|
|
raise ValueError(_error_str) |
|
|
|
|
|
|
|
|
|
|
|
_error_strs = [_missing, _duplicated] |
|
|
|
if _error_strs: |
|
|
|
raise ValueError('\n'.join(_error_strs)) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
@@ -478,11 +408,12 @@ if __name__ == '__main__': |
|
|
|
output['words'] = words |
|
|
|
return output |
|
|
|
|
|
|
|
def get_loss(self, prediction, labels, words): |
|
|
|
def get_loss(self, prediction, labels, words, seq_lens): |
|
|
|
return torch.mean(self.fc1.weight) |
|
|
|
|
|
|
|
def evaluate(self, prediction, labels, demo=2): |
|
|
|
return 0 |
|
|
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
model = Model() |
|
|
|
|
|
|
@@ -493,7 +424,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
dataset = DataSet(fake_data_dict) |
|
|
|
dataset.set_input(words=True, chars=True) |
|
|
|
dataset.set_target(labels=True) |
|
|
|
dataset.set_target(labels=True, words=True) |
|
|
|
|
|
|
|
# trainer = Trainer(dataset, model) |
|
|
|
|
|
|
@@ -505,13 +436,5 @@ if __name__ == '__main__': |
|
|
|
# import inspect |
|
|
|
# print(inspect.getfullargspec(model.forward)) |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
a = [1, 3] |
|
|
|
np.asarray(a) |
|
|
|
|
|
|
|
import pandas |
|
|
|
df = pandas.DataFrame(fake_data_dict) |
|
|
|
df.infer_objects() |
|
|
|
|
|
|
|
|