diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py new file mode 100644 index 00000000..d591cdf8 --- /dev/null +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -0,0 +1,80 @@ +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial + +class IMDBLoader(DataSetLoader): + """ + 读取IMDB数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1].lower().split() + dataset.append(Instance(words=words, target=target)) + if len(dataset)==0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None): + + # paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info diff --git a/reproduction/text_classification/data/MTL16Loader.py b/reproduction/text_classification/data/MTL16Loader.py index 1b3e6245..066b53b4 100644 --- a/reproduction/text_classification/data/MTL16Loader.py +++ b/reproduction/text_classification/data/MTL16Loader.py @@ -32,7 +32,7 @@ class MTL16Loader(DataSetLoader): continue parts = line.split('\t') target = parts[0] - words = parts[1].split() + words = parts[1].lower().split() dataset.append(Instance(words=words, target=target)) if len(dataset)==0: raise RuntimeError(f"{path} has no valid data.") @@ -72,4 +72,8 @@ class MTL16Loader(DataSetLoader): embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) info.embeddings['words'] = embed + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + return info diff --git a/reproduction/text_classification/data/SSTLoader.py b/reproduction/text_classification/data/SSTLoader.py new file mode 100644 index 00000000..b570994e --- /dev/null +++ b/reproduction/text_classification/data/SSTLoader.py @@ -0,0 +1,99 @@ +from typing import Iterable +from nltk import Tree +from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.core.vocabulary import VocabularyOption, Vocabulary +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader + + +class SSTLoader(DataSetLoader): + URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' + DATA_DIR = 'sst/' + + """ + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` + + 读取SST数据集, DataSet包含fields:: + + words: list(str) 需要分类的文本 + target: str 文本的标签 + + 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + + :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + """ + + def __init__(self, subtree=False, fine_grained=False): + self.subtree = subtree + + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', + '3': 'positive', '4': 'very positive'} + if not fine_grained: + tag_v['0'] = tag_v['1'] + tag_v['4'] = tag_v['3'] + self.tag_v = tag_v + + def _load(self, path): + """ + + :param str path: 存储数据的路径 + :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + datas = [] + for l in f: + datas.extend([(s, self.tag_v[t]) + for s, t in self._get_one(l, self.subtree)]) + ds = DataSet() + for words, tag in datas: + ds.append(Instance(words=words, target=tag)) + return ds + + @staticmethod + def _get_one(data, subtree): + tree = Tree.fromstring(data) + if subtree: + return [(t.leaves(), t.label()) for t in tree.subtrees()] + return [(tree.leaves(), tree.label())] + + def process(self, + paths, + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + src_embed_op: EmbeddingOption = None): + input_name, target_name = 'words', 'target' + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + + info = DataInfo(datasets=self.load(paths)) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + src_vocab.from_dataset(*_train_ds, field_name=input_name) + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + src_vocab.index_dataset( + *info.datasets.values(), + field_name=input_name, new_field_name=input_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs = { + input_name: src_vocab, + target_name: tgt_vocab + } + + if src_embed_op is not None: + src_embed_op.vocab = src_vocab + init_emb = EmbedLoader.load_with_vocab(**src_embed_op) + info.embeddings[input_name] = init_emb + + for name, dataset in info.datasets.items(): + dataset.set_input(input_name) + dataset.set_target(target_name) + + return info + diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index c47d48fd..680b3488 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -1,68 +1,77 @@ -import ast -from fastNLP import DataSet, Instance, Vocabulary +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io import JsonLoader -from fastNLP.io.base_loader import DataInfo -from fastNLP.io.embed_loader import EmbeddingOption -from fastNLP.io.file_reader import _read_json -from typing import Union, Dict -from reproduction.Star_transformer.datasets import EmbedLoader -from reproduction.utils import check_dataloader_paths +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial +import pandas as pd - -class yelpLoader(JsonLoader): - +class yelpLoader(DataSetLoader): """ - 读取Yelp数据集, DataSet包含fields: - - review_id: str, 22 character unique review id - user_id: str, 22 character unique user id - business_id: str, 22 character business id - useful: int, number of useful votes received - funny: int, number of funny votes received - cool: int, number of cool votes received - date: str, date formatted YYYY-MM-DD + 读取IMDB数据集,DataSet包含以下fields: + words: list(str), 需要分类的文本 target: str, 文本的标签 - - 数据来源: https://www.yelp.com/dataset/download - - :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + + """ - - def __init__(self, fine_grained=False): + + def __init__(self): super(yelpLoader, self).__init__() - tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', - '4.0': 'positive', '5.0': 'very positive'} - if not fine_grained: - tag_v['1.0'] = tag_v['2.0'] - tag_v['5.0'] = tag_v['4.0'] - self.fine_grained = fine_grained - self.tag_v = tag_v - + def _load(self, path): - ds = DataSet() - for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): - d = ast.literal_eval(d) - d["words"] = d.pop("text").split() - d["target"] = self.tag_v[str(d.pop("stars"))] - ds.append(Instance(**d)) - return ds + dataset = DataSet() + data = pd.read_csv(path, header=None, sep=",").values + for line in data: + target = str(line[0]) + words = str(line[1]).lower().split() + dataset.append(Instance(words=words, target=target)) + if len(dataset)==0: + raise RuntimeError(f"{path} has no valid data.") - def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, - embed_opt: EmbeddingOption = None): - paths = check_dataloader_paths(paths) + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None): + + # paths = check_dataloader_paths(paths) datasets = {} info = DataInfo() - vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset - vocab.from_dataset(dataset, field_name="words") - info.vocabs = vocab + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + info.datasets = datasets - if embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) info.embeddings['words'] = embed - return info + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info diff --git a/reproduction/text_classification/model/awd_lstm.py b/reproduction/text_classification/model/awd_lstm.py new file mode 100644 index 00000000..0d8f711a --- /dev/null +++ b/reproduction/text_classification/model/awd_lstm.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from .awdlstm_module import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.decoder.mlp import MLP + + +class AWDLSTMSentiment(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + nfc=128, + wdrop=0.5): + super(AWDLSTMSentiment,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop) + self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + output = self.mlp(output[:,-1,:]) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} + diff --git a/reproduction/text_classification/model/awdlstm_module.py b/reproduction/text_classification/model/awdlstm_module.py new file mode 100644 index 00000000..87bfe730 --- /dev/null +++ b/reproduction/text_classification/model/awdlstm_module.py @@ -0,0 +1,86 @@ +""" +轻量封装的 Pytorch LSTM 模块. +可在 forward 时传入序列的长度, 自动对padding做合适的处理. +""" +__all__ = [ + "LSTM" +] + +import torch +import torch.nn as nn +import torch.nn.utils.rnn as rnn + +from fastNLP.modules.utils import initial_parameter +from torch import autograd +from .weight_drop import WeightDrop + + +class LSTM(nn.Module): + """ + 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` + + LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 + 为1; 且可以应对DataParallel中LSTM的使用问题。 + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度. + :param num_layers: rnn的层数. Default: 1 + :param dropout: 层间dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + """ + + def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, + bidirectional=False, bias=True, wdrop=0.5): + super(LSTM, self).__init__() + self.batch_first = batch_first + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, + dropout=dropout, bidirectional=bidirectional) + self.lstm = WeightDrop(self.lstm, ['weight_hh_l0'], dropout=wdrop) + self.init_param() + + def init_param(self): + for name, param in self.named_parameters(): + if 'bias' in name: + # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 + param.data.fill_(0) + n = param.size(0) + start, end = n // 4, n // 2 + param.data[start:end].fill_(1) + else: + nn.init.xavier_uniform_(param) + + def forward(self, x, seq_len=None, h0=None, c0=None): + """ + + :param x: [batch, seq_len, input_size] 输入序列 + :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` + :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` + :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` + :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 + 和 [batch, hidden_size*num_direction] 最后时刻隐状态. + """ + batch_size, max_len, _ = x.size() + if h0 is not None and c0 is not None: + hx = (h0, c0) + else: + hx = None + if seq_len is not None and not isinstance(x, rnn.PackedSequence): + sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) + if self.batch_first: + x = x[sort_idx] + else: + x = x[:, sort_idx] + x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) + output, hx = self.lstm(x, hx) # -> [N,L,C] + output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + if self.batch_first: + output = output[unsort_idx] + else: + output = output[:, unsort_idx] + else: + output, hx = self.lstm(x, hx) + return output, hx diff --git a/reproduction/text_classification/model/lstm.py b/reproduction/text_classification/model/lstm.py new file mode 100644 index 00000000..388f3f1c --- /dev/null +++ b/reproduction/text_classification/model/lstm.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.decoder.mlp import MLP + + +class BiLSTMSentiment(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + nfc=128): + super(BiLSTMSentiment,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) + self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + output = self.mlp(output[:,-1,:]) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} + diff --git a/reproduction/text_classification/model/lstm_self_attention.py b/reproduction/text_classification/model/lstm_self_attention.py new file mode 100644 index 00000000..239635fe --- /dev/null +++ b/reproduction/text_classification/model/lstm_self_attention.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.aggregator.attention import SelfAttention +from fastNLP.modules.decoder.mlp import MLP + + +class BiLSTM_SELF_ATTENTION(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + attention_unit=256, + attention_hops=1, + nfc=128): + super(BiLSTM_SELF_ATTENTION,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) + self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops) + self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + after_attention, penalty = self.attention(output,words) + after_attention =after_attention.view(after_attention.size(0),-1) + output = self.mlp(after_attention) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} diff --git a/reproduction/text_classification/model/weight_drop.py b/reproduction/text_classification/model/weight_drop.py new file mode 100644 index 00000000..60fda179 --- /dev/null +++ b/reproduction/text_classification/model/weight_drop.py @@ -0,0 +1,99 @@ +import torch +from torch.nn import Parameter +from functools import wraps + +class WeightDrop(torch.nn.Module): + def __init__(self, module, weights, dropout=0, variational=False): + super(WeightDrop, self).__init__() + self.module = module + self.weights = weights + self.dropout = dropout + self.variational = variational + self._setup() + + def widget_demagnetizer_y2k_edition(*args, **kwargs): + # We need to replace flatten_parameters with a nothing function + # It must be a function rather than a lambda as otherwise pickling explodes + # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! + # (╯°□°)╯︵ ┻━┻ + return + + def _setup(self): + # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN + if issubclass(type(self.module), torch.nn.RNNBase): + self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition + + for name_w in self.weights: + print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) + w = getattr(self.module, name_w) + del self.module._parameters[name_w] + self.module.register_parameter(name_w + '_raw', Parameter(w.data)) + + def _setweights(self): + for name_w in self.weights: + raw_w = getattr(self.module, name_w + '_raw') + w = None + if self.variational: + mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) + if raw_w.is_cuda: mask = mask.cuda() + mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) + w = mask.expand_as(raw_w) * raw_w + else: + w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) + setattr(self.module, name_w, w) + + def forward(self, *args): + self._setweights() + return self.module.forward(*args) + +if __name__ == '__main__': + import torch + from weight_drop import WeightDrop + + # Input is (seq, batch, input) + x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() + h0 = None + + ### + + print('Testing WeightDrop') + print('=-=-=-=-=-=-=-=-=-=') + + ### + + print('Testing WeightDrop with Linear') + + lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) + lin.cuda() + run1 = [x.sum() for x in lin(x).data] + run2 = [x.sum() for x in lin(x).data] + + print('All items should be different') + print('Run 1:', run1) + print('Run 2:', run2) + + assert run1[0] != run2[0] + assert run1[1] != run2[1] + + print('---') + + ### + + print('Testing WeightDrop with LSTM') + + wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) + wdrnn.cuda() + + run1 = [x.sum() for x in wdrnn(x, h0)[0].data] + run2 = [x.sum() for x in wdrnn(x, h0)[0].data] + + print('First timesteps should be equal, all others should differ') + print('Run 1:', run1) + print('Run 2:', run2) + + # First time step, not influenced by hidden to hidden weights, should be equal + assert run1[0] == run2[0] + # Second step should not + assert run1[1] != run2[1] + + print('---') diff --git a/reproduction/text_classification/results_LSTM.xlsx b/reproduction/text_classification/results_LSTM.xlsx new file mode 100644 index 00000000..0d7b841b Binary files /dev/null and b/reproduction/text_classification/results_LSTM.xlsx differ diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py new file mode 100644 index 00000000..ce3e52bc --- /dev/null +++ b/reproduction/text_classification/train_awdlstm.py @@ -0,0 +1,102 @@ +# 这个模型需要在pytorch=0.4下运行,weight_drop不支持1.0 + +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.SSTLoader import SSTLoader +from data.IMDBLoader import IMDBLoader +from data.yelpLoader import yelpLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.awd_lstm import AWDLSTMSentiment + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + nfc=128 + wdrop=0.5 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" + save_model_path="./result_IMDB_test/" +opt=Config + + +# load data +dataloaders = { + "IMDB":IMDBLoader(), + "YELP":yelpLoader(), + "SST-5":SSTLoader(subtree=True,fine_grained=True), + "SST-3":SSTLoader(subtree=True,fine_grained=False) +} + +if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: + raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") + +dataloader = dataloaders[opt.task_name] +datainfo=dataloader.process(opt.datapath) +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=AWDLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc, wdrop=opt.wdrop) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +def test(datainfo, metrics, opt): + # load model + model = ModelLoader.load_pytorch_model(opt.load_model_path) + print("model loaded!") + + # Tester + tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) + acc = tester.test() + print("acc=",acc) + + + +parser = argparse.ArgumentParser() +parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') + + +args = parser.parse_args() +if args.mode == 'train': + train(datainfo, model, optimizer, loss, metrics, opt) +elif args.mode == 'test': + test(datainfo, metrics, opt) +else: + print('no mode specified for model!') + parser.print_help() diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py new file mode 100644 index 00000000..b320e79c --- /dev/null +++ b/reproduction/text_classification/train_lstm.py @@ -0,0 +1,99 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.SSTLoader import SSTLoader +from data.IMDBLoader import IMDBLoader +from data.yelpLoader import yelpLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.lstm import BiLSTMSentiment + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + nfc=128 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" + save_model_path="./result_IMDB_test/" +opt=Config + + +# load data +dataloaders = { + "IMDB":IMDBLoader(), + "YELP":yelpLoader(), + "SST-5":SSTLoader(subtree=True,fine_grained=True), + "SST-3":SSTLoader(subtree=True,fine_grained=False) +} + +if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: + raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") + +dataloader = dataloaders[opt.task_name] +datainfo=dataloader.process(opt.datapath) +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=BiLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +def test(datainfo, metrics, opt): + # load model + model = ModelLoader.load_pytorch_model(opt.load_model_path) + print("model loaded!") + + # Tester + tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) + acc = tester.test() + print("acc=",acc) + + + +parser = argparse.ArgumentParser() +parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') + + +args = parser.parse_args() +if args.mode == 'train': + train(datainfo, model, optimizer, loss, metrics, opt) +elif args.mode == 'test': + test(datainfo, metrics, opt) +else: + print('no mode specified for model!') + parser.print_help() diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py new file mode 100644 index 00000000..8db27d09 --- /dev/null +++ b/reproduction/text_classification/train_lstm_att.py @@ -0,0 +1,101 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.SSTLoader import SSTLoader +from data.IMDBLoader import IMDBLoader +from data.yelpLoader import yelpLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.lstm_self_attention import BiLSTM_SELF_ATTENTION + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + attention_unit=256 + attention_hops=1 + nfc=128 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" + save_model_path="./result_IMDB_test/" +opt=Config + + +# load data +dataloaders = { + "IMDB":IMDBLoader(), + "YELP":yelpLoader(), + "SST-5":SSTLoader(subtree=True,fine_grained=True), + "SST-3":SSTLoader(subtree=True,fine_grained=False) +} + +if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: + raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") + +dataloader = dataloaders[opt.task_name] +datainfo=dataloader.process(opt.datapath) +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +def test(datainfo, metrics, opt): + # load model + model = ModelLoader.load_pytorch_model(opt.load_model_path) + print("model loaded!") + + # Tester + tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) + acc = tester.test() + print("acc=",acc) + + + +parser = argparse.ArgumentParser() +parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') + + +args = parser.parse_args() +if args.mode == 'train': + train(datainfo, model, optimizer, loss, metrics, opt) +elif args.mode == 'test': + test(datainfo, metrics, opt) +else: + print('no mode specified for model!') + parser.print_help()