From 5d9e064ec27595a1e09d7ddcbb27280c685b0701 Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 01:11:46 +0800 Subject: [PATCH] text_classfication --- .../text_classification/data/SSTLoader.py | 90 ++++++++++++++++++- .../text_classification/train_awdlstm.py | 41 +-------- .../text_classification/train_lstm.py | 43 ++------- .../text_classification/train_lstm_att.py | 41 +-------- 4 files changed, 102 insertions(+), 113 deletions(-) diff --git a/reproduction/text_classification/data/SSTLoader.py b/reproduction/text_classification/data/SSTLoader.py index b570994e..d8403b7a 100644 --- a/reproduction/text_classification/data/SSTLoader.py +++ b/reproduction/text_classification/data/SSTLoader.py @@ -5,7 +5,8 @@ from fastNLP.core.vocabulary import VocabularyOption, Vocabulary from fastNLP import DataSet from fastNLP import Instance from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader - +import csv +from typing import Union, Dict class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' @@ -97,3 +98,90 @@ class SSTLoader(DataSetLoader): return info +class sst2Loader(DataSetLoader): + ''' + 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', + ''' + def __init__(self): + super(sst2Loader, self).__init__() + + def _load(self, path: str) -> DataSet: + ds = DataSet() + all_count=0 + csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') + skip_row = 0 + for idx,row in enumerate(csv_reader): + if idx<=skip_row: + continue + target = row[1] + words = row[0].split() + ds.append(Instance(words=words,target=target)) + all_count+=1 + print("all count:", all_count) + return ds + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None, + char_level_op=False): + + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars=[] + for word in words: + word=word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs={} + + # 就分隔为char形式 + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + return info + +if __name__=="__main__": + datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", + "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} + datainfo=sst2Loader().process(datapath,char_level_op=True) + #print(datainfo.datasets["train"]) + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py index ce3e52bc..e67bd25b 100644 --- a/reproduction/text_classification/train_awdlstm.py +++ b/reproduction/text_classification/train_awdlstm.py @@ -8,9 +8,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.awd_lstm import AWDLSTMSentiment @@ -41,18 +39,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -71,32 +60,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': +if __name__ == "__main__": train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help() diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py index b320e79c..b89abc14 100644 --- a/reproduction/text_classification/train_lstm.py +++ b/reproduction/text_classification/train_lstm.py @@ -6,9 +6,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.lstm import BiLSTMSentiment @@ -38,18 +36,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -68,32 +57,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': - train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help() +if __name__ == "__main__": + train(datainfo, model, optimizer, loss, metrics, opt) \ No newline at end of file diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py index 8db27d09..b4d37525 100644 --- a/reproduction/text_classification/train_lstm_att.py +++ b/reproduction/text_classification/train_lstm_att.py @@ -6,9 +6,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.lstm_self_attention import BiLSTM_SELF_ATTENTION @@ -40,18 +38,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -70,32 +59,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': +if __name__ == "__main__": train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help()