|
- # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
- import os
- os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
- os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
-
-
- import torch.nn as nn
-
- from data.SSTLoader import SSTLoader
- from data.IMDBLoader import IMDBLoader
- from data.yelpLoader import yelpLoader
- from fastNLP.modules.encoder.embedding import StaticEmbedding
- from model.lstm_self_attention import BiLSTM_SELF_ATTENTION
-
- from fastNLP.core.const import Const as C
- from fastNLP import CrossEntropyLoss, AccuracyMetric
- from fastNLP import Trainer, Tester
- from torch.optim import Adam
- from fastNLP.io.model_io import ModelLoader, ModelSaver
-
- import argparse
-
-
- class Config():
- train_epoch= 10
- lr=0.001
-
- num_classes=2
- hidden_dim=256
- num_layers=1
- attention_unit=256
- attention_hops=1
- nfc=128
-
- task_name = "IMDB"
- datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"}
- load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51"
- save_model_path="./result_IMDB_test/"
- opt=Config
-
-
- # load data
- dataloaders = {
- "IMDB":IMDBLoader(),
- "YELP":yelpLoader(),
- "SST-5":SSTLoader(subtree=True,fine_grained=True),
- "SST-3":SSTLoader(subtree=True,fine_grained=False)
- }
-
- if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]:
- raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']")
-
- dataloader = dataloaders[opt.task_name]
- datainfo=dataloader.process(opt.datapath)
- # print(datainfo.datasets["train"])
- # print(datainfo)
-
-
- # define model
- vocab=datainfo.vocabs['words']
- embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True)
- model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc)
-
-
- # define loss_function and metrics
- loss=CrossEntropyLoss()
- metrics=AccuracyMetric()
- optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr)
-
-
- def train(datainfo, model, optimizer, loss, metrics, opt):
- trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
- metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1,
- n_epochs=opt.train_epoch, save_path=opt.save_model_path)
- trainer.train()
-
-
- def test(datainfo, metrics, opt):
- # load model
- model = ModelLoader.load_pytorch_model(opt.load_model_path)
- print("model loaded!")
-
- # Tester
- tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0)
- acc = tester.test()
- print("acc=",acc)
-
-
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model')
-
-
- args = parser.parse_args()
- if args.mode == 'train':
- train(datainfo, model, optimizer, loss, metrics, opt)
- elif args.mode == 'test':
- test(datainfo, metrics, opt)
- else:
- print('no mode specified for model!')
- parser.print_help()
|