|
|
@@ -8,6 +8,7 @@ sys.path.append('../..') |
|
|
|
from fastNLP.core.const import Const as C |
|
|
|
import torch.nn as nn |
|
|
|
from fastNLP.io.data_loader import YelpLoader |
|
|
|
from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe |
|
|
|
#from data.sstLoader import sst2Loader |
|
|
|
from model.char_cnn import CharacterLevelCNN |
|
|
|
from fastNLP import CrossEntropyLoss, AccuracyMetric |
|
|
@@ -46,6 +47,8 @@ class Config(): |
|
|
|
extra_characters='' |
|
|
|
max_length=1014 |
|
|
|
weight_decay = 1e-5 |
|
|
|
to_lower=True |
|
|
|
tokenizer = 'spacy' # 使用spacy进行分词 |
|
|
|
|
|
|
|
char_cnn_config={ |
|
|
|
"alphabet": { |
|
|
@@ -111,12 +114,35 @@ ops=Config |
|
|
|
##1.task相关信息:利用dataloader载入dataInfo |
|
|
|
#dataloader=SST2Loader() |
|
|
|
#dataloader=IMDBLoader() |
|
|
|
dataloader=YelpLoader(fine_grained=True) |
|
|
|
datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False) |
|
|
|
# dataloader=YelpLoader(fine_grained=True) |
|
|
|
# datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False) |
|
|
|
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] |
|
|
|
ops.number_of_characters=len(char_vocab) |
|
|
|
ops.embedding_dim=ops.number_of_characters |
|
|
|
|
|
|
|
# load data set |
|
|
|
if ops.task == 'yelp_p': |
|
|
|
data_bundle = YelpPolarityPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file() |
|
|
|
elif ops.task == 'yelp_f': |
|
|
|
data_bundle = YelpFullPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file() |
|
|
|
elif ops.task == 'imdb': |
|
|
|
data_bundle = IMDBPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file() |
|
|
|
elif ops.task == 'sst-2': |
|
|
|
data_bundle = SST2Pipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file() |
|
|
|
else: |
|
|
|
raise RuntimeError(f'NOT support {ops.task} task yet!') |
|
|
|
|
|
|
|
|
|
|
|
def wordtochar(words): |
|
|
|
chars = [] |
|
|
|
for word in words: |
|
|
|
#word = word.lower() |
|
|
|
for char in word: |
|
|
|
chars.append(char) |
|
|
|
chars.append('') |
|
|
|
chars.pop() |
|
|
|
return chars |
|
|
|
|
|
|
|
#chartoindex |
|
|
|
def chartoindex(chars): |
|
|
|
max_seq_len=ops.max_length |
|
|
@@ -136,13 +162,14 @@ def chartoindex(chars): |
|
|
|
char_index_list=[zero_index]*max_seq_len |
|
|
|
return char_index_list |
|
|
|
|
|
|
|
for dataset in datainfo.datasets.values(): |
|
|
|
for dataset in data_bundle.datasets.values(): |
|
|
|
dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars') |
|
|
|
dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars') |
|
|
|
|
|
|
|
datainfo.datasets['train'].set_input('chars') |
|
|
|
datainfo.datasets['test'].set_input('chars') |
|
|
|
datainfo.datasets['train'].set_target('target') |
|
|
|
datainfo.datasets['test'].set_target('target') |
|
|
|
data_bundle.datasets['train'].set_input('chars') |
|
|
|
data_bundle.datasets['test'].set_input('chars') |
|
|
|
data_bundle.datasets['train'].set_target('target') |
|
|
|
data_bundle.datasets['test'].set_target('target') |
|
|
|
|
|
|
|
##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model |
|
|
|
class ModelFactory(nn.Module): |
|
|
@@ -165,7 +192,7 @@ class ModelFactory(nn.Module): |
|
|
|
|
|
|
|
## 2.或直接复用fastNLP的模型 |
|
|
|
#vocab=datainfo.vocabs['words'] |
|
|
|
vocab_label=datainfo.vocabs['target'] |
|
|
|
vocab_label=data_bundle.vocabs['target'] |
|
|
|
''' |
|
|
|
# emded_char=CNNCharEmbedding(vocab) |
|
|
|
# embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) |
|
|
@@ -212,5 +239,5 @@ if __name__=="__main__": |
|
|
|
#print(vocab_label) |
|
|
|
|
|
|
|
#print(datainfo.datasets["train"]) |
|
|
|
train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch) |
|
|
|
train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch) |
|
|
|
|