Browse Source

[bugfix]修复了一些文档错误

tags/v0.5.5
Yige Xu 5 years ago
parent
commit
3510700d73
2 changed files with 6 additions and 13 deletions
  1. +1
    -1
      docs/source/tutorials/tutorial_9_seq_labeling.rst
  2. +5
    -12
      reproduction/text_classification/train_char_cnn.py

+ 1
- 1
docs/source/tutorials/tutorial_9_seq_labeling.rst View File

@@ -146,7 +146,7 @@ fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您
target_vocab=data_bundle.get_vocab('target'))

from fastNLP import SpanFPreRecMetric
from torch import Adam
from torch.optim import Adam
from fastNLP import LossInForward
metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))
optimizer = Adam(model.parameters(), lr=2e-5)


+ 5
- 12
reproduction/text_classification/train_char_cnn.py View File

@@ -94,11 +94,7 @@ ops=Config
# print('RNG SEED: {}'.format(ops.seed))


##1.task相关信息:利用dataloader载入dataInfo
#dataloader=SST2Loader()
#dataloader=IMDBLoader()
# dataloader=YelpLoader(fine_grained=True)
# datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False)
##1.task相关信息:利用dataloader载入DataBundle
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
ops.number_of_characters=len(char_vocab)
ops.embedding_dim=ops.number_of_characters
@@ -155,10 +151,8 @@ for dataset in data_bundle.datasets.values():
# print(data_bundle.datasets['train'][0]['chars'])
# print(data_bundle.datasets['train'][0]['raw_words'])

data_bundle.datasets['train'].set_input('chars')
data_bundle.datasets['test'].set_input('chars')
data_bundle.datasets['train'].set_target('target')
data_bundle.datasets['test'].set_target('target')
data_bundle.set_input('chars')
data_bundle.set_target('target')

##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model
class ModelFactory(nn.Module):
@@ -180,8 +174,7 @@ class ModelFactory(nn.Module):
return {C.OUTPUT:None}

## 2.或直接复用fastNLP的模型
#vocab=datainfo.vocabs['words']
vocab_label=data_bundle.vocabs['target']
vocab_label=data_bundle.get_vocab('target')
'''
# emded_char=CNNCharEmbedding(vocab)
# embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
@@ -199,7 +192,7 @@ embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_v
for para in embedding.parameters():
para.requires_grad=False
#CNNText太过于简单
#model=CNNText(init_embed=embedding, num_classes=ops.num_classes)
#model=CNNText(embed=embedding, num_classes=ops.num_classes)
model=CharacterLevelCNN(ops,embedding)

## 3. 声明loss,metric,optimizer


Loading…
Cancel
Save