Browse Source

[verify] char_cnn use pipe

tags/v0.4.10
wyg 5 years ago
parent
commit
ed6fd60aa9
2 changed files with 60 additions and 9 deletions
  1. +24
    -0
      reproduction/text_classification/model/BertTC.py
  2. +36
    -9
      reproduction/text_classification/train_char_cnn.py

+ 24
- 0
reproduction/text_classification/model/BertTC.py View File

@@ -0,0 +1,24 @@
from fastNLP.embeddings import BertEmbedding
import torch
import torch.nn as nn
from fastNLP.core.const import Const as C

class BertTC(nn.Module):
def __init__(self, vocab,num_class,bert_model_dir_or_name,fine_tune=False):
super(BertTC, self).__init__()
self.embed=BertEmbedding(vocab, requires_grad=fine_tune,
model_dir_or_name=bert_model_dir_or_name,include_cls_sep=True)
self.classifier = nn.Linear(self.embed.embedding_dim, num_class)

def forward(self, words):
embedding_cls=self.embed(words)[:,0]
output=self.classifier(embedding_cls)
return {C.OUTPUT: output}

def predict(self,words):
return self.forward(words)

if __name__=="__main__":
ta=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
tb=ta[:,0]
print(tb)

+ 36
- 9
reproduction/text_classification/train_char_cnn.py View File

@@ -8,6 +8,7 @@ sys.path.append('../..')
from fastNLP.core.const import Const as C
import torch.nn as nn
from fastNLP.io.data_loader import YelpLoader
from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe
#from data.sstLoader import sst2Loader
from model.char_cnn import CharacterLevelCNN
from fastNLP import CrossEntropyLoss, AccuracyMetric
@@ -46,6 +47,8 @@ class Config():
extra_characters=''
max_length=1014
weight_decay = 1e-5
to_lower=True
tokenizer = 'spacy' # 使用spacy进行分词

char_cnn_config={
"alphabet": {
@@ -111,12 +114,35 @@ ops=Config
##1.task相关信息:利用dataloader载入dataInfo
#dataloader=SST2Loader()
#dataloader=IMDBLoader()
dataloader=YelpLoader(fine_grained=True)
datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False)
# dataloader=YelpLoader(fine_grained=True)
# datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False)
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
ops.number_of_characters=len(char_vocab)
ops.embedding_dim=ops.number_of_characters

# load data set
if ops.task == 'yelp_p':
data_bundle = YelpPolarityPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
elif ops.task == 'yelp_f':
data_bundle = YelpFullPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
elif ops.task == 'imdb':
data_bundle = IMDBPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
elif ops.task == 'sst-2':
data_bundle = SST2Pipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
else:
raise RuntimeError(f'NOT support {ops.task} task yet!')


def wordtochar(words):
chars = []
for word in words:
#word = word.lower()
for char in word:
chars.append(char)
chars.append('')
chars.pop()
return chars

#chartoindex
def chartoindex(chars):
max_seq_len=ops.max_length
@@ -136,13 +162,14 @@ def chartoindex(chars):
char_index_list=[zero_index]*max_seq_len
return char_index_list

for dataset in datainfo.datasets.values():
for dataset in data_bundle.datasets.values():
dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars')
dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars')

datainfo.datasets['train'].set_input('chars')
datainfo.datasets['test'].set_input('chars')
datainfo.datasets['train'].set_target('target')
datainfo.datasets['test'].set_target('target')
data_bundle.datasets['train'].set_input('chars')
data_bundle.datasets['test'].set_input('chars')
data_bundle.datasets['train'].set_target('target')
data_bundle.datasets['test'].set_target('target')

##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model
class ModelFactory(nn.Module):
@@ -165,7 +192,7 @@ class ModelFactory(nn.Module):

## 2.或直接复用fastNLP的模型
#vocab=datainfo.vocabs['words']
vocab_label=datainfo.vocabs['target']
vocab_label=data_bundle.vocabs['target']
'''
# emded_char=CNNCharEmbedding(vocab)
# embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
@@ -212,5 +239,5 @@ if __name__=="__main__":
#print(vocab_label)

#print(datainfo.datasets["train"])
train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch)
train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch)

Loading…
Cancel
Save