Browse Source

[verify] train_char_cnn optimization

tags/v0.4.10
wyg 5 years ago
parent
commit
efea6ceaf5
2 changed files with 29 additions and 9 deletions
  1. +5
    -2
      reproduction/text_classification/data/yelpLoader.py
  2. +24
    -7
      reproduction/text_classification/train_char_cnn.py

+ 5
- 2
reproduction/text_classification/data/yelpLoader.py View File

@@ -131,7 +131,9 @@ class yelpLoader(DataSetLoader):
src_vocab_op: VocabularyOption = None, src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None,
embed_opt: EmbeddingOption = None, embed_opt: EmbeddingOption = None,
char_level_op=False):
char_level_op=False,
split_dev_op=True
):
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
datasets = {} datasets = {}
info = DataBundle(datasets=self.load(paths)) info = DataBundle(datasets=self.load(paths))
@@ -172,7 +174,8 @@ class yelpLoader(DataSetLoader):


info.vocabs[target_name]=tgt_vocab info.vocabs[target_name]=tgt_vocab


info.datasets['train'],info.datasets['dev']=info.datasets['train'].split(0.1, shuffle=False)
if split_dev_op:
info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False)


for name, dataset in info.datasets.items(): for name, dataset in info.datasets.items():
dataset.set_input("words") dataset.set_input("words")


+ 24
- 7
reproduction/text_classification/train_char_cnn.py View File

@@ -8,7 +8,8 @@ sys.path.append('../..')
from fastNLP.core.const import Const as C from fastNLP.core.const import Const as C
import torch.nn as nn import torch.nn as nn
from data.yelpLoader import yelpLoader from data.yelpLoader import yelpLoader
from data.sstLoader import sst2Loader
#from data.sstLoader import sst2Loader
from fastNLP.io.data_loader.sst import SST2Loader
from data.IMDBLoader import IMDBLoader from data.IMDBLoader import IMDBLoader
from model.char_cnn import CharacterLevelCNN from model.char_cnn import CharacterLevelCNN
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
@@ -20,16 +21,20 @@ from torch.optim import SGD
from torch.autograd import Variable from torch.autograd import Variable
import torch import torch
from fastNLP import BucketSampler from fastNLP import BucketSampler
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from fastNLP.core import LRScheduler
from utils.util_init import set_rng_seeds


##hyper ##hyper
#todo 这里加入fastnlp的记录 #todo 这里加入fastnlp的记录
class Config(): class Config():
#seed=7777
model_dir_or_name="en-base-uncased" model_dir_or_name="en-base-uncased"
embedding_grad= False, embedding_grad= False,
bert_embedding_larers= '4,-2,-1' bert_embedding_larers= '4,-2,-1'
train_epoch= 50 train_epoch= 50
num_classes=2 num_classes=2
task= "IMDB"
task= "yelp_p"
#yelp_p #yelp_p
datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv",
"test": "/remote-home/ygwang/yelp_polarity/test.csv"} "test": "/remote-home/ygwang/yelp_polarity/test.csv"}
@@ -46,6 +51,7 @@ class Config():
number_of_characters=69 number_of_characters=69
extra_characters='' extra_characters=''
max_length=1014 max_length=1014
weight_decay = 1e-5


char_cnn_config={ char_cnn_config={
"alphabet": { "alphabet": {
@@ -104,12 +110,15 @@ class Config():
} }
ops=Config ops=Config


# set_rng_seeds(ops.seed)
# print('RNG SEED: {}'.format(ops.seed))



##1.task相关信息:利用dataloader载入dataInfo ##1.task相关信息:利用dataloader载入dataInfo
#dataloader=sst2Loader()
#dataloader=SST2Loader()
#dataloader=IMDBLoader() #dataloader=IMDBLoader()
dataloader=yelpLoader(fine_grained=True) dataloader=yelpLoader(fine_grained=True)
datainfo=dataloader.process(ops.datapath,char_level_op=True)
datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False)
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
ops.number_of_characters=len(char_vocab) ops.number_of_characters=len(char_vocab)
ops.embedding_dim=ops.number_of_characters ops.embedding_dim=ops.number_of_characters
@@ -186,12 +195,20 @@ model=CharacterLevelCNN(ops,embedding)
## 3. 声明loss,metric,optimizer ## 3. 声明loss,metric,optimizer
loss=CrossEntropyLoss loss=CrossEntropyLoss
metric=AccuracyMetric metric=AccuracyMetric
optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
callbacks = []
# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
callbacks.append(
LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
ops.train_epoch * 0.8 else ops.lr * 0.1))
)


## 4.定义train方法 ## 4.定义train方法
def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),
metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=0, check_code_level=-1,
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size,
metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1,
n_epochs=num_epochs) n_epochs=num_epochs)
print(trainer.train()) print(trainer.train())




Loading…
Cancel
Save