|
@@ -8,7 +8,8 @@ sys.path.append('../..') |
|
|
from fastNLP.core.const import Const as C |
|
|
from fastNLP.core.const import Const as C |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
from data.yelpLoader import yelpLoader |
|
|
from data.yelpLoader import yelpLoader |
|
|
from data.sstLoader import sst2Loader |
|
|
|
|
|
|
|
|
#from data.sstLoader import sst2Loader |
|
|
|
|
|
from fastNLP.io.data_loader.sst import SST2Loader |
|
|
from data.IMDBLoader import IMDBLoader |
|
|
from data.IMDBLoader import IMDBLoader |
|
|
from model.char_cnn import CharacterLevelCNN |
|
|
from model.char_cnn import CharacterLevelCNN |
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
@@ -20,16 +21,20 @@ from torch.optim import SGD |
|
|
from torch.autograd import Variable |
|
|
from torch.autograd import Variable |
|
|
import torch |
|
|
import torch |
|
|
from fastNLP import BucketSampler |
|
|
from fastNLP import BucketSampler |
|
|
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR |
|
|
|
|
|
from fastNLP.core import LRScheduler |
|
|
|
|
|
from utils.util_init import set_rng_seeds |
|
|
|
|
|
|
|
|
##hyper |
|
|
##hyper |
|
|
#todo 这里加入fastnlp的记录 |
|
|
#todo 这里加入fastnlp的记录 |
|
|
class Config(): |
|
|
class Config(): |
|
|
|
|
|
#seed=7777 |
|
|
model_dir_or_name="en-base-uncased" |
|
|
model_dir_or_name="en-base-uncased" |
|
|
embedding_grad= False, |
|
|
embedding_grad= False, |
|
|
bert_embedding_larers= '4,-2,-1' |
|
|
bert_embedding_larers= '4,-2,-1' |
|
|
train_epoch= 50 |
|
|
train_epoch= 50 |
|
|
num_classes=2 |
|
|
num_classes=2 |
|
|
task= "IMDB" |
|
|
|
|
|
|
|
|
task= "yelp_p" |
|
|
#yelp_p |
|
|
#yelp_p |
|
|
datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", |
|
|
datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", |
|
|
"test": "/remote-home/ygwang/yelp_polarity/test.csv"} |
|
|
"test": "/remote-home/ygwang/yelp_polarity/test.csv"} |
|
@@ -46,6 +51,7 @@ class Config(): |
|
|
number_of_characters=69 |
|
|
number_of_characters=69 |
|
|
extra_characters='' |
|
|
extra_characters='' |
|
|
max_length=1014 |
|
|
max_length=1014 |
|
|
|
|
|
weight_decay = 1e-5 |
|
|
|
|
|
|
|
|
char_cnn_config={ |
|
|
char_cnn_config={ |
|
|
"alphabet": { |
|
|
"alphabet": { |
|
@@ -104,12 +110,15 @@ class Config(): |
|
|
} |
|
|
} |
|
|
ops=Config |
|
|
ops=Config |
|
|
|
|
|
|
|
|
|
|
|
# set_rng_seeds(ops.seed) |
|
|
|
|
|
# print('RNG SEED: {}'.format(ops.seed)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##1.task相关信息:利用dataloader载入dataInfo |
|
|
##1.task相关信息:利用dataloader载入dataInfo |
|
|
#dataloader=sst2Loader() |
|
|
|
|
|
|
|
|
#dataloader=SST2Loader() |
|
|
#dataloader=IMDBLoader() |
|
|
#dataloader=IMDBLoader() |
|
|
dataloader=yelpLoader(fine_grained=True) |
|
|
dataloader=yelpLoader(fine_grained=True) |
|
|
datainfo=dataloader.process(ops.datapath,char_level_op=True) |
|
|
|
|
|
|
|
|
datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False) |
|
|
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] |
|
|
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] |
|
|
ops.number_of_characters=len(char_vocab) |
|
|
ops.number_of_characters=len(char_vocab) |
|
|
ops.embedding_dim=ops.number_of_characters |
|
|
ops.embedding_dim=ops.number_of_characters |
|
@@ -186,12 +195,20 @@ model=CharacterLevelCNN(ops,embedding) |
|
|
## 3. 声明loss,metric,optimizer |
|
|
## 3. 声明loss,metric,optimizer |
|
|
loss=CrossEntropyLoss |
|
|
loss=CrossEntropyLoss |
|
|
metric=AccuracyMetric |
|
|
metric=AccuracyMetric |
|
|
optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) |
|
|
|
|
|
|
|
|
#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) |
|
|
|
|
|
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], |
|
|
|
|
|
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) |
|
|
|
|
|
callbacks = [] |
|
|
|
|
|
# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) |
|
|
|
|
|
callbacks.append( |
|
|
|
|
|
LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < |
|
|
|
|
|
ops.train_epoch * 0.8 else ops.lr * 0.1)) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
## 4.定义train方法 |
|
|
## 4.定义train方法 |
|
|
def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): |
|
|
def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): |
|
|
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'), |
|
|
|
|
|
metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, |
|
|
|
|
|
|
|
|
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size, |
|
|
|
|
|
metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1, |
|
|
n_epochs=num_epochs) |
|
|
n_epochs=num_epochs) |
|
|
print(trainer.train()) |
|
|
print(trainer.train()) |
|
|
|
|
|
|
|
|