|
|
@@ -1,15 +1,8 @@ |
|
|
|
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 |
|
|
|
import os |
|
|
|
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' |
|
|
|
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' |
|
|
|
|
|
|
|
import sys |
|
|
|
sys.path.append('../..') |
|
|
|
from fastNLP.core.const import Const as C |
|
|
|
import torch.nn as nn |
|
|
|
from fastNLP.io.data_loader import YelpLoader |
|
|
|
from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe |
|
|
|
#from data.sstLoader import sst2Loader |
|
|
|
from model.char_cnn import CharacterLevelCNN |
|
|
|
from fastNLP import CrossEntropyLoss, AccuracyMetric |
|
|
|
from fastNLP.core.trainer import Trainer |
|
|
@@ -27,9 +20,9 @@ class Config(): |
|
|
|
model_dir_or_name="en-base-uncased" |
|
|
|
embedding_grad= False, |
|
|
|
bert_embedding_larers= '4,-2,-1' |
|
|
|
train_epoch= 50 |
|
|
|
train_epoch= 100 |
|
|
|
num_classes=2 |
|
|
|
task= "yelp_p" |
|
|
|
task= "sst-2" |
|
|
|
#yelp_p |
|
|
|
datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", |
|
|
|
"test": "/remote-home/ygwang/yelp_polarity/test.csv"} |
|
|
@@ -132,15 +125,17 @@ elif ops.task == 'sst-2': |
|
|
|
else: |
|
|
|
raise RuntimeError(f'NOT support {ops.task} task yet!') |
|
|
|
|
|
|
|
print(data_bundle) |
|
|
|
|
|
|
|
def wordtochar(words): |
|
|
|
chars = [] |
|
|
|
for word in words: |
|
|
|
|
|
|
|
#for word in words: |
|
|
|
#word = word.lower() |
|
|
|
for char in word: |
|
|
|
chars.append(char) |
|
|
|
chars.append('') |
|
|
|
chars.pop() |
|
|
|
for char in words: |
|
|
|
chars.append(char) |
|
|
|
#chars.append('') |
|
|
|
#chars.pop() |
|
|
|
return chars |
|
|
|
|
|
|
|
#chartoindex |
|
|
@@ -162,10 +157,14 @@ def chartoindex(chars): |
|
|
|
char_index_list=[zero_index]*max_seq_len |
|
|
|
return char_index_list |
|
|
|
|
|
|
|
|
|
|
|
for dataset in data_bundle.datasets.values(): |
|
|
|
dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars') |
|
|
|
dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars') |
|
|
|
|
|
|
|
# print(data_bundle.datasets['train'][0]['chars']) |
|
|
|
# print(data_bundle.datasets['train'][0]['raw_words']) |
|
|
|
|
|
|
|
data_bundle.datasets['train'].set_input('chars') |
|
|
|
data_bundle.datasets['test'].set_input('chars') |
|
|
|
data_bundle.datasets['train'].set_target('target') |
|
|
@@ -216,7 +215,6 @@ model=CharacterLevelCNN(ops,embedding) |
|
|
|
## 3. 声明loss,metric,optimizer |
|
|
|
loss=CrossEntropyLoss |
|
|
|
metric=AccuracyMetric |
|
|
|
#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) |
|
|
|
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], |
|
|
|
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) |
|
|
|
callbacks = [] |
|
|
@@ -236,8 +234,4 @@ def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): |
|
|
|
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
|
#print(vocab_label) |
|
|
|
|
|
|
|
#print(datainfo.datasets["train"]) |
|
|
|
train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch) |
|
|
|
|