| @@ -5,7 +5,8 @@ from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
| import csv | |||
| from typing import Union, Dict | |||
| class SSTLoader(DataSetLoader): | |||
| URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||
| @@ -97,3 +98,90 @@ class SSTLoader(DataSetLoader): | |||
| return info | |||
| class sst2Loader(DataSetLoader): | |||
| ''' | |||
| 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||
| ''' | |||
| def __init__(self): | |||
| super(sst2Loader, self).__init__() | |||
| def _load(self, path: str) -> DataSet: | |||
| ds = DataSet() | |||
| all_count=0 | |||
| csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') | |||
| skip_row = 0 | |||
| for idx,row in enumerate(csv_reader): | |||
| if idx<=skip_row: | |||
| continue | |||
| target = row[1] | |||
| words = row[0].split() | |||
| ds.append(Instance(words=words,target=target)) | |||
| all_count+=1 | |||
| print("all count:", all_count) | |||
| return ds | |||
| def process(self, | |||
| paths: Union[str, Dict[str, str]], | |||
| src_vocab_opt: VocabularyOption = None, | |||
| tgt_vocab_opt: VocabularyOption = None, | |||
| src_embed_opt: EmbeddingOption = None, | |||
| char_level_op=False): | |||
| paths = check_dataloader_paths(paths) | |||
| datasets = {} | |||
| info = DataInfo() | |||
| for name, path in paths.items(): | |||
| dataset = self.load(path) | |||
| datasets[name] = dataset | |||
| def wordtochar(words): | |||
| chars=[] | |||
| for word in words: | |||
| word=word.lower() | |||
| for char in word: | |||
| chars.append(char) | |||
| return chars | |||
| input_name, target_name = 'words', 'target' | |||
| info.vocabs={} | |||
| # 就分隔为char形式 | |||
| if char_level_op: | |||
| for dataset in datasets.values(): | |||
| dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
| src_vocab.from_dataset(datasets['train'], field_name='words') | |||
| src_vocab.index_dataset(*datasets.values(), field_name='words') | |||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
| tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||
| tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||
| info.vocabs = { | |||
| "words": src_vocab, | |||
| "target": tgt_vocab | |||
| } | |||
| info.datasets = datasets | |||
| if src_embed_opt is not None: | |||
| embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) | |||
| info.embeddings['words'] = embed | |||
| return info | |||
| if __name__=="__main__": | |||
| datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", | |||
| "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} | |||
| datainfo=sst2Loader().process(datapath,char_level_op=True) | |||
| #print(datainfo.datasets["train"]) | |||
| len_count = 0 | |||
| for instance in datainfo.datasets["train"]: | |||
| len_count += len(instance["chars"]) | |||
| ave_len = len_count / len(datainfo.datasets["train"]) | |||
| print(ave_len) | |||
| @@ -8,9 +8,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||
| import torch.nn as nn | |||
| from data.SSTLoader import SSTLoader | |||
| from data.IMDBLoader import IMDBLoader | |||
| from data.yelpLoader import yelpLoader | |||
| from fastNLP.modules.encoder.embedding import StaticEmbedding | |||
| from model.awd_lstm import AWDLSTMSentiment | |||
| @@ -41,18 +39,9 @@ opt=Config | |||
| # load data | |||
| dataloaders = { | |||
| "IMDB":IMDBLoader(), | |||
| "YELP":yelpLoader(), | |||
| "SST-5":SSTLoader(subtree=True,fine_grained=True), | |||
| "SST-3":SSTLoader(subtree=True,fine_grained=False) | |||
| } | |||
| if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: | |||
| raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") | |||
| dataloader = dataloaders[opt.task_name] | |||
| dataloader=IMDBLoader() | |||
| datainfo=dataloader.process(opt.datapath) | |||
| # print(datainfo.datasets["train"]) | |||
| # print(datainfo) | |||
| @@ -71,32 +60,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T | |||
| def train(datainfo, model, optimizer, loss, metrics, opt): | |||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
| metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, | |||
| metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, | |||
| n_epochs=opt.train_epoch, save_path=opt.save_model_path) | |||
| trainer.train() | |||
| def test(datainfo, metrics, opt): | |||
| # load model | |||
| model = ModelLoader.load_pytorch_model(opt.load_model_path) | |||
| print("model loaded!") | |||
| # Tester | |||
| tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) | |||
| acc = tester.test() | |||
| print("acc=",acc) | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') | |||
| args = parser.parse_args() | |||
| if args.mode == 'train': | |||
| if __name__ == "__main__": | |||
| train(datainfo, model, optimizer, loss, metrics, opt) | |||
| elif args.mode == 'test': | |||
| test(datainfo, metrics, opt) | |||
| else: | |||
| print('no mode specified for model!') | |||
| parser.print_help() | |||
| @@ -6,9 +6,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||
| import torch.nn as nn | |||
| from data.SSTLoader import SSTLoader | |||
| from data.IMDBLoader import IMDBLoader | |||
| from data.yelpLoader import yelpLoader | |||
| from fastNLP.modules.encoder.embedding import StaticEmbedding | |||
| from model.lstm import BiLSTMSentiment | |||
| @@ -38,18 +36,9 @@ opt=Config | |||
| # load data | |||
| dataloaders = { | |||
| "IMDB":IMDBLoader(), | |||
| "YELP":yelpLoader(), | |||
| "SST-5":SSTLoader(subtree=True,fine_grained=True), | |||
| "SST-3":SSTLoader(subtree=True,fine_grained=False) | |||
| } | |||
| if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: | |||
| raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") | |||
| dataloader = dataloaders[opt.task_name] | |||
| dataloader=IMDBLoader() | |||
| datainfo=dataloader.process(opt.datapath) | |||
| # print(datainfo.datasets["train"]) | |||
| # print(datainfo) | |||
| @@ -68,32 +57,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T | |||
| def train(datainfo, model, optimizer, loss, metrics, opt): | |||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
| metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, | |||
| metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, | |||
| n_epochs=opt.train_epoch, save_path=opt.save_model_path) | |||
| trainer.train() | |||
| def test(datainfo, metrics, opt): | |||
| # load model | |||
| model = ModelLoader.load_pytorch_model(opt.load_model_path) | |||
| print("model loaded!") | |||
| # Tester | |||
| tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) | |||
| acc = tester.test() | |||
| print("acc=",acc) | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') | |||
| args = parser.parse_args() | |||
| if args.mode == 'train': | |||
| train(datainfo, model, optimizer, loss, metrics, opt) | |||
| elif args.mode == 'test': | |||
| test(datainfo, metrics, opt) | |||
| else: | |||
| print('no mode specified for model!') | |||
| parser.print_help() | |||
| if __name__ == "__main__": | |||
| train(datainfo, model, optimizer, loss, metrics, opt) | |||
| @@ -6,9 +6,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||
| import torch.nn as nn | |||
| from data.SSTLoader import SSTLoader | |||
| from data.IMDBLoader import IMDBLoader | |||
| from data.yelpLoader import yelpLoader | |||
| from fastNLP.modules.encoder.embedding import StaticEmbedding | |||
| from model.lstm_self_attention import BiLSTM_SELF_ATTENTION | |||
| @@ -40,18 +38,9 @@ opt=Config | |||
| # load data | |||
| dataloaders = { | |||
| "IMDB":IMDBLoader(), | |||
| "YELP":yelpLoader(), | |||
| "SST-5":SSTLoader(subtree=True,fine_grained=True), | |||
| "SST-3":SSTLoader(subtree=True,fine_grained=False) | |||
| } | |||
| if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: | |||
| raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") | |||
| dataloader = dataloaders[opt.task_name] | |||
| dataloader=IMDBLoader() | |||
| datainfo=dataloader.process(opt.datapath) | |||
| # print(datainfo.datasets["train"]) | |||
| # print(datainfo) | |||
| @@ -70,32 +59,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T | |||
| def train(datainfo, model, optimizer, loss, metrics, opt): | |||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
| metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, | |||
| metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, | |||
| n_epochs=opt.train_epoch, save_path=opt.save_model_path) | |||
| trainer.train() | |||
| def test(datainfo, metrics, opt): | |||
| # load model | |||
| model = ModelLoader.load_pytorch_model(opt.load_model_path) | |||
| print("model loaded!") | |||
| # Tester | |||
| tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) | |||
| acc = tester.test() | |||
| print("acc=",acc) | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') | |||
| args = parser.parse_args() | |||
| if args.mode == 'train': | |||
| if __name__ == "__main__": | |||
| train(datainfo, model, optimizer, loss, metrics, opt) | |||
| elif args.mode == 'test': | |||
| test(datainfo, metrics, opt) | |||
| else: | |||
| print('no mode specified for model!') | |||
| parser.print_help() | |||