From 3510700d7333000b768cda21fd07729493879975 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Fri, 15 Nov 2019 23:21:52 +0800 Subject: [PATCH 1/2] =?UTF-8?q?[bugfix]=E4=BF=AE=E5=A4=8D=E4=BA=86?= =?UTF-8?q?=E4=B8=80=E4=BA=9B=E6=96=87=E6=A1=A3=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tutorials/tutorial_9_seq_labeling.rst | 2 +- .../text_classification/train_char_cnn.py | 17 +++++------------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/docs/source/tutorials/tutorial_9_seq_labeling.rst b/docs/source/tutorials/tutorial_9_seq_labeling.rst index 60bc1440..0272fc1c 100644 --- a/docs/source/tutorials/tutorial_9_seq_labeling.rst +++ b/docs/source/tutorials/tutorial_9_seq_labeling.rst @@ -146,7 +146,7 @@ fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您 target_vocab=data_bundle.get_vocab('target')) from fastNLP import SpanFPreRecMetric - from torch import Adam + from torch.optim import Adam from fastNLP import LossInForward metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target')) optimizer = Adam(model.parameters(), lr=2e-5) diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index 55d830e6..a4a97dc4 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -94,11 +94,7 @@ ops=Config # print('RNG SEED: {}'.format(ops.seed)) -##1.task相关信息:利用dataloader载入dataInfo -#dataloader=SST2Loader() -#dataloader=IMDBLoader() -# dataloader=YelpLoader(fine_grained=True) -# datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False) +##1.task相关信息:利用dataloader载入DataBundle char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] ops.number_of_characters=len(char_vocab) ops.embedding_dim=ops.number_of_characters @@ -155,10 +151,8 @@ for dataset in data_bundle.datasets.values(): # print(data_bundle.datasets['train'][0]['chars']) # print(data_bundle.datasets['train'][0]['raw_words']) -data_bundle.datasets['train'].set_input('chars') -data_bundle.datasets['test'].set_input('chars') -data_bundle.datasets['train'].set_target('target') -data_bundle.datasets['test'].set_target('target') +data_bundle.set_input('chars') +data_bundle.set_target('target') ##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model class ModelFactory(nn.Module): @@ -180,8 +174,7 @@ class ModelFactory(nn.Module): return {C.OUTPUT:None} ## 2.或直接复用fastNLP的模型 -#vocab=datainfo.vocabs['words'] -vocab_label=data_bundle.vocabs['target'] +vocab_label=data_bundle.get_vocab('target') ''' # emded_char=CNNCharEmbedding(vocab) # embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) @@ -199,7 +192,7 @@ embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_v for para in embedding.parameters(): para.requires_grad=False #CNNText太过于简单 -#model=CNNText(init_embed=embedding, num_classes=ops.num_classes) +#model=CNNText(embed=embedding, num_classes=ops.num_classes) model=CharacterLevelCNN(ops,embedding) ## 3. 声明loss,metric,optimizer From 81c71aceb8853f591bb069f8880c3697eead06ca Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Sun, 17 Nov 2019 01:27:54 +0800 Subject: [PATCH 2/2] =?UTF-8?q?[bugfix]=E4=BF=AE=E5=A4=8D=E4=BA=86trainer?= =?UTF-8?q?=E5=9C=A8update=5Fevery=E5=A4=A7=E4=BA=8E1=E6=97=B6loss?= =?UTF-8?q?=E7=9A=84=E6=98=BE=E7=A4=BA=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a39362e2..721478f7 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -666,8 +666,8 @@ class Trainer(object): # edit prediction self.callback_manager.on_loss_begin(batch_y, prediction) loss = self._compute_loss(prediction, batch_y).mean() - avg_loss += loss.item() loss = loss / self.update_every + avg_loss += loss.item() # Is loss NaN or inf? requires_grad = False self.callback_manager.on_backward_begin(loss)