|
|
@@ -344,7 +344,7 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): |
|
|
|
func_signature = get_func_signature(func) |
|
|
|
prev_func_signature = get_func_signature(prev_func) |
|
|
|
if len(check_res.missing)>0: |
|
|
|
_missing = "Function {} misses argument {}, \n only provided with {}(from {}) and " \ |
|
|
|
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ |
|
|
|
"{}(from target in Dataset)." \ |
|
|
|
.format(func_signature, check_res.missing, |
|
|
|
list(output.keys()), prev_func_signature, |
|
|
@@ -357,14 +357,14 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): |
|
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
|
if len(check_res.duplicated)>0: |
|
|
|
if len(check_res.duplicated) > 1: |
|
|
|
_duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \ |
|
|
|
"them in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ |
|
|
|
"them in {} at the same time.".format(check_res.duplicated, |
|
|
|
func_signature, |
|
|
|
check_res.duplicated, |
|
|
|
prev_func_signature) |
|
|
|
else: |
|
|
|
_duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \ |
|
|
|
"it in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ |
|
|
|
"it in {} at the same time.".format(check_res.duplicated, |
|
|
|
func_signature, |
|
|
|
check_res.duplicated, |
|
|
|
prev_func_signature) |
|
|
@@ -372,15 +372,16 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): |
|
|
|
if _number_errs > 0: |
|
|
|
_error_strs = [] |
|
|
|
if _number_errs > 1: |
|
|
|
count = 1 |
|
|
|
count = 0 |
|
|
|
order_words = ['Firstly', 'Secondly', 'Thirdly'] |
|
|
|
if _missing: |
|
|
|
_error_strs.append('({}).{}'.format(count, _missing)) |
|
|
|
_error_strs.append('{}, {}'.format(order_words[count], _missing)) |
|
|
|
count += 1 |
|
|
|
if _duplicated: |
|
|
|
_error_strs.append('({}).{}'.format(count, _duplicated)) |
|
|
|
_error_strs.append('{}, {}'.format(order_words[count], _duplicated)) |
|
|
|
count += 1 |
|
|
|
if _unused and check_level == STRICT_CHECK_LEVEL: |
|
|
|
_error_strs.append('({}).{}'.format(count, _unused)) |
|
|
|
_error_strs.append('{}, {}'.format(order_words[count], _unused)) |
|
|
|
else: |
|
|
|
if _unused: |
|
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
@@ -390,9 +391,13 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): |
|
|
|
_unused = _unused.strip() |
|
|
|
warnings.warn(_unused) |
|
|
|
else: |
|
|
|
_error_strs = [_missing, _duplicated] |
|
|
|
if _missing: |
|
|
|
_error_strs.append(_missing) |
|
|
|
if _duplicated: |
|
|
|
_error_strs.append(_duplicated) |
|
|
|
|
|
|
|
if _error_strs: |
|
|
|
raise ValueError('\n'.join(_error_strs)) |
|
|
|
raise ValueError('\n' + '\n'.join(_error_strs)) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
@@ -410,10 +415,10 @@ if __name__ == '__main__': |
|
|
|
def forward(self, words, chars): |
|
|
|
output = {} |
|
|
|
output['prediction'] = torch.randn(3, 4) |
|
|
|
output['words'] = words |
|
|
|
# output['words'] = words |
|
|
|
return output |
|
|
|
|
|
|
|
def get_loss(self, prediction, labels, words, seq_lens): |
|
|
|
def get_loss(self, prediction, labels, words): |
|
|
|
return torch.mean(self.fc1.weight) |
|
|
|
|
|
|
|
def evaluate(self, prediction, labels, demo=2): |
|
|
@@ -424,7 +429,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
num_samples = 4 |
|
|
|
fake_data_dict = {'words': np.random.randint(num_samples, size=(4, 3)), 'chars': np.random.randn(num_samples, 6), |
|
|
|
'labels': np.random.randint(2, size=(num_samples,))} |
|
|
|
'labels': np.random.randint(2, size=(num_samples,)), 'seq_lens': [1, 3, 4, 6]} |
|
|
|
|
|
|
|
|
|
|
|
dataset = DataSet(fake_data_dict) |
|
|
@@ -441,5 +446,7 @@ if __name__ == '__main__': |
|
|
|
# import inspect |
|
|
|
# print(inspect.getfullargspec(model.forward)) |
|
|
|
|
|
|
|
import pandas |
|
|
|
df = pandas.DataFrame({'a':0}) |
|
|
|
|
|
|
|
|