diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index 021a79b7..8d0d005f 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -5,10 +5,8 @@ from ..base_loader import DataInfo, DataSetLoader from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet from ...core.instance import Instance -from ..embed_loader import EmbeddingOption, EmbedLoader +from ..utils import check_dataloader_paths, get_tokenizer -spacy.prefer_gpu() -sptk = spacy.load('en') class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' @@ -37,6 +35,7 @@ class SSTLoader(DataSetLoader): tag_v['0'] = tag_v['1'] tag_v['4'] = tag_v['3'] self.tag_v = tag_v + self.tokenizer = get_tokenizer() def _load(self, path): """ @@ -55,29 +54,37 @@ class SSTLoader(DataSetLoader): ds.append(Instance(words=words, target=tag)) return ds - @staticmethod - def _get_one(data, subtree): + def _get_one(self, data, subtree): tree = Tree.fromstring(data) if subtree: - return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] - return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())] + return [([x.text for x in self.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] + return [([x.text for x in self.tokenizer(' '.join(tree.leaves()))], tree.label())] def process(self, - paths, - train_ds: Iterable[str] = None, + paths, train_subtree=True, src_vocab_op: VocabularyOption = None, - tgt_vocab_op: VocabularyOption = None, - src_embed_op: EmbeddingOption = None): + tgt_vocab_op: VocabularyOption = None,): + paths = check_dataloader_paths(paths) input_name, target_name = 'words', 'target' src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) - info = DataInfo(datasets=self.load(paths)) - _train_ds = [info.datasets[name] - for name in train_ds] if train_ds else info.datasets.values() - src_vocab.from_dataset(*_train_ds, field_name=input_name) - tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + info = DataInfo() + origin_subtree = self.subtree + self.subtree = train_subtree + info.datasets['train'] = self._load(paths['train']) + self.subtree = origin_subtree + for n, p in paths.items(): + if n != 'train': + info.datasets[n] = self._load(p) + + src_vocab.from_dataset( + info.datasets['train'], + field_name=input_name, + no_create_entry_dataset=[ds for n, ds in info.datasets.items() if n != 'train']) + tgt_vocab.from_dataset(info.datasets['train'], field_name=target_name) + src_vocab.index_dataset( *info.datasets.values(), field_name=input_name, new_field_name=input_name) @@ -89,10 +96,5 @@ class SSTLoader(DataSetLoader): target_name: tgt_vocab } - if src_embed_op is not None: - src_embed_op.vocab = src_vocab - init_emb = EmbedLoader.load_with_vocab(**src_embed_op) - info.embeddings[input_name] = init_emb - return info diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py new file mode 100644 index 00000000..a7d2de85 --- /dev/null +++ b/fastNLP/io/utils.py @@ -0,0 +1,69 @@ +import os + +from typing import Union, Dict + + +def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: + """ + 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 + { + 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 + 'test': 'xxx' # 可能有,也可能没有 + ... + } + 如果paths为不合法的,将直接进行raise相应的错误 + + :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 + 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 + :return: + """ + if isinstance(paths, str): + if os.path.isfile(paths): + return {'train': paths} + elif os.path.isdir(paths): + filenames = os.listdir(paths) + files = {} + for filename in filenames: + path_pair = None + if 'train' in filename: + path_pair = ('train', filename) + if 'dev' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) + path_pair = ('dev', filename) + if 'test' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) + path_pair = ('test', filename) + if path_pair: + files[path_pair[0]] = os.path.join(paths, path_pair[1]) + return files + else: + raise FileNotFoundError(f"{paths} is not a valid file path.") + + elif isinstance(paths, dict): + if paths: + if 'train' not in paths: + raise KeyError("You have to include `train` in your dict.") + for key, value in paths.items(): + if isinstance(key, str) and isinstance(value, str): + if not os.path.isfile(value): + raise TypeError(f"{value} is not a valid file.") + else: + raise TypeError("All keys and values in paths should be str.") + return paths + else: + raise ValueError("Empty paths is not allowed.") + else: + raise TypeError(f"paths only supports str and dict. not {type(paths)}.") + +def get_tokenizer(): + try: + import spacy + spacy.prefer_gpu() + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split() diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 4101b033..2bee7f2e 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -19,7 +19,7 @@ class DotAttention(nn.Module): 补上文档 """ - def __init__(self, key_size, value_size, dropout=0): + def __init__(self, key_size, value_size, dropout=0.0): super(DotAttention, self).__init__() self.key_size = key_size self.value_size = value_size @@ -37,7 +37,7 @@ class DotAttention(nn.Module): """ output = torch.matmul(Q, K.transpose(1, 2)) / self.scale if mask_out is not None: - output.masked_fill_(mask_out, -1e8) + output.masked_fill_(mask_out, -1e18) output = self.softmax(output) output = self.drop(output) return torch.matmul(output, V) @@ -67,9 +67,8 @@ class MultiHeadAttention(nn.Module): self.k_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size) # follow the paper, do not apply dropout within dot-product - self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) + self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) self.out = nn.Linear(value_size * num_head, input_size) - self.drop = TimestepDropout(dropout) self.reset_parameters() def reset_parameters(self): @@ -105,7 +104,7 @@ class MultiHeadAttention(nn.Module): # concat all heads, do output linear atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) - output = self.drop(self.out(atte)) + output = self.out(atte) return output diff --git a/reproduction/Star_transformer/README.md b/reproduction/Star_transformer/README.md index 37c5f1e9..d07d5536 100644 --- a/reproduction/Star_transformer/README.md +++ b/reproduction/Star_transformer/README.md @@ -6,7 +6,7 @@ paper: [Star-Transformer](https://arxiv.org/abs/1902.09113) |Pos Tagging|CTB 9.0|-|ACC 92.31| |Pos Tagging|CONLL 2012|-|ACC 96.51| |Named Entity Recognition|CONLL 2012|-|F1 85.66| -|Text Classification|SST|-|49.18| +|Text Classification|SST|-|51.2| |Natural Language Inference|SNLI|-|83.76| ## Usage diff --git a/reproduction/seqence_labelling/ner/model/dilated_cnn.py b/reproduction/seqence_labelling/ner/model/dilated_cnn.py new file mode 100644 index 00000000..a4e02159 --- /dev/null +++ b/reproduction/seqence_labelling/ner/model/dilated_cnn.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from fastNLP.modules.decoder import ConditionalRandomField +from fastNLP.modules.encoder import Embedding +from fastNLP.core.utils import seq_len_to_mask +from fastNLP.core.const import Const as C + + +class IDCNN(nn.Module): + def __init__(self, + init_embed, + char_embed, + num_cls, + repeats, num_layers, num_filters, kernel_size, + use_crf=False, use_projection=False, block_loss=False, + input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): + super(IDCNN, self).__init__() + self.word_embeddings = Embedding(init_embed) + + if char_embed is None: + self.char_embeddings = None + embedding_size = self.word_embeddings.embedding_dim + else: + self.char_embeddings = Embedding(char_embed) + embedding_size = self.word_embeddings.embedding_dim + \ + self.char_embeddings.embedding_dim + + self.conv0 = nn.Sequential( + nn.Conv1d(in_channels=embedding_size, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=1, + padding=kernel_size//2, + bias=True), + nn.ReLU(), + ) + + block = [] + for layer_i in range(num_layers): + dilated = 2 ** layer_i if layer_i+1 < num_layers else 1 + block.append(nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=dilated, + padding=(kernel_size//2) * dilated, + bias=True)) + block.append(nn.ReLU()) + self.block = nn.Sequential(*block) + + if use_projection: + self.projection = nn.Sequential( + nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters//2, + kernel_size=1, + bias=True), + nn.ReLU(),) + encode_dim = num_filters // 2 + else: + self.projection = None + encode_dim = num_filters + + self.input_drop = nn.Dropout(input_dropout) + self.hidden_drop = nn.Dropout(hidden_dropout) + self.inner_drop = nn.Dropout(inner_dropout) + self.repeats = repeats + self.out_fc = nn.Conv1d( + in_channels=encode_dim, + out_channels=num_cls, + kernel_size=1, + bias=True) + self.crf = ConditionalRandomField( + num_tags=num_cls) if use_crf else None + self.block_loss = block_loss + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.xavier_normal_(m.weight, gain=1) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + def forward(self, words, seq_len, target=None, chars=None): + if self.char_embeddings is None: + x = self.word_embeddings(words) + else: + if chars is None: + raise ValueError('must provide chars for model with char embedding') + e1 = self.word_embeddings(words) + e2 = self.char_embeddings(chars) + x = torch.cat((e1, e2), dim=-1) # b,l,h + mask = seq_len_to_mask(seq_len) + + x = x.transpose(1, 2) # b,h,l + last_output = self.conv0(x) + output = [] + for repeat in range(self.repeats): + last_output = self.block(last_output) + hidden = self.projection(last_output) if self.projection is not None else last_output + output.append(self.out_fc(hidden)) + + def compute_loss(y, t, mask): + if self.crf is not None and target is not None: + loss = self.crf(y.transpose(1, 2), t, mask) + else: + t.masked_fill_(mask == 0, -100) + loss = F.cross_entropy(y, t, ignore_index=-100) + return loss + + if target is not None: + if self.block_loss: + losses = [compute_loss(o, target, mask) for o in output] + loss = sum(losses) + else: + loss = compute_loss(output[-1], target, mask) + else: + loss = None + + scores = output[-1] + if self.crf is not None: + pred, _ = self.crf.viterbi_decode(scores.transpose(1, 2), mask) + else: + pred = scores.max(1)[1] * mask.long() + + return { + C.LOSS: loss, + C.OUTPUT: pred, + } + + def predict(self, words, seq_len, chars=None): + res = self.forward( + words=words, + seq_len=seq_len, + chars=chars, + target=None + )[C.OUTPUT] + return { + C.OUTPUT: res + } diff --git a/reproduction/seqence_labelling/ner/train_idcnn.py b/reproduction/seqence_labelling/ner/train_idcnn.py new file mode 100644 index 00000000..1781c763 --- /dev/null +++ b/reproduction/seqence_labelling/ner/train_idcnn.py @@ -0,0 +1,99 @@ +from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader +from fastNLP.core.callback import FitlogCallback, LRScheduler +from fastNLP import GradientClipCallback +from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR +from torch.optim import SGD, Adam +from fastNLP import Const +from fastNLP import RandomSampler, BucketSampler +from fastNLP import SpanFPreRecMetric +from fastNLP import Trainer +from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN +from fastNLP.core.utils import Option +from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding +from fastNLP.core.utils import cache_results +import sys +import torch.cuda +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +encoding_type = 'bioes' + + +def get_path(path): + return os.path.join(os.environ['HOME'], path) + +data_path = get_path('workdir/datasets/ontonotes-v4') + +ops = Option( + batch_size=128, + num_epochs=100, + lr=3e-4, + repeats=3, + num_layers=3, + num_filters=400, + use_crf=True, + gradient_clip=5, +) + +@cache_results('ontonotes-cache') +def load_data(): + + data = OntoNoteNERDataLoader(encoding_type=encoding_type).process(data_path, + lower=True) + + # char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], + # kernel_sizes=[3]) + + word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], + model_dir_or_name='en-glove-840b-300', + requires_grad=True) + return data, [word_embed] + +data, embeds = load_data() +print(data.datasets['train'][0]) +print(list(data.vocabs.keys())) + +for ds in data.datasets.values(): + ds.rename_field('cap_words', 'chars') + ds.set_input('chars') + +word_embed = embeds[0] +char_embed = CNNCharEmbedding(data.vocabs['cap_words']) +# for ds in data.datasets: +# ds.rename_field('') + +print(data.vocabs[Const.TARGET].word2idx) + +model = IDCNN(init_embed=word_embed, + char_embed=char_embed, + num_cls=len(data.vocabs[Const.TARGET]), + repeats=ops.repeats, + num_layers=ops.num_layers, + num_filters=ops.num_filters, + kernel_size=3, + use_crf=ops.use_crf, use_projection=True, + block_loss=True, + input_dropout=0.33, hidden_dropout=0.2, inner_dropout=0.2) + +print(model) + +callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='norm'),] + +optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=0) +# scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) +# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 15))) +# optimizer = SWATS(model.parameters(), verbose=True) +# optimizer = Adam(model.parameters(), lr=0.005) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, + sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), + device=device, dev_data=data.datasets['dev'], batch_size=ops.batch_size, + metrics=SpanFPreRecMetric( + tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), + check_code_level=-1, + callbacks=callbacks, num_workers=2, n_epochs=ops.num_epochs) +trainer.train() diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md new file mode 100644 index 00000000..a318cc61 --- /dev/null +++ b/reproduction/text_classification/README.md @@ -0,0 +1,26 @@ +# text_classification任务模型复现 +这里使用fastNLP复现以下模型: + +char_cnn :论文链接[Character-level Convolutional Networks for Text Classification](https://arxiv.org/pdf/1509.01626v3.pdf) + +dpcnn:论文链接[Deep Pyramid Convolutional Neural Networks for TextCategorization](https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf) + +HAN:论文链接[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) + +LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding]() + +AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models]() + +# 数据集及复现结果汇总 + +使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) + +model name | yelp_p | yelp_f | sst-2|IMDB +:---: | :---: | :---: | :---: |----- +char_cnn | 93.80/95.12 | - | - |- +dpcnn | 95.50/97.36 | - | - |- +HAN |- | - | - |- +LSTM| 95.74/- |- |- |88.52/- +AWD-LSTM| 95.96/- |- |- |88.91/- +LSTM+self_attention| 96.34/- | - | - |89.53/- + diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py new file mode 100644 index 00000000..30daf233 --- /dev/null +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -0,0 +1,110 @@ +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial + + +class IMDBLoader(DataSetLoader): + """ + 读取IMDB数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1].lower().split() + dataset.append(Instance(words=words, target=target)) + + if len(dataset)==0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None, + char_level_op=False): + + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars = [] + for word in words: + word = word.lower() + for char in word: + chars.append(char) + return chars + + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info + + + +if __name__=="__main__": + datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv", + "test": "/remote-home/ygwang/IMDB_data/test.csv"} + datainfo=IMDBLoader().process(datapath,char_level_op=True) + #print(datainfo.datasets["train"]) + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) + diff --git a/reproduction/text_classification/data/MTL16Loader.py b/reproduction/text_classification/data/MTL16Loader.py index 1b3e6245..066b53b4 100644 --- a/reproduction/text_classification/data/MTL16Loader.py +++ b/reproduction/text_classification/data/MTL16Loader.py @@ -32,7 +32,7 @@ class MTL16Loader(DataSetLoader): continue parts = line.split('\t') target = parts[0] - words = parts[1].split() + words = parts[1].lower().split() dataset.append(Instance(words=words, target=target)) if len(dataset)==0: raise RuntimeError(f"{path} has no valid data.") @@ -72,4 +72,8 @@ class MTL16Loader(DataSetLoader): embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) info.embeddings['words'] = embed + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + return info diff --git a/reproduction/text_classification/data/SSTLoader.py b/reproduction/text_classification/data/SSTLoader.py new file mode 100644 index 00000000..d8403b7a --- /dev/null +++ b/reproduction/text_classification/data/SSTLoader.py @@ -0,0 +1,187 @@ +from typing import Iterable +from nltk import Tree +from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.core.vocabulary import VocabularyOption, Vocabulary +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +import csv +from typing import Union, Dict + +class SSTLoader(DataSetLoader): + URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' + DATA_DIR = 'sst/' + + """ + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` + + 读取SST数据集, DataSet包含fields:: + + words: list(str) 需要分类的文本 + target: str 文本的标签 + + 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + + :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + """ + + def __init__(self, subtree=False, fine_grained=False): + self.subtree = subtree + + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', + '3': 'positive', '4': 'very positive'} + if not fine_grained: + tag_v['0'] = tag_v['1'] + tag_v['4'] = tag_v['3'] + self.tag_v = tag_v + + def _load(self, path): + """ + + :param str path: 存储数据的路径 + :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + datas = [] + for l in f: + datas.extend([(s, self.tag_v[t]) + for s, t in self._get_one(l, self.subtree)]) + ds = DataSet() + for words, tag in datas: + ds.append(Instance(words=words, target=tag)) + return ds + + @staticmethod + def _get_one(data, subtree): + tree = Tree.fromstring(data) + if subtree: + return [(t.leaves(), t.label()) for t in tree.subtrees()] + return [(tree.leaves(), tree.label())] + + def process(self, + paths, + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + src_embed_op: EmbeddingOption = None): + input_name, target_name = 'words', 'target' + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + + info = DataInfo(datasets=self.load(paths)) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + src_vocab.from_dataset(*_train_ds, field_name=input_name) + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + src_vocab.index_dataset( + *info.datasets.values(), + field_name=input_name, new_field_name=input_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs = { + input_name: src_vocab, + target_name: tgt_vocab + } + + if src_embed_op is not None: + src_embed_op.vocab = src_vocab + init_emb = EmbedLoader.load_with_vocab(**src_embed_op) + info.embeddings[input_name] = init_emb + + for name, dataset in info.datasets.items(): + dataset.set_input(input_name) + dataset.set_target(target_name) + + return info + +class sst2Loader(DataSetLoader): + ''' + 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', + ''' + def __init__(self): + super(sst2Loader, self).__init__() + + def _load(self, path: str) -> DataSet: + ds = DataSet() + all_count=0 + csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') + skip_row = 0 + for idx,row in enumerate(csv_reader): + if idx<=skip_row: + continue + target = row[1] + words = row[0].split() + ds.append(Instance(words=words,target=target)) + all_count+=1 + print("all count:", all_count) + return ds + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None, + char_level_op=False): + + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars=[] + for word in words: + word=word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs={} + + # 就分隔为char形式 + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + return info + +if __name__=="__main__": + datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", + "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} + datainfo=sst2Loader().process(datapath,char_level_op=True) + #print(datainfo.datasets["train"]) + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/data/sstLoader.py b/reproduction/text_classification/data/sstLoader.py new file mode 100644 index 00000000..d8403b7a --- /dev/null +++ b/reproduction/text_classification/data/sstLoader.py @@ -0,0 +1,187 @@ +from typing import Iterable +from nltk import Tree +from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.core.vocabulary import VocabularyOption, Vocabulary +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +import csv +from typing import Union, Dict + +class SSTLoader(DataSetLoader): + URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' + DATA_DIR = 'sst/' + + """ + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` + + 读取SST数据集, DataSet包含fields:: + + words: list(str) 需要分类的文本 + target: str 文本的标签 + + 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + + :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + """ + + def __init__(self, subtree=False, fine_grained=False): + self.subtree = subtree + + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', + '3': 'positive', '4': 'very positive'} + if not fine_grained: + tag_v['0'] = tag_v['1'] + tag_v['4'] = tag_v['3'] + self.tag_v = tag_v + + def _load(self, path): + """ + + :param str path: 存储数据的路径 + :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + datas = [] + for l in f: + datas.extend([(s, self.tag_v[t]) + for s, t in self._get_one(l, self.subtree)]) + ds = DataSet() + for words, tag in datas: + ds.append(Instance(words=words, target=tag)) + return ds + + @staticmethod + def _get_one(data, subtree): + tree = Tree.fromstring(data) + if subtree: + return [(t.leaves(), t.label()) for t in tree.subtrees()] + return [(tree.leaves(), tree.label())] + + def process(self, + paths, + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + src_embed_op: EmbeddingOption = None): + input_name, target_name = 'words', 'target' + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + + info = DataInfo(datasets=self.load(paths)) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + src_vocab.from_dataset(*_train_ds, field_name=input_name) + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + src_vocab.index_dataset( + *info.datasets.values(), + field_name=input_name, new_field_name=input_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs = { + input_name: src_vocab, + target_name: tgt_vocab + } + + if src_embed_op is not None: + src_embed_op.vocab = src_vocab + init_emb = EmbedLoader.load_with_vocab(**src_embed_op) + info.embeddings[input_name] = init_emb + + for name, dataset in info.datasets.items(): + dataset.set_input(input_name) + dataset.set_target(target_name) + + return info + +class sst2Loader(DataSetLoader): + ''' + 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', + ''' + def __init__(self): + super(sst2Loader, self).__init__() + + def _load(self, path: str) -> DataSet: + ds = DataSet() + all_count=0 + csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') + skip_row = 0 + for idx,row in enumerate(csv_reader): + if idx<=skip_row: + continue + target = row[1] + words = row[0].split() + ds.append(Instance(words=words,target=target)) + all_count+=1 + print("all count:", all_count) + return ds + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None, + char_level_op=False): + + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars=[] + for word in words: + word=word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs={} + + # 就分隔为char形式 + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + return info + +if __name__=="__main__": + datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", + "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} + datainfo=sst2Loader().process(datapath,char_level_op=True) + #print(datainfo.datasets["train"]) + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index c47d48fd..c5c91f17 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -1,18 +1,64 @@ import ast +import csv +from typing import Iterable from fastNLP import DataSet, Instance, Vocabulary from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io import JsonLoader -from fastNLP.io.base_loader import DataInfo +from fastNLP.io.base_loader import DataInfo,DataSetLoader from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict -from reproduction.Star_transformer.datasets import EmbedLoader -from reproduction.utils import check_dataloader_paths +from reproduction.utils import check_dataloader_paths, get_tokenizer + +def clean_str(sentence, tokenizer, char_lower=False): + """ + heavily borrowed from github + https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb + :param sentence: is a str + :return: + """ + if char_lower: + sentence = sentence.lower() + import re + nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') + words = tokenizer(sentence) + words_collection = [] + for word in words: + if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: + continue + tt = nonalpnum.split(word) + t = ''.join(tt) + if t != '': + words_collection.append(t) + + return words_collection -class yelpLoader(JsonLoader): +class yelpLoader(DataSetLoader): """ + 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: + words: list(str), 需要分类的文本 + target: str, 文本的标签 + chars:list(str),未index的字符列表 + + 数据集:yelp_full/yelp_polarity + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + """ + + def __init__(self, fine_grained=False,lower=False): + super(yelpLoader, self).__init__() + tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', + '4.0': 'positive', '5.0': 'very positive'} + if not fine_grained: + tag_v['1.0'] = tag_v['2.0'] + tag_v['5.0'] = tag_v['4.0'] + self.fine_grained = fine_grained + self.tag_v = tag_v + self.lower = lower + self.tokenizer = get_tokenizer() + + ''' 读取Yelp数据集, DataSet包含fields: review_id: str, 22 character unique review id @@ -27,20 +73,8 @@ class yelpLoader(JsonLoader): 数据来源: https://www.yelp.com/dataset/download - :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` - """ - - def __init__(self, fine_grained=False): - super(yelpLoader, self).__init__() - tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', - '4.0': 'positive', '5.0': 'very positive'} - if not fine_grained: - tag_v['1.0'] = tag_v['2.0'] - tag_v['5.0'] = tag_v['4.0'] - self.fine_grained = fine_grained - self.tag_v = tag_v - - def _load(self, path): + + def _load_json(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): d = ast.literal_eval(d) @@ -48,21 +82,116 @@ class yelpLoader(JsonLoader): d["target"] = self.tag_v[str(d.pop("stars"))] ds.append(Instance(**d)) return ds + + def _load_yelp2015_broken(self,path): + ds = DataSet() + with open (path,encoding='ISO 8859-1') as f: + row=f.readline() + all_count=0 + exp_count=0 + while row: + row=row.split("\t\t") + all_count+=1 + if len(row)>=3: + words=row[-1].split() + try: + target=self.tag_v[str(row[-2])+".0"] + ds.append(Instance(words=words, target=target)) + except KeyError: + exp_count+=1 + else: + exp_count+=1 + row = f.readline() + print("error sample count:",exp_count) + print("all count:",all_count) + return ds + ''' + + def _load(self, path): + ds = DataSet() + csv_reader=csv.reader(open(path,encoding='utf-8')) + all_count=0 + real_count=0 + for row in csv_reader: + all_count+=1 + if len(row)==2: + target=self.tag_v[row[0]+".0"] + words = clean_str(row[1], self.tokenizer, self.lower) + if len(words)!=0: + ds.append(Instance(words=words,target=target)) + real_count += 1 + print("all count:", all_count) + print("real count:", real_count) + return ds + - def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, - embed_opt: EmbeddingOption = None): + + def process(self, paths: Union[str, Dict[str, str]], + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + embed_opt: EmbeddingOption = None, + char_level_op=False): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() - vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) - for name, path in paths.items(): - dataset = self.load(path) - datasets[name] = dataset - vocab.from_dataset(dataset, field_name="words") - info.vocabs = vocab - info.datasets = datasets - if embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) - info.embeddings['words'] = embed + info = DataInfo(datasets=self.load(paths)) + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + + + def wordtochar(words): + + chars=[] + for word in words: + word=word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs={} + #就分隔为char形式 + if char_level_op: + for dataset in info.datasets.values(): + dataset.apply_field(wordtochar, field_name="words",new_field_name='chars') + # if embed_opt is not None: + # embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) + # info.embeddings['words'] = embed + else: + src_vocab.from_dataset(*_train_ds, field_name=input_name) + src_vocab.index_dataset(*info.datasets.values(),field_name=input_name, new_field_name=input_name) + info.vocabs[input_name]=src_vocab + + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + + info.vocabs[target_name]=tgt_vocab + + info.datasets['train'],info.datasets['dev']=info.datasets['train'].split(0.1, shuffle=False) + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + return info +if __name__=="__main__": + testloader=yelpLoader() + # datapath = {"train": "/remote-home/ygwang/yelp_full/train.csv", + # "test": "/remote-home/ygwang/yelp_full/test.csv"} + #datapath={"train": "/remote-home/ygwang/yelp_full/test.csv"} + datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", + "test": "/remote-home/ygwang/yelp_polarity/test.csv"} + datainfo=testloader.process(datapath,char_level_op=True) + + len_count=0 + for instance in datainfo.datasets["train"]: + len_count+=len(instance["chars"]) + + ave_len=len_count/len(datainfo.datasets["train"]) + print(ave_len) diff --git a/reproduction/text_classification/model/HAN.py b/reproduction/text_classification/model/HAN.py new file mode 100644 index 00000000..0902d1e4 --- /dev/null +++ b/reproduction/text_classification/model/HAN.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +from fastNLP.modules.utils import get_embeddings +from fastNLP.core import Const as C + + +def pack_sequence(tensor_seq, padding_value=0.0): + if len(tensor_seq) <= 0: + return + length = [v.size(0) for v in tensor_seq] + max_len = max(length) + size = [len(tensor_seq), max_len] + size.extend(list(tensor_seq[0].size()[1:])) + ans = torch.Tensor(*size).fill_(padding_value) + if tensor_seq[0].data.is_cuda: + ans = ans.cuda() + ans = Variable(ans) + for i, v in enumerate(tensor_seq): + ans[i, :length[i], :] = v + return ans + + +class HANCLS(nn.Module): + def __init__(self, init_embed, num_cls): + super(HANCLS, self).__init__() + + self.embed = get_embeddings(init_embed) + self.han = HAN(input_size=300, + output_size=num_cls, + word_hidden_size=50, word_num_layers=1, word_context_size=100, + sent_hidden_size=50, sent_num_layers=1, sent_context_size=100 + ) + + def forward(self, input_sents): + # input_sents [B, num_sents, seq-len] dtype long + # target + B, num_sents, seq_len = input_sents.size() + input_sents = input_sents.view(-1, seq_len) # flat + words_embed = self.embed(input_sents) # should be [B*num-sent, seqlen , word-dim] + words_embed = words_embed.view(B, num_sents, seq_len, -1) # recover # [B, num-sent, seqlen , word-dim] + out = self.han(words_embed) + + return {C.OUTPUT: out} + + def predict(self, input_sents): + x = self.forward(input_sents)[C.OUTPUT] + return {C.OUTPUT: torch.argmax(x, 1)} + + +class HAN(nn.Module): + def __init__(self, input_size, output_size, + word_hidden_size, word_num_layers, word_context_size, + sent_hidden_size, sent_num_layers, sent_context_size): + super(HAN, self).__init__() + + self.word_layer = AttentionNet(input_size, + word_hidden_size, + word_num_layers, + word_context_size) + self.sent_layer = AttentionNet(2 * word_hidden_size, + sent_hidden_size, + sent_num_layers, + sent_context_size) + self.output_layer = nn.Linear(2 * sent_hidden_size, output_size) + self.softmax = nn.LogSoftmax(dim=1) + + def forward(self, batch_doc): + # input is a sequence of matrix + doc_vec_list = [] + for doc in batch_doc: + sent_mat = self.word_layer(doc) # doc's dim (num_sent, seq_len, word_dim) + doc_vec_list.append(sent_mat) # sent_mat's dim (num_sent, vec_dim) + doc_vec = self.sent_layer(pack_sequence(doc_vec_list)) + output = self.softmax(self.output_layer(doc_vec)) + return output + + +class AttentionNet(nn.Module): + def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size): + super(AttentionNet, self).__init__() + + self.input_size = input_size + self.gru_hidden_size = gru_hidden_size + self.gru_num_layers = gru_num_layers + self.context_vec_size = context_vec_size + + # Encoder + self.gru = nn.GRU(input_size=input_size, + hidden_size=gru_hidden_size, + num_layers=gru_num_layers, + batch_first=True, + bidirectional=True) + # Attention + self.fc = nn.Linear(2 * gru_hidden_size, context_vec_size) + self.tanh = nn.Tanh() + self.softmax = nn.Softmax(dim=1) + # context vector + self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1)) + self.context_vec.data.uniform_(-0.1, 0.1) + + def forward(self, inputs): + # GRU part + h_t, hidden = self.gru(inputs) # inputs's dim (batch_size, seq_len, word_dim) + u = self.tanh(self.fc(h_t)) + # Attention part + alpha = self.softmax(torch.matmul(u, self.context_vec)) # u's dim (batch_size, seq_len, context_vec_size) + output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) # alpha's dim (batch_size, seq_len, 1) + return torch.squeeze(output, dim=2) # output's dim (batch_size, 2*hidden_size, 1) diff --git a/reproduction/text_classification/model/awd_lstm.py b/reproduction/text_classification/model/awd_lstm.py new file mode 100644 index 00000000..0d8f711a --- /dev/null +++ b/reproduction/text_classification/model/awd_lstm.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from .awdlstm_module import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.decoder.mlp import MLP + + +class AWDLSTMSentiment(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + nfc=128, + wdrop=0.5): + super(AWDLSTMSentiment,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop) + self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + output = self.mlp(output[:,-1,:]) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} + diff --git a/reproduction/text_classification/model/awdlstm_module.py b/reproduction/text_classification/model/awdlstm_module.py new file mode 100644 index 00000000..87bfe730 --- /dev/null +++ b/reproduction/text_classification/model/awdlstm_module.py @@ -0,0 +1,86 @@ +""" +轻量封装的 Pytorch LSTM 模块. +可在 forward 时传入序列的长度, 自动对padding做合适的处理. +""" +__all__ = [ + "LSTM" +] + +import torch +import torch.nn as nn +import torch.nn.utils.rnn as rnn + +from fastNLP.modules.utils import initial_parameter +from torch import autograd +from .weight_drop import WeightDrop + + +class LSTM(nn.Module): + """ + 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` + + LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 + 为1; 且可以应对DataParallel中LSTM的使用问题。 + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度. + :param num_layers: rnn的层数. Default: 1 + :param dropout: 层间dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + """ + + def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, + bidirectional=False, bias=True, wdrop=0.5): + super(LSTM, self).__init__() + self.batch_first = batch_first + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, + dropout=dropout, bidirectional=bidirectional) + self.lstm = WeightDrop(self.lstm, ['weight_hh_l0'], dropout=wdrop) + self.init_param() + + def init_param(self): + for name, param in self.named_parameters(): + if 'bias' in name: + # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 + param.data.fill_(0) + n = param.size(0) + start, end = n // 4, n // 2 + param.data[start:end].fill_(1) + else: + nn.init.xavier_uniform_(param) + + def forward(self, x, seq_len=None, h0=None, c0=None): + """ + + :param x: [batch, seq_len, input_size] 输入序列 + :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` + :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` + :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` + :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 + 和 [batch, hidden_size*num_direction] 最后时刻隐状态. + """ + batch_size, max_len, _ = x.size() + if h0 is not None and c0 is not None: + hx = (h0, c0) + else: + hx = None + if seq_len is not None and not isinstance(x, rnn.PackedSequence): + sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) + if self.batch_first: + x = x[sort_idx] + else: + x = x[:, sort_idx] + x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) + output, hx = self.lstm(x, hx) # -> [N,L,C] + output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + if self.batch_first: + output = output[unsort_idx] + else: + output = output[:, unsort_idx] + else: + output, hx = self.lstm(x, hx) + return output, hx diff --git a/reproduction/text_classification/model/char_cnn.py b/reproduction/text_classification/model/char_cnn.py index f87f5c14..ac370082 100644 --- a/reproduction/text_classification/model/char_cnn.py +++ b/reproduction/text_classification/model/char_cnn.py @@ -1 +1,90 @@ -# TODO \ No newline at end of file +''' +@author: https://github.com/ahmedbesbes/character-based-cnn +这里借鉴了上述链接中char-cnn model的代码,改动主要为将其改动为符合fastnlp的pipline +''' +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C + +class CharacterLevelCNN(nn.Module): + def __init__(self, args,embedding): + super(CharacterLevelCNN, self).__init__() + + self.config=args.char_cnn_config + self.embedding=embedding + + conv_layers = [] + for i, conv_layer_parameter in enumerate(self.config['model_parameters'][args.model_size]['conv']): + if i == 0: + #in_channels = args.number_of_characters + len(args.extra_characters) + in_channels = args.embedding_dim + out_channels = conv_layer_parameter[0] + else: + in_channels, out_channels = conv_layer_parameter[0], conv_layer_parameter[0] + + if conv_layer_parameter[2] != -1: + conv_layer = nn.Sequential(nn.Conv1d(in_channels, + out_channels, + kernel_size=conv_layer_parameter[1], padding=0), + nn.ReLU(), + nn.MaxPool1d(conv_layer_parameter[2])) + else: + conv_layer = nn.Sequential(nn.Conv1d(in_channels, + out_channels, + kernel_size=conv_layer_parameter[1], padding=0), + nn.ReLU()) + conv_layers.append(conv_layer) + self.conv_layers = nn.ModuleList(conv_layers) + + input_shape = (args.batch_size, args.max_length, + args.number_of_characters + len(args.extra_characters)) + dimension = self._get_conv_output(input_shape) + + print('dimension :', dimension) + + fc_layer_parameter = self.config['model_parameters'][args.model_size]['fc'][0] + fc_layers = nn.ModuleList([ + nn.Sequential( + nn.Linear(dimension, fc_layer_parameter), nn.Dropout(0.5)), + nn.Sequential(nn.Linear(fc_layer_parameter, + fc_layer_parameter), nn.Dropout(0.5)), + nn.Linear(fc_layer_parameter, args.num_classes), + ]) + + self.fc_layers = fc_layers + + if args.model_size == 'small': + self._create_weights(mean=0.0, std=0.05) + elif args.model_size == 'large': + self._create_weights(mean=0.0, std=0.02) + + def _create_weights(self, mean=0.0, std=0.05): + for module in self.modules(): + if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear): + module.weight.data.normal_(mean, std) + + def _get_conv_output(self, shape): + input = torch.rand(shape) + output = input.transpose(1, 2) + # forward pass through conv layers + for i in range(len(self.conv_layers)): + output = self.conv_layers[i](output) + + output = output.view(output.size(0), -1) + n_size = output.size(1) + return n_size + + def forward(self, chars): + input=self.embedding(chars) + output = input.transpose(1, 2) + # forward pass through conv layers + for i in range(len(self.conv_layers)): + output = self.conv_layers[i](output) + + output = output.view(output.size(0), -1) + + # forward pass through fc layers + for i in range(len(self.fc_layers)): + output = self.fc_layers[i](output) + + return {C.OUTPUT: output} \ No newline at end of file diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index f87f5c14..dafe62bc 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -1 +1,97 @@ -# TODO \ No newline at end of file +import torch +import torch.nn as nn +from fastNLP.modules.utils import get_embeddings +from fastNLP.core import Const as C + + +class DPCNN(nn.Module): + def __init__(self, init_embed, num_cls, n_filters=256, + kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1): + super().__init__() + self.region_embed = RegionEmbedding( + init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5]) + embed_dim = self.region_embed.embedding_dim + self.conv_list = nn.ModuleList() + for i in range(n_layers): + self.conv_list.append(nn.Sequential( + nn.ReLU(), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), + )) + self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + self.embed_drop = nn.Dropout(embed_dropout) + self.classfier = nn.Sequential( + nn.Dropout(cls_dropout), + nn.Linear(n_filters, num_cls), + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + def forward(self, words, seq_len=None): + words = words.long() + # get region embeddings + x = self.region_embed(words) + x = self.embed_drop(x) + + # not pooling on first conv + x = self.conv_list[0](x) + x + for conv in self.conv_list[1:]: + x = self.pool(x) + x = conv(x) + x + + # B, C, L => B, C + x, _ = torch.max(x, dim=2) + x = self.classfier(x) + return {C.OUTPUT: x} + + def predict(self, words, seq_len=None): + x = self.forward(words, seq_len)[C.OUTPUT] + return {C.OUTPUT: torch.argmax(x, 1)} + + +class RegionEmbedding(nn.Module): + def __init__(self, init_embed, out_dim=300, kernel_sizes=None): + super().__init__() + if kernel_sizes is None: + kernel_sizes = [5, 9] + assert isinstance( + kernel_sizes, list), 'kernel_sizes should be List(int)' + self.embed = get_embeddings(init_embed) + try: + embed_dim = self.embed.embedding_dim + except Exception: + embed_dim = self.embed.embed_size + self.region_embeds = nn.ModuleList() + for ksz in kernel_sizes: + self.region_embeds.append(nn.Sequential( + nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), + )) + self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) + for _ in range(len(kernel_sizes))]) + self.embedding_dim = embed_dim + + def forward(self, x): + x = self.embed(x) + x = x.transpose(1, 2) + # B, C, L + out = 0 + for conv, fc in zip(self.region_embeds, self.linears[1:]): + conv_i = conv(x) + out = out + fc(conv_i) + # B, C, L + return out + + +if __name__ == '__main__': + x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long) + model = DPCNN((10000, 300), 20) + y = model(x) + print(y.size(), y.mean(1), y.std(1)) diff --git a/reproduction/text_classification/model/lstm.py b/reproduction/text_classification/model/lstm.py new file mode 100644 index 00000000..388f3f1c --- /dev/null +++ b/reproduction/text_classification/model/lstm.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.decoder.mlp import MLP + + +class BiLSTMSentiment(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + nfc=128): + super(BiLSTMSentiment,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) + self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + output = self.mlp(output[:,-1,:]) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} + diff --git a/reproduction/text_classification/model/lstm_self_attention.py b/reproduction/text_classification/model/lstm_self_attention.py new file mode 100644 index 00000000..239635fe --- /dev/null +++ b/reproduction/text_classification/model/lstm_self_attention.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.aggregator.attention import SelfAttention +from fastNLP.modules.decoder.mlp import MLP + + +class BiLSTM_SELF_ATTENTION(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + attention_unit=256, + attention_hops=1, + nfc=128): + super(BiLSTM_SELF_ATTENTION,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) + self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops) + self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + after_attention, penalty = self.attention(output,words) + after_attention =after_attention.view(after_attention.size(0),-1) + output = self.mlp(after_attention) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} diff --git a/reproduction/text_classification/model/weight_drop.py b/reproduction/text_classification/model/weight_drop.py new file mode 100644 index 00000000..60fda179 --- /dev/null +++ b/reproduction/text_classification/model/weight_drop.py @@ -0,0 +1,99 @@ +import torch +from torch.nn import Parameter +from functools import wraps + +class WeightDrop(torch.nn.Module): + def __init__(self, module, weights, dropout=0, variational=False): + super(WeightDrop, self).__init__() + self.module = module + self.weights = weights + self.dropout = dropout + self.variational = variational + self._setup() + + def widget_demagnetizer_y2k_edition(*args, **kwargs): + # We need to replace flatten_parameters with a nothing function + # It must be a function rather than a lambda as otherwise pickling explodes + # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! + # (╯°□°)╯︵ ┻━┻ + return + + def _setup(self): + # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN + if issubclass(type(self.module), torch.nn.RNNBase): + self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition + + for name_w in self.weights: + print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) + w = getattr(self.module, name_w) + del self.module._parameters[name_w] + self.module.register_parameter(name_w + '_raw', Parameter(w.data)) + + def _setweights(self): + for name_w in self.weights: + raw_w = getattr(self.module, name_w + '_raw') + w = None + if self.variational: + mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) + if raw_w.is_cuda: mask = mask.cuda() + mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) + w = mask.expand_as(raw_w) * raw_w + else: + w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) + setattr(self.module, name_w, w) + + def forward(self, *args): + self._setweights() + return self.module.forward(*args) + +if __name__ == '__main__': + import torch + from weight_drop import WeightDrop + + # Input is (seq, batch, input) + x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() + h0 = None + + ### + + print('Testing WeightDrop') + print('=-=-=-=-=-=-=-=-=-=') + + ### + + print('Testing WeightDrop with Linear') + + lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) + lin.cuda() + run1 = [x.sum() for x in lin(x).data] + run2 = [x.sum() for x in lin(x).data] + + print('All items should be different') + print('Run 1:', run1) + print('Run 2:', run2) + + assert run1[0] != run2[0] + assert run1[1] != run2[1] + + print('---') + + ### + + print('Testing WeightDrop with LSTM') + + wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) + wdrnn.cuda() + + run1 = [x.sum() for x in wdrnn(x, h0)[0].data] + run2 = [x.sum() for x in wdrnn(x, h0)[0].data] + + print('First timesteps should be equal, all others should differ') + print('Run 1:', run1) + print('Run 2:', run2) + + # First time step, not influenced by hidden to hidden weights, should be equal + assert run1[0] == run2[0] + # Second step should not + assert run1[1] != run2[1] + + print('---') diff --git a/reproduction/text_classification/train_HAN.py b/reproduction/text_classification/train_HAN.py new file mode 100644 index 00000000..b1135342 --- /dev/null +++ b/reproduction/text_classification/train_HAN.py @@ -0,0 +1,109 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 + +import os +import sys +sys.path.append('../../') +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +from fastNLP.core.const import Const as C +from fastNLP.core import LRScheduler +import torch.nn as nn +from fastNLP.io.dataset_loader import SSTLoader +from reproduction.text_classification.data.yelpLoader import yelpLoader +from reproduction.text_classification.model.HAN import HANCLS +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +import torch.cuda +from torch.optim.lr_scheduler import CosineAnnealingLR + + +##hyper + +class Config(): + model_dir_or_name = "en-base-uncased" + embedding_grad = False, + train_epoch = 30 + batch_size = 100 + num_classes = 5 + task = "yelp" + #datadir = '/remote-home/lyli/fastNLP/yelp_polarity/' + datadir = '/remote-home/ygwang/yelp_polarity/' + datafile = {"train": "train.csv", "test": "test.csv"} + lr = 1e-3 + + def __init__(self): + self.datapath = {k: os.path.join(self.datadir, v) + for k, v in self.datafile.items()} + + +ops = Config() + +##1.task相关信息:利用dataloader载入dataInfo + +datainfo = yelpLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['test'])) + + +# post process +def make_sents(words): + sents = [words] + return sents + + +for dataset in datainfo.datasets.values(): + dataset.apply_field(make_sents, field_name='words', new_field_name='input_sents') + +datainfo = datainfo +datainfo.datasets['train'].set_input('input_sents') +datainfo.datasets['test'].set_input('input_sents') +datainfo.datasets['train'].set_target('target') +datainfo.datasets['test'].set_target('target') + +## 2.或直接复用fastNLP的模型 + +vocab = datainfo.vocabs['words'] +# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) +embedding = StaticEmbedding(vocab) + +print(len(vocab)) +print(len(datainfo.vocabs['target'])) + +# model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) +model = HANCLS(init_embed=embedding, num_cls=ops.num_classes) + +## 3. 声明loss,metric,optimizer +loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], + lr=ops.lr, momentum=0.9, weight_decay=0) + +callbacks = [] +callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +print(device) + +for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + + +## 4.定义train方法 +def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metrics], dev_data=datainfo.datasets['test'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=num_epochs) + + print(trainer.train()) + + +if __name__ == "__main__": + train(model, datainfo, loss, metric, optimizer) diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py new file mode 100644 index 00000000..007b2910 --- /dev/null +++ b/reproduction/text_classification/train_awdlstm.py @@ -0,0 +1,69 @@ +# 这个模型需要在pytorch=0.4下运行,weight_drop不支持1.0 + +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.IMDBLoader import IMDBLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.awd_lstm import AWDLSTMSentiment + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + nfc=128 + wdrop=0.5 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + save_model_path="./result_IMDB_test/" + +opt=Config() + + +# load data +dataloader=IMDBLoader() +datainfo=dataloader.process(opt.datapath) + +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=AWDLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc, wdrop=opt.wdrop) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +if __name__ == "__main__": + train(datainfo, model, optimizer, loss, metrics, opt) diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index e69de29b..050527fe 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -0,0 +1,205 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + +import sys +sys.path.append('../..') +from fastNLP.core.const import Const as C +import torch.nn as nn +from data.yelpLoader import yelpLoader +from data.sstLoader import sst2Loader +from data.IMDBLoader import IMDBLoader +from model.char_cnn import CharacterLevelCNN +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.models.cnn_text_classification import CNNText +from fastNLP.modules.encoder.embedding import CNNCharEmbedding,StaticEmbedding,StackEmbedding,LSTMCharEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +from torch.autograd import Variable +import torch +from fastNLP import BucketSampler + +##hyper +#todo 这里加入fastnlp的记录 +class Config(): + model_dir_or_name="en-base-uncased" + embedding_grad= False, + bert_embedding_larers= '4,-2,-1' + train_epoch= 50 + num_classes=2 + task= "IMDB" + #yelp_p + datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", + "test": "/remote-home/ygwang/yelp_polarity/test.csv"} + #IMDB + #datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv", + # "test": "/remote-home/ygwang/IMDB_data/test.csv"} + # sst + # datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", + # "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} + + lr=0.01 + batch_size=128 + model_size="large" + number_of_characters=69 + extra_characters='' + max_length=1014 + + char_cnn_config={ + "alphabet": { + "en": { + "lower": { + "alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}", + "number_of_characters": 69 + }, + "both": { + "alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}", + "number_of_characters": 95 + } + } + }, + "model_parameters": { + "small": { + "conv": [ + #依次是channel,kennnel_size,maxpooling_size + [256,7,3], + [256,7,3], + [256,3,-1], + [256,3,-1], + [256,3,-1], + [256,3,3] + ], + "fc": [1024,1024] + }, + "large":{ + "conv":[ + [1024, 7, 3], + [1024, 7, 3], + [1024, 3, -1], + [1024, 3, -1], + [1024, 3, -1], + [1024, 3, 3] + ], + "fc": [2048,2048] + } + }, + "data": { + "text_column": "SentimentText", + "label_column": "Sentiment", + "max_length": 1014, + "num_of_classes": 2, + "encoding": None, + "chunksize": 50000, + "max_rows": 100000, + "preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"] + }, + "training": { + "batch_size": 128, + "learning_rate": 0.01, + "epochs": 10, + "optimizer": "sgd" + } + } +ops=Config + + +##1.task相关信息:利用dataloader载入dataInfo +#dataloader=sst2Loader() +#dataloader=IMDBLoader() +dataloader=yelpLoader(fine_grained=True) +datainfo=dataloader.process(ops.datapath,char_level_op=True) +char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] +ops.number_of_characters=len(char_vocab) +ops.embedding_dim=ops.number_of_characters + +#chartoindex +def chartoindex(chars): + max_seq_len=ops.max_length + zero_index=len(char_vocab) + char_index_list=[] + for char in chars: + if char in char_vocab: + char_index_list.append(char_vocab.index(char)) + else: + #均使用最后一个作为embbeding + char_index_list.append(zero_index) + if len(char_index_list) > max_seq_len: + char_index_list = char_index_list[:max_seq_len] + elif 0 < len(char_index_list) < max_seq_len: + char_index_list = char_index_list+[zero_index]*(max_seq_len-len(char_index_list)) + elif len(char_index_list) == 0: + char_index_list=[zero_index]*max_seq_len + return char_index_list + +for dataset in datainfo.datasets.values(): + dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars') + +datainfo.datasets['train'].set_input('chars') +datainfo.datasets['test'].set_input('chars') +datainfo.datasets['train'].set_target('target') +datainfo.datasets['test'].set_target('target') + +##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model +class ModelFactory(nn.Module): + """ + 用于拼装embedding,encoder,decoder 以及设计forward过程 + + :param embedding: embbeding model + :param encoder: encoder model + :param decoder: decoder model + + """ + def __int__(self,embedding,encoder,decoder,**kwargs): + super(ModelFactory,self).__init__() + self.embedding=embedding + self.encoder=encoder + self.decoder=decoder + + def forward(self,x): + return {C.OUTPUT:None} + +## 2.或直接复用fastNLP的模型 +#vocab=datainfo.vocabs['words'] +vocab_label=datainfo.vocabs['target'] +''' +# emded_char=CNNCharEmbedding(vocab) +# embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) +# embedding=StackEmbedding([emded_char, embed_word]) +# cnn_char_embed = CNNCharEmbedding(vocab) +# lstm_char_embed = LSTMCharEmbedding(vocab) +# embedding = StackEmbedding([cnn_char_embed, lstm_char_embed]) +''' +#one-hot embedding +embedding_weight= Variable(torch.zeros(len(char_vocab)+1, len(char_vocab))) + +for i in range(len(char_vocab)): + embedding_weight[i][i]=1 +embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_vocab),padding_idx=len(char_vocab),_weight=embedding_weight) +for para in embedding.parameters(): + para.requires_grad=False +#CNNText太过于简单 +#model=CNNText(init_embed=embedding, num_classes=ops.num_classes) +model=CharacterLevelCNN(ops,embedding) + +## 3. 声明loss,metric,optimizer +loss=CrossEntropyLoss +metric=AccuracyMetric +optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) + +## 4.定义train方法 +def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'), + metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, + n_epochs=num_epochs) + print(trainer.train()) + + + +if __name__=="__main__": + #print(vocab_label) + + #print(datainfo.datasets["train"]) + train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch) + \ No newline at end of file diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index e69de29b..9664bf75 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -0,0 +1,120 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 + +import torch.cuda +from fastNLP.core.utils import cache_results +from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from fastNLP.core.trainer import Trainer +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from reproduction.text_classification.model.dpcnn import DPCNN +from data.yelpLoader import yelpLoader +from fastNLP.core.sampler import BucketSampler +import torch.nn as nn +from fastNLP.core import LRScheduler +from fastNLP.core.const import Const as C +from fastNLP.core.vocabulary import VocabularyOption +from utils.util_init import set_rng_seeds +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + + +# hyper + +class Config(): + seed = 12345 + model_dir_or_name = "dpcnn-yelp-p" + embedding_grad = True + train_epoch = 30 + batch_size = 100 + task = "yelp_p" + #datadir = 'workdir/datasets/SST' + datadir = 'workdir/datasets/yelp_polarity' + # datadir = 'workdir/datasets/yelp_full' + #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + datafile = {"train": "train.csv", "test": "test.csv"} + lr = 1e-3 + src_vocab_op = VocabularyOption(max_size=100000) + embed_dropout = 0.3 + cls_dropout = 0.1 + weight_decay = 1e-5 + + def __init__(self): + self.datadir = os.path.join(os.environ['HOME'], self.datadir) + self.datapath = {k: os.path.join(self.datadir, v) + for k, v in self.datafile.items()} + + +ops = Config() + +set_rng_seeds(ops.seed) +print('RNG SEED: {}'.format(ops.seed)) + +# 1.task相关信息:利用dataloader载入dataInfo + +#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) + + +@cache_results(ops.model_dir_or_name+'-data-cache') +def load_data(): + datainfo = yelpLoader(fine_grained=True, lower=True).process( + paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op) + for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + embedding = StaticEmbedding( + datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad, + normalize=False + ) + return datainfo, embedding + + +datainfo, embedding = load_data() + +# 2.或直接复用fastNLP的模型 + +# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) + +print(datainfo) +print(datainfo.datasets['train'][0]) + +model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), + embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) +print(model) + +# 3. 声明loss,metric,optimizer +loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], + lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) + +callbacks = [] +# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) +callbacks.append( + LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < + ops.train_epoch * 0.8 else ops.lr * 0.1)) +) + +# callbacks.append( +# FitlogCallback(data=datainfo.datasets, verbose=1) +# ) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +print(device) + +# 4.定义train方法 +trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), + metrics=[metric], + dev_data=datainfo.datasets['test'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=ops.train_epoch, num_workers=4) + + + +if __name__ == "__main__": + print(trainer.train()) diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py new file mode 100644 index 00000000..4ecc61a1 --- /dev/null +++ b/reproduction/text_classification/train_lstm.py @@ -0,0 +1,66 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.IMDBLoader import IMDBLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.lstm import BiLSTMSentiment + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + nfc=128 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + save_model_path="./result_IMDB_test/" + +opt=Config() + + +# load data +dataloader=IMDBLoader() +datainfo=dataloader.process(opt.datapath) + +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=BiLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +if __name__ == "__main__": + train(datainfo, model, optimizer, loss, metrics, opt) \ No newline at end of file diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py new file mode 100644 index 00000000..a6f0dd03 --- /dev/null +++ b/reproduction/text_classification/train_lstm_att.py @@ -0,0 +1,68 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.IMDBLoader import IMDBLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.lstm_self_attention import BiLSTM_SELF_ATTENTION + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + attention_unit=256 + attention_hops=1 + nfc=128 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + save_model_path="./result_IMDB_test/" + +opt=Config() + + +# load data +dataloader=IMDBLoader() +datainfo=dataloader.process(opt.datapath) + +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +if __name__ == "__main__": + train(datainfo, model, optimizer, loss, metrics, opt) diff --git a/reproduction/text_classification/utils/util_init.py b/reproduction/text_classification/utils/util_init.py new file mode 100644 index 00000000..fcb8fffb --- /dev/null +++ b/reproduction/text_classification/utils/util_init.py @@ -0,0 +1,11 @@ +import numpy +import torch +import random + + +def set_rng_seeds(seed): + random.seed(seed) + numpy.random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # print('RNG_SEED {}'.format(seed)) diff --git a/reproduction/utils.py b/reproduction/utils.py index 536b8eec..d6cd1af3 100644 --- a/reproduction/utils.py +++ b/reproduction/utils.py @@ -59,4 +59,13 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: else: raise TypeError(f"paths only supports str and dict. not {type(paths)}.") - +def get_tokenizer(): + try: + import spacy + spacy.prefer_gpu() + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split()