Browse Source

修改dataframe.read_csv

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
ffc963190e
2 changed files with 29 additions and 17 deletions
  1. +8
    -3
      fastNLP/core/dataset.py
  2. +21
    -14
      fastNLP/core/trainer.py

+ 8
- 3
fastNLP/core/dataset.py View File

@@ -293,7 +293,7 @@ class DataSet(object):
return train_set, dev_set

@classmethod
def read_csv(cls, csv_path, headers=None, sep='\t'):
def read_csv(cls, csv_path, headers=None, sep='\t', dropna=True):
with open(csv_path, 'r') as f:
start_idx = 0
if headers is None:
@@ -307,8 +307,13 @@ class DataSet(object):
_dict[col] = []
for line_idx, line in enumerate(f, start_idx):
contents = line.split(sep)
assert len(contents)==len(headers), "Line {} has {} parts, while header has {}."\
.format(line_idx, len(contents), len(headers))
if len(contents)!=len(headers):
if dropna:
continue
else:
#TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts."\
.format(line_idx, len(contents), len(headers)))
for header, content in zip(headers, contents):
_dict[header].append(content)
return cls(_dict)

+ 21
- 14
fastNLP/core/trainer.py View File

@@ -344,7 +344,7 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):
func_signature = get_func_signature(func)
prev_func_signature = get_func_signature(prev_func)
if len(check_res.missing)>0:
_missing = "Function {} misses argument {}, \n only provided with {}(from {}) and " \
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \
"{}(from target in Dataset)." \
.format(func_signature, check_res.missing,
list(output.keys()), prev_func_signature,
@@ -357,14 +357,14 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 1:
_duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \
"them in {} at the same time.\n".format(check_res.duplicated,
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \
"them in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
else:
_duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \
"it in {} at the same time.\n".format(check_res.duplicated,
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
"it in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
@@ -372,15 +372,16 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):
if _number_errs > 0:
_error_strs = []
if _number_errs > 1:
count = 1
count = 0
order_words = ['Firstly', 'Secondly', 'Thirdly']
if _missing:
_error_strs.append('({}).{}'.format(count, _missing))
_error_strs.append('{}, {}'.format(order_words[count], _missing))
count += 1
if _duplicated:
_error_strs.append('({}).{}'.format(count, _duplicated))
_error_strs.append('{}, {}'.format(order_words[count], _duplicated))
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_strs.append('({}).{}'.format(count, _unused))
_error_strs.append('{}, {}'.format(order_words[count], _unused))
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
@@ -390,9 +391,13 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):
_unused = _unused.strip()
warnings.warn(_unused)
else:
_error_strs = [_missing, _duplicated]
if _missing:
_error_strs.append(_missing)
if _duplicated:
_error_strs.append(_duplicated)

if _error_strs:
raise ValueError('\n'.join(_error_strs))
raise ValueError('\n' + '\n'.join(_error_strs))


if __name__ == '__main__':
@@ -410,10 +415,10 @@ if __name__ == '__main__':
def forward(self, words, chars):
output = {}
output['prediction'] = torch.randn(3, 4)
output['words'] = words
# output['words'] = words
return output

def get_loss(self, prediction, labels, words, seq_lens):
def get_loss(self, prediction, labels, words):
return torch.mean(self.fc1.weight)

def evaluate(self, prediction, labels, demo=2):
@@ -424,7 +429,7 @@ if __name__ == '__main__':

num_samples = 4
fake_data_dict = {'words': np.random.randint(num_samples, size=(4, 3)), 'chars': np.random.randn(num_samples, 6),
'labels': np.random.randint(2, size=(num_samples,))}
'labels': np.random.randint(2, size=(num_samples,)), 'seq_lens': [1, 3, 4, 6]}


dataset = DataSet(fake_data_dict)
@@ -441,5 +446,7 @@ if __name__ == '__main__':
# import inspect
# print(inspect.getfullargspec(model.forward))

import pandas
df = pandas.DataFrame({'a':0})



Loading…
Cancel
Save