From ffc963190e1fa4cfa06b265ff8b1034c062234e2 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 26 Nov 2018 20:43:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9dataframe.read=5Fcsv?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 11 ++++++++--- fastNLP/core/trainer.py | 35 +++++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 49c2add4..ee0e5590 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -293,7 +293,7 @@ class DataSet(object): return train_set, dev_set @classmethod - def read_csv(cls, csv_path, headers=None, sep='\t'): + def read_csv(cls, csv_path, headers=None, sep='\t', dropna=True): with open(csv_path, 'r') as f: start_idx = 0 if headers is None: @@ -307,8 +307,13 @@ class DataSet(object): _dict[col] = [] for line_idx, line in enumerate(f, start_idx): contents = line.split(sep) - assert len(contents)==len(headers), "Line {} has {} parts, while header has {}."\ - .format(line_idx, len(contents), len(headers)) + if len(contents)!=len(headers): + if dropna: + continue + else: + #TODO change error type + raise ValueError("Line {} has {} parts, while header has {} parts."\ + .format(line_idx, len(contents), len(headers))) for header, content in zip(headers, contents): _dict[header].append(content) return cls(_dict) \ No newline at end of file diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index e5499767..26602dc9 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -344,7 +344,7 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): func_signature = get_func_signature(func) prev_func_signature = get_func_signature(prev_func) if len(check_res.missing)>0: - _missing = "Function {} misses argument {}, \n only provided with {}(from {}) and " \ + _missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ "{}(from target in Dataset)." \ .format(func_signature, check_res.missing, list(output.keys()), prev_func_signature, @@ -357,14 +357,14 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): _unused += "in function {}.\n".format(func_signature) if len(check_res.duplicated)>0: if len(check_res.duplicated) > 1: - _duplicated = "Duplicated keys {} are detected when calling function {}. \nDon't set {} as target and output " \ - "them in {} at the same time.\n".format(check_res.duplicated, + _duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ + "them in {} at the same time.".format(check_res.duplicated, func_signature, check_res.duplicated, prev_func_signature) else: - _duplicated = "Duplicated key {} is detected when calling function {}. \nDon't set {} as target and output " \ - "it in {} at the same time.\n".format(check_res.duplicated, + _duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ + "it in {} at the same time.".format(check_res.duplicated, func_signature, check_res.duplicated, prev_func_signature) @@ -372,15 +372,16 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): if _number_errs > 0: _error_strs = [] if _number_errs > 1: - count = 1 + count = 0 + order_words = ['Firstly', 'Secondly', 'Thirdly'] if _missing: - _error_strs.append('({}).{}'.format(count, _missing)) + _error_strs.append('{}, {}'.format(order_words[count], _missing)) count += 1 if _duplicated: - _error_strs.append('({}).{}'.format(count, _duplicated)) + _error_strs.append('{}, {}'.format(order_words[count], _duplicated)) count += 1 if _unused and check_level == STRICT_CHECK_LEVEL: - _error_strs.append('({}).{}'.format(count, _unused)) + _error_strs.append('{}, {}'.format(order_words[count], _unused)) else: if _unused: if check_level == STRICT_CHECK_LEVEL: @@ -390,9 +391,13 @@ def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): _unused = _unused.strip() warnings.warn(_unused) else: - _error_strs = [_missing, _duplicated] + if _missing: + _error_strs.append(_missing) + if _duplicated: + _error_strs.append(_duplicated) + if _error_strs: - raise ValueError('\n'.join(_error_strs)) + raise ValueError('\n' + '\n'.join(_error_strs)) if __name__ == '__main__': @@ -410,10 +415,10 @@ if __name__ == '__main__': def forward(self, words, chars): output = {} output['prediction'] = torch.randn(3, 4) - output['words'] = words + # output['words'] = words return output - def get_loss(self, prediction, labels, words, seq_lens): + def get_loss(self, prediction, labels, words): return torch.mean(self.fc1.weight) def evaluate(self, prediction, labels, demo=2): @@ -424,7 +429,7 @@ if __name__ == '__main__': num_samples = 4 fake_data_dict = {'words': np.random.randint(num_samples, size=(4, 3)), 'chars': np.random.randn(num_samples, 6), - 'labels': np.random.randint(2, size=(num_samples,))} + 'labels': np.random.randint(2, size=(num_samples,)), 'seq_lens': [1, 3, 4, 6]} dataset = DataSet(fake_data_dict) @@ -441,5 +446,7 @@ if __name__ == '__main__': # import inspect # print(inspect.getfullargspec(model.forward)) + import pandas + df = pandas.DataFrame({'a':0})