diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index 04ae94f9..d2272a88 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -131,7 +131,9 @@ class yelpLoader(DataSetLoader): src_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None, embed_opt: EmbeddingOption = None, - char_level_op=False): + char_level_op=False, + split_dev_op=True + ): paths = check_dataloader_paths(paths) datasets = {} info = DataBundle(datasets=self.load(paths)) @@ -172,7 +174,8 @@ class yelpLoader(DataSetLoader): info.vocabs[target_name]=tgt_vocab - info.datasets['train'],info.datasets['dev']=info.datasets['train'].split(0.1, shuffle=False) + if split_dev_op: + info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) for name, dataset in info.datasets.items(): dataset.set_input("words") diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index 050527fe..e4bb9220 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -8,7 +8,8 @@ sys.path.append('../..') from fastNLP.core.const import Const as C import torch.nn as nn from data.yelpLoader import yelpLoader -from data.sstLoader import sst2Loader +#from data.sstLoader import sst2Loader +from fastNLP.io.data_loader.sst import SST2Loader from data.IMDBLoader import IMDBLoader from model.char_cnn import CharacterLevelCNN from fastNLP.core.vocabulary import Vocabulary @@ -20,16 +21,20 @@ from torch.optim import SGD from torch.autograd import Variable import torch from fastNLP import BucketSampler +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from fastNLP.core import LRScheduler +from utils.util_init import set_rng_seeds ##hyper #todo 这里加入fastnlp的记录 class Config(): + #seed=7777 model_dir_or_name="en-base-uncased" embedding_grad= False, bert_embedding_larers= '4,-2,-1' train_epoch= 50 num_classes=2 - task= "IMDB" + task= "yelp_p" #yelp_p datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", "test": "/remote-home/ygwang/yelp_polarity/test.csv"} @@ -46,6 +51,7 @@ class Config(): number_of_characters=69 extra_characters='' max_length=1014 + weight_decay = 1e-5 char_cnn_config={ "alphabet": { @@ -104,12 +110,15 @@ class Config(): } ops=Config +# set_rng_seeds(ops.seed) +# print('RNG SEED: {}'.format(ops.seed)) + ##1.task相关信息:利用dataloader载入dataInfo -#dataloader=sst2Loader() +#dataloader=SST2Loader() #dataloader=IMDBLoader() dataloader=yelpLoader(fine_grained=True) -datainfo=dataloader.process(ops.datapath,char_level_op=True) +datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False) char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] ops.number_of_characters=len(char_vocab) ops.embedding_dim=ops.number_of_characters @@ -186,12 +195,20 @@ model=CharacterLevelCNN(ops,embedding) ## 3. 声明loss,metric,optimizer loss=CrossEntropyLoss metric=AccuracyMetric -optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) +#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) +optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], + lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) +callbacks = [] +# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) +callbacks.append( + LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < + ops.train_epoch * 0.8 else ops.lr * 0.1)) +) ## 4.定义train方法 def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): - trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'), - metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size, + metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1, n_epochs=num_epochs) print(trainer.train())