From 5824b7f4c73788738baa0d39c01ec0d12bc4ba0e Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 3 Dec 2018 00:08:59 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B7=91=E9=80=9Atutorial,=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E4=B8=80=E4=BA=9Bbugs:=20*=20dataset=E6=A3=80=E6=9F=A5slice?= =?UTF-8?q?=E5=BC=80=E5=A7=8B=E4=BD=8D=E7=BD=AE=EF=BC=8C=E7=A1=AE=E4=BF=9D?= =?UTF-8?q?=E7=BB=93=E6=9E=9C=E4=B8=8D=E4=B8=BA=E7=A9=BA=20*=20fieldarray?= =?UTF-8?q?=E6=A3=80=E6=9F=A5content=E4=B8=8D=E4=B8=BA=E7=A9=BA=20*=20opti?= =?UTF-8?q?mizer=E6=8E=A5=E5=8F=97=E7=9A=84model=20params=E6=98=AF?= =?UTF-8?q?=E4=B8=80=E4=B8=AAgenerator=EF=BC=8C=E4=B8=8D=E8=83=BD=E8=B5=8B?= =?UTF-8?q?=E5=80=BC=20*=20code=20style=20refine?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 7 +- fastNLP/core/fieldarray.py | 3 + fastNLP/core/optimizer.py | 12 ++- fastNLP/models/cnn_text_classification.py | 7 +- test/io/__init__.py | 0 test/test_tutorial.py | 95 +++++++++++++++++++++++ 6 files changed, 115 insertions(+), 9 deletions(-) delete mode 100644 test/io/__init__.py create mode 100644 test/test_tutorial.py diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 6d2a94d6..e93333a0 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -87,6 +87,8 @@ class DataSet(object): if isinstance(idx, int): return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) elif isinstance(idx, slice): + if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): + raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}") data_set = DataSet() for field in self.field_arrays.values(): data_set.add_field(name=field.name, @@ -135,7 +137,9 @@ class DataSet(object): :param bool is_target: whether this field is label or target. """ if len(self.field_arrays) != 0: - assert len(self) == len(fields) + if len(self) != len(fields): + raise RuntimeError(f"The field to append must have the same size as dataset. " + f"Dataset size {len(self)} != field size {len(fields)}") self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, is_input=is_input) @@ -168,6 +172,7 @@ class DataSet(object): """ if old_name in self.field_arrays: self.field_arrays[new_name] = self.field_arrays.pop(old_name) + self.field_arrays[new_name].name = new_name else: raise KeyError("{} is not a valid name. ".format(old_name)) diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 976dc2c6..14c52829 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -33,7 +33,10 @@ class FieldArray(object): type_set = set([type(item) for item in content[0]]) else: # 1-D list + if len(content) == 0: + raise RuntimeError("Cannot create FieldArray with an empty list.") type_set = set([type(item) for item in content]) + if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)): return type_set.pop() elif len(type_set) == 2 and float in type_set and int in type_set: diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index 4cb21462..5075fa02 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -42,8 +42,10 @@ class SGD(Optimizer): def construct_from_pytorch(self, model_params): if self.model_params is None: - self.model_params = model_params - return torch.optim.SGD(self.model_params, **self.settings) + # careful! generator cannot be assigned. + return torch.optim.SGD(model_params, **self.settings) + else: + return torch.optim.SGD(self.model_params, **self.settings) class Adam(Optimizer): @@ -75,5 +77,7 @@ class Adam(Optimizer): def construct_from_pytorch(self, model_params): if self.model_params is None: - self.model_params = model_params - return torch.optim.Adam(self.model_params, **self.settings) + # careful! generator cannot be assigned. + return torch.optim.Adam(model_params, **self.settings) + else: + return torch.optim.Adam(self.model_params, **self.settings) diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index 04b76fba..9aa07e66 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -18,8 +18,8 @@ class CNNText(torch.nn.Module): def __init__(self, embed_num, embed_dim, num_classes, - kernel_nums=(3,4,5), - kernel_sizes=(3,4,5), + kernel_nums=(3, 4, 5), + kernel_sizes=(3, 4, 5), padding=0, dropout=0.5): super(CNNText, self).__init__() @@ -45,7 +45,7 @@ class CNNText(torch.nn.Module): x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] - return {'output':x} + return {'output': x} def predict(self, word_seq): """ @@ -78,4 +78,3 @@ class CNNText(torch.nn.Module): correct = (predict == label_seq).long().sum().item() total = label_seq.size(0) return {'acc': 1.0 * correct / total} - diff --git a/test/io/__init__.py b/test/io/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/test_tutorial.py b/test/test_tutorial.py new file mode 100644 index 00000000..05338514 --- /dev/null +++ b/test/test_tutorial.py @@ -0,0 +1,95 @@ +import unittest + +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Tester +from fastNLP import Vocabulary +from fastNLP.core.losses import CrossEntropyLoss +from fastNLP.core.metrics import AccuracyMetric +from fastNLP.models import CNNText + + +class TestTutorial(unittest.TestCase): + def test_tutorial(self): + # 从csv读取数据到DataSet + dataset = DataSet.read_csv("./data_for_tests/tutorial_sample_dataset.csv", headers=('raw_sentence', 'label'), + sep='\t') + print(len(dataset)) + print(dataset[0]) + + dataset.append(Instance(raw_sentence='fake data', label='0')) + dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') + # label转int + dataset.apply(lambda x: int(x['label']), new_field_name='label') + + # 使用空格分割句子 + def split_sent(ins): + return ins['raw_sentence'].split() + + dataset.apply(split_sent, new_field_name='words') + # 增加长度信息 + dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') + print(len(dataset)) + print(dataset[0]) + + # DataSet.drop(func)筛除数据 + dataset.drop(lambda x: x['seq_len'] <= 3) + print(len(dataset)) + + # 设置DataSet中,哪些field要转为tensor + # set target,loss或evaluate中的golden,计算loss,模型评估时使用 + dataset.set_target("label") + # set input,模型forward时使用 + dataset.set_input("words") + + # 分出测试集、训练集 + test_data, train_data = dataset.split(0.5) + print(len(test_data)) + print(len(train_data)) + + # 构建词表, Vocabulary.add(word) + vocab = Vocabulary(min_freq=2) + train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) + vocab.build_vocab() + + # index句子, Vocabulary.to_index(word) + train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') + test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') + print(test_data[0]) + + model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) + + from fastNLP import Trainer + from copy import deepcopy + + # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 + train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 + train_data.rename_field('label', 'label_seq') + test_data.rename_field('words', 'word_seq') + test_data.rename_field('label', 'label_seq') + + # 实例化Trainer,传入模型和数据,进行训练 + copy_model = deepcopy(model) + overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, + losser=CrossEntropyLoss(input="output", target="label_seq"), + metrics=AccuracyMetric(pred="predict", target="label_seq"), + save_path="./save", + batch_size=4, + n_epochs=10) + overfit_trainer.train() + + trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, + losser=CrossEntropyLoss(input="output", target="label_seq"), + metrics=AccuracyMetric(pred="predict", target="label_seq"), + save_path="./save", + batch_size=4, + n_epochs=10) + trainer.train() + print('Train finished!') + + # 使用fastNLP的Tester测试脚本 + + tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), + batch_size=4) + acc = tester.test() + print(acc)