|
|
@@ -94,11 +94,7 @@ ops=Config |
|
|
|
# print('RNG SEED: {}'.format(ops.seed)) |
|
|
|
|
|
|
|
|
|
|
|
##1.task相关信息:利用dataloader载入dataInfo |
|
|
|
#dataloader=SST2Loader() |
|
|
|
#dataloader=IMDBLoader() |
|
|
|
# dataloader=YelpLoader(fine_grained=True) |
|
|
|
# datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False) |
|
|
|
##1.task相关信息:利用dataloader载入DataBundle |
|
|
|
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] |
|
|
|
ops.number_of_characters=len(char_vocab) |
|
|
|
ops.embedding_dim=ops.number_of_characters |
|
|
@@ -155,10 +151,8 @@ for dataset in data_bundle.datasets.values(): |
|
|
|
# print(data_bundle.datasets['train'][0]['chars']) |
|
|
|
# print(data_bundle.datasets['train'][0]['raw_words']) |
|
|
|
|
|
|
|
data_bundle.datasets['train'].set_input('chars') |
|
|
|
data_bundle.datasets['test'].set_input('chars') |
|
|
|
data_bundle.datasets['train'].set_target('target') |
|
|
|
data_bundle.datasets['test'].set_target('target') |
|
|
|
data_bundle.set_input('chars') |
|
|
|
data_bundle.set_target('target') |
|
|
|
|
|
|
|
##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model |
|
|
|
class ModelFactory(nn.Module): |
|
|
@@ -180,8 +174,7 @@ class ModelFactory(nn.Module): |
|
|
|
return {C.OUTPUT:None} |
|
|
|
|
|
|
|
## 2.或直接复用fastNLP的模型 |
|
|
|
#vocab=datainfo.vocabs['words'] |
|
|
|
vocab_label=data_bundle.vocabs['target'] |
|
|
|
vocab_label=data_bundle.get_vocab('target') |
|
|
|
''' |
|
|
|
# emded_char=CNNCharEmbedding(vocab) |
|
|
|
# embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) |
|
|
@@ -199,7 +192,7 @@ embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_v |
|
|
|
for para in embedding.parameters(): |
|
|
|
para.requires_grad=False |
|
|
|
#CNNText太过于简单 |
|
|
|
#model=CNNText(init_embed=embedding, num_classes=ops.num_classes) |
|
|
|
#model=CNNText(embed=embedding, num_classes=ops.num_classes) |
|
|
|
model=CharacterLevelCNN(ops,embedding) |
|
|
|
|
|
|
|
## 3. 声明loss,metric,optimizer |
|
|
|