diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index a318cc61..09ca7d42 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -20,7 +20,7 @@ model name | yelp_p | yelp_f | sst-2|IMDB char_cnn | 93.80/95.12 | - | - |- dpcnn | 95.50/97.36 | - | - |- HAN |- | - | - |- -LSTM| 95.74/- |- |- |88.52/- -AWD-LSTM| 95.96/- |- |- |88.91/- -LSTM+self_attention| 96.34/- | - | - |89.53/- +LSTM| 95.74/- |64.16/- |- |88.52/- +AWD-LSTM| 95.96/- |64.74/- |- |88.91/- +LSTM+self_attention| 96.34/- | 65.78/- | - |89.53/- diff --git a/reproduction/text_classification/data/SSTLoader.py b/reproduction/text_classification/data/SSTLoader.py deleted file mode 100644 index d8403b7a..00000000 --- a/reproduction/text_classification/data/SSTLoader.py +++ /dev/null @@ -1,187 +0,0 @@ -from typing import Iterable -from nltk import Tree -from fastNLP.io.base_loader import DataInfo, DataSetLoader -from fastNLP.core.vocabulary import VocabularyOption, Vocabulary -from fastNLP import DataSet -from fastNLP import Instance -from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader -import csv -from typing import Union, Dict - -class SSTLoader(DataSetLoader): - URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' - DATA_DIR = 'sst/' - - """ - 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` - - 读取SST数据集, DataSet包含fields:: - - words: list(str) 需要分类的文本 - target: str 文本的标签 - - 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip - - :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` - :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` - """ - - def __init__(self, subtree=False, fine_grained=False): - self.subtree = subtree - - tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', - '3': 'positive', '4': 'very positive'} - if not fine_grained: - tag_v['0'] = tag_v['1'] - tag_v['4'] = tag_v['3'] - self.tag_v = tag_v - - def _load(self, path): - """ - - :param str path: 存储数据的路径 - :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 - """ - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - datas = [] - for l in f: - datas.extend([(s, self.tag_v[t]) - for s, t in self._get_one(l, self.subtree)]) - ds = DataSet() - for words, tag in datas: - ds.append(Instance(words=words, target=tag)) - return ds - - @staticmethod - def _get_one(data, subtree): - tree = Tree.fromstring(data) - if subtree: - return [(t.leaves(), t.label()) for t in tree.subtrees()] - return [(tree.leaves(), tree.label())] - - def process(self, - paths, - train_ds: Iterable[str] = None, - src_vocab_op: VocabularyOption = None, - tgt_vocab_op: VocabularyOption = None, - src_embed_op: EmbeddingOption = None): - input_name, target_name = 'words', 'target' - src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) - tgt_vocab = Vocabulary(unknown=None, padding=None) \ - if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) - - info = DataInfo(datasets=self.load(paths)) - _train_ds = [info.datasets[name] - for name in train_ds] if train_ds else info.datasets.values() - src_vocab.from_dataset(*_train_ds, field_name=input_name) - tgt_vocab.from_dataset(*_train_ds, field_name=target_name) - src_vocab.index_dataset( - *info.datasets.values(), - field_name=input_name, new_field_name=input_name) - tgt_vocab.index_dataset( - *info.datasets.values(), - field_name=target_name, new_field_name=target_name) - info.vocabs = { - input_name: src_vocab, - target_name: tgt_vocab - } - - if src_embed_op is not None: - src_embed_op.vocab = src_vocab - init_emb = EmbedLoader.load_with_vocab(**src_embed_op) - info.embeddings[input_name] = init_emb - - for name, dataset in info.datasets.items(): - dataset.set_input(input_name) - dataset.set_target(target_name) - - return info - -class sst2Loader(DataSetLoader): - ''' - 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', - ''' - def __init__(self): - super(sst2Loader, self).__init__() - - def _load(self, path: str) -> DataSet: - ds = DataSet() - all_count=0 - csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') - skip_row = 0 - for idx,row in enumerate(csv_reader): - if idx<=skip_row: - continue - target = row[1] - words = row[0].split() - ds.append(Instance(words=words,target=target)) - all_count+=1 - print("all count:", all_count) - return ds - - def process(self, - paths: Union[str, Dict[str, str]], - src_vocab_opt: VocabularyOption = None, - tgt_vocab_opt: VocabularyOption = None, - src_embed_opt: EmbeddingOption = None, - char_level_op=False): - - paths = check_dataloader_paths(paths) - datasets = {} - info = DataInfo() - for name, path in paths.items(): - dataset = self.load(path) - datasets[name] = dataset - - def wordtochar(words): - chars=[] - for word in words: - word=word.lower() - for char in word: - chars.append(char) - return chars - - input_name, target_name = 'words', 'target' - info.vocabs={} - - # 就分隔为char形式 - if char_level_op: - for dataset in datasets.values(): - dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') - - src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) - src_vocab.from_dataset(datasets['train'], field_name='words') - src_vocab.index_dataset(*datasets.values(), field_name='words') - - tgt_vocab = Vocabulary(unknown=None, padding=None) \ - if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) - tgt_vocab.from_dataset(datasets['train'], field_name='target') - tgt_vocab.index_dataset(*datasets.values(), field_name='target') - - - info.vocabs = { - "words": src_vocab, - "target": tgt_vocab - } - - info.datasets = datasets - - - if src_embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) - info.embeddings['words'] = embed - - return info - -if __name__=="__main__": - datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", - "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} - datainfo=sst2Loader().process(datapath,char_level_op=True) - #print(datainfo.datasets["train"]) - len_count = 0 - for instance in datainfo.datasets["train"]: - len_count += len(instance["chars"]) - - ave_len = len_count / len(datainfo.datasets["train"]) - print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/data/sstLoader.py b/reproduction/text_classification/data/sstloader.py similarity index 97% rename from reproduction/text_classification/data/sstLoader.py rename to reproduction/text_classification/data/sstloader.py index 14524ea5..97cd935e 100644 --- a/reproduction/text_classification/data/sstLoader.py +++ b/reproduction/text_classification/data/sstloader.py @@ -9,27 +9,22 @@ import csv from typing import Union, Dict from reproduction.utils import check_dataloader_paths, get_tokenizer + class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' DATA_DIR = 'sst/' """ 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` - 读取SST数据集, DataSet包含fields:: - words: list(str) 需要分类的文本 target: str 文本的标签 - 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip - :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ - def __init__(self, subtree=False, fine_grained=False): self.subtree = subtree - tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', '3': 'positive', '4': 'very positive'} if not fine_grained: @@ -39,7 +34,6 @@ class SSTLoader(DataSetLoader): def _load(self, path): """ - :param str path: 存储数据的路径 :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 """ @@ -54,6 +48,7 @@ class SSTLoader(DataSetLoader): ds.append(Instance(words=words, target=tag)) return ds + @staticmethod def _get_one(data, subtree): tree = Tree.fromstring(data) @@ -61,6 +56,7 @@ class SSTLoader(DataSetLoader): return [(t.leaves(), t.label()) for t in tree.subtrees()] return [(tree.leaves(), tree.label())] + def process(self, paths, train_ds: Iterable[str] = None, @@ -88,25 +84,30 @@ class SSTLoader(DataSetLoader): target_name: tgt_vocab } + if src_embed_op is not None: src_embed_op.vocab = src_vocab init_emb = EmbedLoader.load_with_vocab(**src_embed_op) info.embeddings[input_name] = init_emb + for name, dataset in info.datasets.items(): dataset.set_input(input_name) dataset.set_target(target_name) - return info + + class sst2Loader(DataSetLoader): ''' 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', ''' + def __init__(self): super(sst2Loader, self).__init__() self.tokenizer = get_tokenizer() + def _load(self, path: str) -> DataSet: ds = DataSet() all_count=0 @@ -122,6 +123,8 @@ class sst2Loader(DataSetLoader): print("all count:", all_count) return ds + + def process(self, paths: Union[str, Dict[str, str]], src_vocab_opt: VocabularyOption = None, @@ -153,7 +156,6 @@ class sst2Loader(DataSetLoader): if char_level_op: for dataset in datasets.values(): dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') - src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) src_vocab.from_dataset(datasets['train'], field_name='words') src_vocab.index_dataset(*datasets.values(), field_name='words') @@ -171,21 +173,26 @@ class sst2Loader(DataSetLoader): info.datasets = datasets - if src_embed_opt is not None: embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) info.embeddings['words'] = embed + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + return info + + if __name__=="__main__": datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} datainfo=sst2Loader().process(datapath,char_level_op=True) #print(datainfo.datasets["train"]) + len_count = 0 for instance in datainfo.datasets["train"]: len_count += len(instance["chars"]) - ave_len = len_count / len(datainfo.datasets["train"]) print(ave_len) \ No newline at end of file