From 32917dac7f8e87adbad72c69d3b269dd48606b8a Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 27 Jun 2019 21:56:49 +0800 Subject: [PATCH 01/20] add dpcnn --- .../text_classification/model/dpcnn.py | 92 ++++++++++++++++++- .../text_classification/train_dpcnn.py | 80 ++++++++++++++++ 2 files changed, 171 insertions(+), 1 deletion(-) diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index f87f5c14..2da7b3e5 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -1 +1,91 @@ -# TODO \ No newline at end of file +import torch +import torch.nn as nn +from fastNLP.modules.utils import get_embeddings +from fastNLP.core import Const as C + +class DPCNN(nn.Module): + def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + super().__init__() + self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + embed_dim = self.region_embed.embedding_dim + self.conv_list = nn.ModuleList() + for i in range(n_layers): + self.conv_list.append(nn.Sequential( + nn.ReLU(), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + )) + self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + self.embed_drop = nn.Dropout(embed_dropout) + self.classfier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(n_filters, num_cls), + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + def forward(self, words, seq_len=None): + words = words.long() + # get region embeddings + x = self.region_embed(words) + x = self.embed_drop(x) + + # not pooling on first conv + x = self.conv_list[0](x) + x + for conv in self.conv_list[1:]: + x = self.pool(x) + x = conv(x) + x + + # B, C, L => B, C + x, _ = torch.max(x, dim=2) + x = self.classfier(x) + return {C.OUTPUT: x} + + def predict(self, words, seq_len=None): + x = self.forward(words, seq_len)[C.OUTPUT] + return {C.OUTPUT: torch.argmax(x, 1)} + + +class RegionEmbedding(nn.Module): + def __init__(self, init_embed, out_dim=300, kernel_sizes=None): + super().__init__() + if kernel_sizes is None: + kernel_sizes = [5, 9] + assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + self.embed = get_embeddings(init_embed) + try: + embed_dim = self.embed.embedding_dim + except Exception: + embed_dim = self.embed.embed_size + self.region_embeds = nn.ModuleList() + for ksz in kernel_sizes: + self.region_embeds.append(nn.Sequential( + nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), + )) + self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) + for _ in range(len(kernel_sizes) + 1)]) + self.embedding_dim = embed_dim + + def forward(self, x): + x = self.embed(x) + x = x.transpose(1, 2) + # B, C, L + out = self.linears[0](x) + for conv, fc in zip(self.region_embeds, self.linears[1:]): + conv_i = conv(x) + out = out + fc(conv_i) + # B, C, L + return out + + +if __name__ == '__main__': + x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long) + model = DPCNN((10000, 300), 20) + y = model(x) + print(y.size(), y.mean(1), y.std(1)) diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index e69de29b..13ff4fc1 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -0,0 +1,80 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +from fastNLP.core.const import Const as C +from fastNLP.core import LRScheduler +import torch.nn as nn +from fastNLP.io.dataset_loader import SSTLoader +from reproduction.text_classification.model.dpcnn import DPCNN +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +import torch.cuda +from torch.optim.lr_scheduler import CosineAnnealingLR + +##hyper +class Config(): + model_dir_or_name="en-base-uncased" + embedding_grad= False, + train_epoch= 30 + batch_size = 100 + num_classes=5 + task= "SST" + datadir = '/remote-home/yfshao/workdir/datasets/SST' + datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + lr=1e-3 + def __init__(self): + self.datapath = {k:os.path.join(self.datadir, v) + for k, v in self.datafile.items()} + +ops=Config() + + +##1.task相关信息:利用dataloader载入dataInfo +datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds='train') + +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['dev'])) + + +## 2.或直接复用fastNLP的模型 +vocab = datainfo.vocabs['words'] + +# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) +embedding = StaticEmbedding(vocab) +print(len(vocab)) +print(len(datainfo.vocabs['target'])) +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) + +## 3. 声明loss,metric,optimizer +loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], + lr=ops.lr, momentum=0.9, weight_decay=0) + +callbacks = [] +callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' +print(device) + +for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + +## 4.定义train方法 +def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metrics], dev_data=datainfo.datasets['dev'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=num_epochs) + print(trainer.train()) + + +if __name__=="__main__": + train(model,datainfo,loss,metric,optimizer) \ No newline at end of file From f1adb0f9156357944183fdb692d5206aad7e3b0d Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 4 Jul 2019 13:56:37 +0800 Subject: [PATCH 02/20] add ID-CNN --- .../ner/model/dilated_cnn.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 reproduction/seqence_labelling/ner/model/dilated_cnn.py diff --git a/reproduction/seqence_labelling/ner/model/dilated_cnn.py b/reproduction/seqence_labelling/ner/model/dilated_cnn.py new file mode 100644 index 00000000..cd2fa64b --- /dev/null +++ b/reproduction/seqence_labelling/ner/model/dilated_cnn.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from fastNLP.modules.decoder import ConditionalRandomField +from fastNLP.modules.encoder import Embedding +from fastNLP.core.utils import seq_len_to_mask +from fastNLP.core.const import Const as C + + +class IDCNN(nn.Module): + def __init__(self, init_embed, char_embed, + num_cls, + repeats, num_layers, num_filters, kernel_size, + use_crf=False, use_projection=False, block_loss=False, + input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): + super(IDCNN, self).__init__() + self.word_embeddings = Embedding(init_embed) + self.char_embeddings = Embedding(char_embed) + embedding_size = self.word_embeddings.embedding_dim + \ + self.char_embeddings.embedding_dim + + self.conv0 = nn.Sequential( + nn.Conv1d(in_channels=embedding_size, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=1, + padding=kernel_size//2, + bias=True), + nn.ReLU(), + ) + + block = [] + for layer_i in range(num_layers): + dilated = 2 ** layer_i + block.append(nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=dilated, + padding=(kernel_size//2) * dilated, + bias=True)) + block.append(nn.ReLU()) + self.block = nn.Sequential(*block) + + if use_projection: + self.projection = nn.Sequential( + nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters//2, + kernel_size=1, + bias=True), + nn.ReLU(),) + encode_dim = num_filters // 2 + else: + self.projection = None + encode_dim = num_filters + + self.input_drop = nn.Dropout(input_dropout) + self.hidden_drop = nn.Dropout(hidden_dropout) + self.inner_drop = nn.Dropout(inner_dropout) + self.repeats = repeats + self.out_fc = nn.Conv1d( + in_channels=encode_dim, + out_channels=num_cls, + kernel_size=1, + bias=True) + self.crf = ConditionalRandomField( + num_tags=num_cls) if use_crf else None + self.block_loss = block_loss + + def forward(self, words, chars, seq_len, target=None): + e1 = self.word_embeddings(words) + e2 = self.char_embeddings(chars) + x = torch.cat((e1, e2), dim=-1) # b,l,h + mask = seq_len_to_mask(seq_len) + + x = x.transpose(1, 2) # b,h,l + last_output = self.conv0(x) + output = [] + for repeat in range(self.repeats): + last_output = self.block(last_output) + hidden = self.projection(last_output) if self.projection is not None else last_output + output.append(self.out_fc(hidden)) + + def compute_loss(y, t, mask): + if self.crf is not None and target is not None: + loss = self.crf(y, t, mask) + else: + t.masked_fill_(mask == 0, -100) + loss = F.cross_entropy(y, t, ignore_index=-100) + return loss + + if self.block_loss: + losses = [compute_loss(o, target, mask) for o in output] + loss = sum(losses) + else: + loss = compute_loss(output[-1], target, mask) + + scores = output[-1] + if self.crf is not None: + pred = self.crf.viterbi_decode(scores, target, mask) + else: + pred = scores.max(1)[1] * mask.long() + + return { + C.LOSS: loss, + C.OUTPUT: pred, + } + + def predict(self, words, chars, seq_len): + return self.forward(words, chars, seq_len)[C.OUTPUT] From 372496ca32a5015c19563c82ca1b29d6071a6e81 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 4 Jul 2019 14:03:53 +0800 Subject: [PATCH 03/20] update model & dataloader in text_classification --- .../text_classification/data/IMDBLoader.py | 82 +++++++++ .../text_classification/data/yelpLoader.py | 164 +++++++++++++++--- .../text_classification/train_dpcnn.py | 97 +++++++---- 3 files changed, 283 insertions(+), 60 deletions(-) create mode 100644 reproduction/text_classification/data/IMDBLoader.py diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py new file mode 100644 index 00000000..2df87e26 --- /dev/null +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -0,0 +1,82 @@ +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial + + +class IMDBLoader(DataSetLoader): + """ + 读取IMDB数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1].split() + dataset.append(Instance(words=words, target=target)) + if len(dataset) == 0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None): + + # paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + # src_vocab.from_dataset(datasets['train'], datasets["dev"], datasets["test"], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index c47d48fd..63605ecf 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -1,4 +1,6 @@ import ast +import csv +from typing import Iterable from fastNLP import DataSet, Instance, Vocabulary from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io import JsonLoader @@ -10,11 +12,34 @@ from reproduction.Star_transformer.datasets import EmbedLoader from reproduction.utils import check_dataloader_paths +def clean_str(sentence, char_lower=False): + """ + heavily borrowed from github + https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb + :param sentence: is a str + :return: + """ + if char_lower: + sentence = sentence.lower() + import re + nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') + words = sentence.split() + words_collection = [] + for word in words: + if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: + continue + tt = nonalpnum.split(word) + t = ''.join(tt) + if t != '': + words_collection.append(t) + + return words_collection + + class yelpLoader(JsonLoader): - """ 读取Yelp数据集, DataSet包含fields: - + review_id: str, 22 character unique review id user_id: str, 22 character unique user id business_id: str, 22 character business id @@ -24,23 +49,25 @@ class yelpLoader(JsonLoader): date: str, date formatted YYYY-MM-DD words: list(str), 需要分类的文本 target: str, 文本的标签 - + 数据来源: https://www.yelp.com/dataset/download - + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ - - def __init__(self, fine_grained=False): + + def __init__(self, fine_grained=False, lower=False): super(yelpLoader, self).__init__() tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', - '4.0': 'positive', '5.0': 'very positive'} + '4.0': 'positive', '5.0': 'very positive'} if not fine_grained: tag_v['1.0'] = tag_v['2.0'] tag_v['5.0'] = tag_v['4.0'] self.fine_grained = fine_grained self.tag_v = tag_v - - def _load(self, path): + self.lower = lower + + ''' + def _load_json(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): d = ast.literal_eval(d) @@ -49,20 +76,113 @@ class yelpLoader(JsonLoader): ds.append(Instance(**d)) return ds - def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, - embed_opt: EmbeddingOption = None): + def _load_yelp2015_broken(self,path): + ds = DataSet() + with open (path,encoding='ISO 8859-1') as f: + row=f.readline() + all_count=0 + exp_count=0 + while row: + row=row.split("\t\t") + all_count+=1 + if len(row)>=3: + words=row[-1].split() + try: + target=self.tag_v[str(row[-2])+".0"] + ds.append(Instance(words=words, target=target)) + except KeyError: + exp_count+=1 + else: + exp_count+=1 + row = f.readline() + print("error sample count:",exp_count) + print("all count:",all_count) + return ds + ''' + + def _load(self, path): + ds = DataSet() + csv_reader = csv.reader(open(path, encoding='utf-8')) + all_count = 0 + real_count = 0 + for row in csv_reader: + all_count += 1 + if len(row) == 2: + target = self.tag_v[row[0] + ".0"] + words = clean_str(row[1], self.lower) + if len(words) != 0: + ds.append(Instance(words=words, target=target)) + real_count += 1 + print("all count:", all_count) + print("real count:", real_count) + return ds + + def process(self, paths: Union[str, Dict[str, str]], + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + embed_opt: EmbeddingOption = None, + char_level_op=False): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() - vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) - for name, path in paths.items(): - dataset = self.load(path) - datasets[name] = dataset - vocab.from_dataset(dataset, field_name="words") - info.vocabs = vocab - info.datasets = datasets - if embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) - info.embeddings['words'] = embed + info = DataInfo(datasets=self.load(paths)) + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + + # vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) + # for name, path in paths.items(): + # dataset = self.load(path) + # datasets[name] = dataset + # vocab.from_dataset(dataset, field_name="words") + # info.vocabs = vocab + # info.datasets = datasets + + def wordtochar(words): + chars = [] + for word in words: + word = word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs = {} + # 就分隔为char形式 + if char_level_op: + for dataset in info.datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + # if embed_opt is not None: + # embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) + # info.embeddings['words'] = embed + else: + src_vocab.from_dataset(*_train_ds, field_name=input_name) + src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) + info.vocabs[input_name] = src_vocab + + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs[target_name] = tgt_vocab + return info + +if __name__ == "__main__": + testloader = yelpLoader() + # datapath = {"train": "/remote-home/ygwang/yelp_full/train.csv", + # "test": "/remote-home/ygwang/yelp_full/test.csv"} + # datapath={"train": "/remote-home/ygwang/yelp_full/test.csv"} + datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", + "test": "/remote-home/ygwang/yelp_polarity/test.csv"} + datainfo = testloader.process(datapath, char_level_op=True) + + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index 13ff4fc1..bf243ffb 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -1,65 +1,83 @@ # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 + +from torch.optim.lr_scheduler import CosineAnnealingLR +import torch.cuda +from torch.optim import SGD +from fastNLP.core.trainer import Trainer +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from reproduction.text_classification.model.dpcnn import DPCNN +from .data.yelpLoader import yelpLoader +from fastNLP.io.dataset_loader import SSTLoader +import torch.nn as nn +from fastNLP.core import LRScheduler +from fastNLP.core.const import Const as C +import sys import os os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from fastNLP.core.const import Const as C -from fastNLP.core import LRScheduler -import torch.nn as nn -from fastNLP.io.dataset_loader import SSTLoader -from reproduction.text_classification.model.dpcnn import DPCNN -from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding -from fastNLP import CrossEntropyLoss, AccuracyMetric -from fastNLP.core.trainer import Trainer -from torch.optim import SGD -import torch.cuda -from torch.optim.lr_scheduler import CosineAnnealingLR +sys.path.append('../..') + + +# hyper -##hyper class Config(): - model_dir_or_name="en-base-uncased" - embedding_grad= False, - train_epoch= 30 + model_dir_or_name = "en-base-uncased" + embedding_grad = False, + train_epoch = 30 batch_size = 100 - num_classes=5 - task= "SST" - datadir = '/remote-home/yfshao/workdir/datasets/SST' - datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} - lr=1e-3 + num_classes = 2 + task = "yelp_p" + #datadir = '/remote-home/yfshao/workdir/datasets/SST' + datadir = '/remote-home/ygwang/yelp_polarity' + #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + datafile = {"train": "train.csv", "test": "test.csv"} + lr = 1e-3 + def __init__(self): - self.datapath = {k:os.path.join(self.datadir, v) + self.datapath = {k: os.path.join(self.datadir, v) for k, v in self.datafile.items()} -ops=Config() + +ops = Config() -##1.task相关信息:利用dataloader载入dataInfo -datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds='train') +# 1.task相关信息:利用dataloader载入dataInfo +#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) +datainfo = yelpLoader(fine_grained=True, lower=True).process( + paths=ops.datapath, train_ds=['train']) print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['dev'])) +print(len(datainfo.datasets['test'])) -## 2.或直接复用fastNLP的模型 -vocab = datainfo.vocabs['words'] +# 2.或直接复用fastNLP的模型 +vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) -embedding = StaticEmbedding(vocab) +#embedding = StaticEmbedding(vocab) +embedding = StaticEmbedding( + vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) + print(len(vocab)) print(len(datainfo.vocabs['target'])) + model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) -## 3. 声明loss,metric,optimizer -loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) -metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) -optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], - lr=ops.lr, momentum=0.9, weight_decay=0) + +# 3. 声明loss,metric,optimizer +loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], + lr=ops.lr, momentum=0.9, weight_decay=0) callbacks = [] callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + print(device) for ds in datainfo.datasets.values(): @@ -67,14 +85,17 @@ for ds in datainfo.datasets.values(): ds.set_input(C.INPUT, C.INPUT_LEN) ds.set_target(C.TARGET) -## 4.定义train方法 -def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): + +# 4.定义train方法 +def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=[metrics], dev_data=datainfo.datasets['dev'], device=device, + metrics=[metrics], + dev_data=datainfo.datasets['test'], device=device, check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, n_epochs=num_epochs) + print(trainer.train()) -if __name__=="__main__": - train(model,datainfo,loss,metric,optimizer) \ No newline at end of file +if __name__ == "__main__": + train(model, datainfo, loss, metric, optimizer) From c5fc29dfef2a48729aafed42ccf6309aa5bee17e Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 6 Jul 2019 13:15:40 +0800 Subject: [PATCH 04/20] -update DPCNN & train script -use spacy tokenizer for yelp data -add set_rng_seed --- fastNLP/modules/aggregator/attention.py | 9 +- .../text_classification/data/yelpLoader.py | 18 +++- .../text_classification/model/dpcnn.py | 22 +++-- .../text_classification/train_dpcnn.py | 83 ++++++++++++------- .../text_classification/utils/util_init.py | 11 +++ 5 files changed, 94 insertions(+), 49 deletions(-) create mode 100644 reproduction/text_classification/utils/util_init.py diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 4101b033..2bee7f2e 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -19,7 +19,7 @@ class DotAttention(nn.Module): 补上文档 """ - def __init__(self, key_size, value_size, dropout=0): + def __init__(self, key_size, value_size, dropout=0.0): super(DotAttention, self).__init__() self.key_size = key_size self.value_size = value_size @@ -37,7 +37,7 @@ class DotAttention(nn.Module): """ output = torch.matmul(Q, K.transpose(1, 2)) / self.scale if mask_out is not None: - output.masked_fill_(mask_out, -1e8) + output.masked_fill_(mask_out, -1e18) output = self.softmax(output) output = self.drop(output) return torch.matmul(output, V) @@ -67,9 +67,8 @@ class MultiHeadAttention(nn.Module): self.k_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size) # follow the paper, do not apply dropout within dot-product - self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) + self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) self.out = nn.Linear(value_size * num_head, input_size) - self.drop = TimestepDropout(dropout) self.reset_parameters() def reset_parameters(self): @@ -105,7 +104,7 @@ class MultiHeadAttention(nn.Module): # concat all heads, do output linear atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) - output = self.drop(self.out(atte)) + output = self.out(atte) return output diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index 63605ecf..d97f9399 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -8,11 +8,20 @@ from fastNLP.io.base_loader import DataInfo from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict -from reproduction.Star_transformer.datasets import EmbedLoader from reproduction.utils import check_dataloader_paths -def clean_str(sentence, char_lower=False): +def get_tokenizer(): + try: + import spacy + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split() + +def clean_str(sentence, tokenizer, char_lower=False): """ heavily borrowed from github https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb @@ -23,7 +32,7 @@ def clean_str(sentence, char_lower=False): sentence = sentence.lower() import re nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') - words = sentence.split() + words = tokenizer(sentence) words_collection = [] for word in words: if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: @@ -65,6 +74,7 @@ class yelpLoader(JsonLoader): self.fine_grained = fine_grained self.tag_v = tag_v self.lower = lower + self.tokenizer = get_tokenizer() ''' def _load_json(self, path): @@ -109,7 +119,7 @@ class yelpLoader(JsonLoader): all_count += 1 if len(row) == 2: target = self.tag_v[row[0] + ".0"] - words = clean_str(row[1], self.lower) + words = clean_str(row[1], self.tokenizer, self.lower) if len(words) != 0: ds.append(Instance(words=words, target=target)) real_count += 1 diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index 2da7b3e5..dafe62bc 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -3,22 +3,27 @@ import torch.nn as nn from fastNLP.modules.utils import get_embeddings from fastNLP.core import Const as C + class DPCNN(nn.Module): - def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + def __init__(self, init_embed, num_cls, n_filters=256, + kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1): super().__init__() - self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + self.region_embed = RegionEmbedding( + init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5]) embed_dim = self.region_embed.embedding_dim self.conv_list = nn.ModuleList() for i in range(n_layers): self.conv_list.append(nn.Sequential( nn.ReLU(), - nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), - nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), )) self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.embed_drop = nn.Dropout(embed_dropout) self.classfier = nn.Sequential( - nn.Dropout(dropout), + nn.Dropout(cls_dropout), nn.Linear(n_filters, num_cls), ) self.reset_parameters() @@ -57,7 +62,8 @@ class RegionEmbedding(nn.Module): super().__init__() if kernel_sizes is None: kernel_sizes = [5, 9] - assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + assert isinstance( + kernel_sizes, list), 'kernel_sizes should be List(int)' self.embed = get_embeddings(init_embed) try: embed_dim = self.embed.embedding_dim @@ -69,14 +75,14 @@ class RegionEmbedding(nn.Module): nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), )) self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) - for _ in range(len(kernel_sizes) + 1)]) + for _ in range(len(kernel_sizes))]) self.embedding_dim = embed_dim def forward(self, x): x = self.embed(x) x = x.transpose(1, 2) # B, C, L - out = self.linears[0](x) + out = 0 for conv, fc in zip(self.region_embeds, self.linears[1:]): conv_i = conv(x) out = out + fc(conv_i) diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index bf243ffb..294a0742 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -1,40 +1,44 @@ # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 -from torch.optim.lr_scheduler import CosineAnnealingLR import torch.cuda +from fastNLP.core.utils import cache_results from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from fastNLP.core.trainer import Trainer from fastNLP import CrossEntropyLoss, AccuracyMetric from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding from reproduction.text_classification.model.dpcnn import DPCNN -from .data.yelpLoader import yelpLoader -from fastNLP.io.dataset_loader import SSTLoader +from data.yelpLoader import yelpLoader import torch.nn as nn from fastNLP.core import LRScheduler from fastNLP.core.const import Const as C -import sys +from fastNLP.core.vocabulary import VocabularyOption +from utils.util_init import set_rng_seeds import os os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -sys.path.append('../..') - # hyper class Config(): - model_dir_or_name = "en-base-uncased" - embedding_grad = False, + seed = 12345 + model_dir_or_name = "dpcnn-yelp-p" + embedding_grad = True train_epoch = 30 batch_size = 100 num_classes = 2 task = "yelp_p" #datadir = '/remote-home/yfshao/workdir/datasets/SST' - datadir = '/remote-home/ygwang/yelp_polarity' + datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity' #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} datafile = {"train": "train.csv", "test": "test.csv"} lr = 1e-3 + src_vocab_op = VocabularyOption() + embed_dropout = 0.3 + cls_dropout = 0.1 + weight_decay = 1e-4 def __init__(self): self.datapath = {k: os.path.join(self.datadir, v) @@ -43,15 +47,23 @@ class Config(): ops = Config() +set_rng_seeds(ops.seed) +print('RNG SEED: {}'.format(ops.seed)) # 1.task相关信息:利用dataloader载入dataInfo #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) -datainfo = yelpLoader(fine_grained=True, lower=True).process( - paths=ops.datapath, train_ds=['train']) -print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['test'])) - +@cache_results(ops.model_dir_or_name+'-data-cache') +def load_data(): + datainfo = yelpLoader(fine_grained=True, lower=True).process( + paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op) + for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + return datainfo + +datainfo = load_data() # 2.或直接复用fastNLP的模型 @@ -59,43 +71,50 @@ vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) #embedding = StaticEmbedding(vocab) embedding = StaticEmbedding( - vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) + vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad, + normalize=False +) + +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['test'])) +print(datainfo.datasets['train'][0]) print(len(vocab)) print(len(datainfo.vocabs['target'])) -model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) - +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes, + embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) +print(model) # 3. 声明loss,metric,optimizer loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], - lr=ops.lr, momentum=0.9, weight_decay=0) + lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) callbacks = [] callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) +# callbacks.append +# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < +# ops.train_epoch * 0.8 else ops.lr * 0.1)) +# ) + +# callbacks.append( +# FitlogCallback(data=datainfo.datasets, verbose=1) +# ) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print(device) -for ds in datainfo.datasets.values(): - ds.apply_field(len, C.INPUT, C.INPUT_LEN) - ds.set_input(C.INPUT, C.INPUT_LEN) - ds.set_target(C.TARGET) - - # 4.定义train方法 -def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): - trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=[metrics], - dev_data=datainfo.datasets['test'], device=device, - check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, - n_epochs=num_epochs) +trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metric], + dev_data=datainfo.datasets['test'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=ops.train_epoch, num_workers=4) - print(trainer.train()) if __name__ == "__main__": - train(model, datainfo, loss, metric, optimizer) + print(trainer.train()) diff --git a/reproduction/text_classification/utils/util_init.py b/reproduction/text_classification/utils/util_init.py new file mode 100644 index 00000000..fcb8fffb --- /dev/null +++ b/reproduction/text_classification/utils/util_init.py @@ -0,0 +1,11 @@ +import numpy +import torch +import random + + +def set_rng_seeds(seed): + random.seed(seed) + numpy.random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # print('RNG_SEED {}'.format(seed)) From b02a91ea0117e49707476ce599159723efe36595 Mon Sep 17 00:00:00 2001 From: wyg <1505116161@qq.com> Date: Sat, 6 Jul 2019 15:36:45 +0800 Subject: [PATCH 05/20] =?UTF-8?q?[add]=20dataloader:=20yelp/sst2/IMDB/MTL1?= =?UTF-8?q?6=20[add]=20model:=20char=5Fcnn=20dpcnn=20[test]=20train=5Fchar?= =?UTF-8?q?=5Fcnn=20train=5Fdpcnn,dataloader=E7=9A=84=E5=9D=87=E5=9C=A8?= =?UTF-8?q?=E8=87=AA=E5=B7=B1=E7=9A=84main=E6=96=B9=E6=B3=95=E5=86=85?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../text_classification/data/IMDBLoader.py | 107 +++++++++ .../text_classification/data/sstLoader.py | 98 +++++++++ .../text_classification/data/yelpLoader.py | 184 +++++++++++++--- .../text_classification/model/char_cnn.py | 91 +++++++- .../text_classification/model/dpcnn.py | 112 +++++++++- .../text_classification/train_char_cnn.py | 206 ++++++++++++++++++ .../text_classification/train_dpcnn.py | 101 +++++++++ 7 files changed, 868 insertions(+), 31 deletions(-) create mode 100644 reproduction/text_classification/data/IMDBLoader.py create mode 100644 reproduction/text_classification/data/sstLoader.py diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py new file mode 100644 index 00000000..cb422524 --- /dev/null +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -0,0 +1,107 @@ +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial + +class IMDBLoader(DataSetLoader): + """ + 读取IMDB数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1].split() + dataset.append(Instance(words=words, target=target)) + if len(dataset)==0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None, + char_level_op=False): + + # paths = check_dataloader_paths(paths) + + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars = [] + for word in words: + word = word.lower() + for char in word: + chars.append(char) + return chars + + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + # src_vocab.from_dataset(datasets['train'], datasets["dev"], datasets["test"], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info + +if __name__=="__main__": + datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv", + "test": "/remote-home/ygwang/IMDB_data/test.csv"} + datainfo=IMDBLoader().process(datapath,char_level_op=True) + #print(datainfo.datasets["train"]) + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/data/sstLoader.py b/reproduction/text_classification/data/sstLoader.py new file mode 100644 index 00000000..bffb67fd --- /dev/null +++ b/reproduction/text_classification/data/sstLoader.py @@ -0,0 +1,98 @@ +import csv +from typing import Iterable +from fastNLP import DataSet, Instance, Vocabulary +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataInfo,DataSetLoader +from fastNLP.io.embed_loader import EmbeddingOption +from fastNLP.io.file_reader import _read_json +from typing import Union, Dict +from reproduction.Star_transformer.datasets import EmbedLoader +from reproduction.utils import check_dataloader_paths + +class sst2Loader(DataSetLoader): + ''' + 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', + ''' + def __init__(self): + super(sst2Loader, self).__init__() + + def _load(self, path: str) -> DataSet: + ds = DataSet() + all_count=0 + csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') + skip_row = 0 + for idx,row in enumerate(csv_reader): + if idx<=skip_row: + continue + target = row[1] + words = row[0].split() + ds.append(Instance(words=words,target=target)) + all_count+=1 + print("all count:", all_count) + return ds + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None, + char_level_op=False): + + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars=[] + for word in words: + word=word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs={} + + # 就分隔为char形式 + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + return info + +if __name__=="__main__": + datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", + "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} + datainfo=sst2Loader().process(datapath,char_level_op=True) + #print(datainfo.datasets["train"]) + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index c47d48fd..9d34004d 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -1,8 +1,10 @@ import ast +import csv +from typing import Iterable from fastNLP import DataSet, Instance, Vocabulary from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io import JsonLoader -from fastNLP.io.base_loader import DataInfo +from fastNLP.io.base_loader import DataInfo,DataSetLoader from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict @@ -10,27 +12,44 @@ from reproduction.Star_transformer.datasets import EmbedLoader from reproduction.utils import check_dataloader_paths -class yelpLoader(JsonLoader): - +def clean_str(sentence,char_lower=False): """ - 读取Yelp数据集, DataSet包含fields: + heavily borrowed from github + https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb + :param sentence: is a str + :return: + """ + if char_lower: + sentence=sentence.lower() + import re + nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') + words = sentence.split() + words_collection = [] + for word in words: + if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: + continue + tt = nonalpnum.split(word) + t = ''.join(tt) + if t != '': + words_collection.append(t) + + return words_collection + + +class yelpLoader(DataSetLoader): - review_id: str, 22 character unique review id - user_id: str, 22 character unique user id - business_id: str, 22 character business id - useful: int, number of useful votes received - funny: int, number of funny votes received - cool: int, number of cool votes received - date: str, date formatted YYYY-MM-DD + """ + 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: + words: list(str), 需要分类的文本 target: str, 文本的标签 - - 数据来源: https://www.yelp.com/dataset/download - + chars:list(str),未index的字符列表 + + 数据集:yelp_full/yelp_polarity :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ - def __init__(self, fine_grained=False): + def __init__(self, fine_grained=False,lower=False): super(yelpLoader, self).__init__() tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', '4.0': 'positive', '5.0': 'very positive'} @@ -39,8 +58,24 @@ class yelpLoader(JsonLoader): tag_v['5.0'] = tag_v['4.0'] self.fine_grained = fine_grained self.tag_v = tag_v + self.lower=lower + + ''' + 读取Yelp数据集, DataSet包含fields: - def _load(self, path): + review_id: str, 22 character unique review id + user_id: str, 22 character unique user id + business_id: str, 22 character business id + useful: int, number of useful votes received + funny: int, number of funny votes received + cool: int, number of cool votes received + date: str, date formatted YYYY-MM-DD + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + 数据来源: https://www.yelp.com/dataset/download + + def _load_json(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): d = ast.literal_eval(d) @@ -48,21 +83,112 @@ class yelpLoader(JsonLoader): d["target"] = self.tag_v[str(d.pop("stars"))] ds.append(Instance(**d)) return ds + + def _load_yelp2015_broken(self,path): + ds = DataSet() + with open (path,encoding='ISO 8859-1') as f: + row=f.readline() + all_count=0 + exp_count=0 + while row: + row=row.split("\t\t") + all_count+=1 + if len(row)>=3: + words=row[-1].split() + try: + target=self.tag_v[str(row[-2])+".0"] + ds.append(Instance(words=words, target=target)) + except KeyError: + exp_count+=1 + else: + exp_count+=1 + row = f.readline() + print("error sample count:",exp_count) + print("all count:",all_count) + return ds + ''' + def _load(self, path): + ds = DataSet() + csv_reader=csv.reader(open(path,encoding='utf-8')) + all_count=0 + real_count=0 + for row in csv_reader: + all_count+=1 + if len(row)==2: + target=self.tag_v[row[0]+".0"] + words=clean_str(row[1],self.lower) + if len(words)!=0: + ds.append(Instance(words=words,target=target)) + real_count += 1 + print("all count:", all_count) + print("real count:", real_count) + return ds + - def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, - embed_opt: EmbeddingOption = None): + def process(self, paths: Union[str, Dict[str, str]], + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + embed_opt: EmbeddingOption = None, + char_level_op=False): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() - vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) - for name, path in paths.items(): - dataset = self.load(path) - datasets[name] = dataset - vocab.from_dataset(dataset, field_name="words") - info.vocabs = vocab - info.datasets = datasets - if embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) - info.embeddings['words'] = embed + info = DataInfo(datasets=self.load(paths)) + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + #vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) + # for name, path in paths.items(): + # dataset = self.load(path) + # datasets[name] = dataset + # vocab.from_dataset(dataset, field_name="words") + # info.vocabs = vocab + # info.datasets = datasets + + def wordtochar(words): + chars=[] + for word in words: + word=word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs={} + #就分隔为char形式 + if char_level_op: + for dataset in info.datasets.values(): + dataset.apply_field(wordtochar, field_name="words",new_field_name='chars') + # if embed_opt is not None: + # embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) + # info.embeddings['words'] = embed + else: + src_vocab.from_dataset(*_train_ds, field_name=input_name) + src_vocab.index_dataset(*info.datasets.values(),field_name=input_name, new_field_name=input_name) + info.vocabs[input_name]=src_vocab + + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs[target_name]=tgt_vocab + return info +if __name__=="__main__": + testloader=yelpLoader() + # datapath = {"train": "/remote-home/ygwang/yelp_full/train.csv", + # "test": "/remote-home/ygwang/yelp_full/test.csv"} + #datapath={"train": "/remote-home/ygwang/yelp_full/test.csv"} + datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", + "test": "/remote-home/ygwang/yelp_polarity/test.csv"} + datainfo=testloader.process(datapath,char_level_op=True) + + len_count=0 + for instance in datainfo.datasets["train"]: + len_count+=len(instance["chars"]) + + ave_len=len_count/len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/model/char_cnn.py b/reproduction/text_classification/model/char_cnn.py index f87f5c14..ac370082 100644 --- a/reproduction/text_classification/model/char_cnn.py +++ b/reproduction/text_classification/model/char_cnn.py @@ -1 +1,90 @@ -# TODO \ No newline at end of file +''' +@author: https://github.com/ahmedbesbes/character-based-cnn +这里借鉴了上述链接中char-cnn model的代码,改动主要为将其改动为符合fastnlp的pipline +''' +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C + +class CharacterLevelCNN(nn.Module): + def __init__(self, args,embedding): + super(CharacterLevelCNN, self).__init__() + + self.config=args.char_cnn_config + self.embedding=embedding + + conv_layers = [] + for i, conv_layer_parameter in enumerate(self.config['model_parameters'][args.model_size]['conv']): + if i == 0: + #in_channels = args.number_of_characters + len(args.extra_characters) + in_channels = args.embedding_dim + out_channels = conv_layer_parameter[0] + else: + in_channels, out_channels = conv_layer_parameter[0], conv_layer_parameter[0] + + if conv_layer_parameter[2] != -1: + conv_layer = nn.Sequential(nn.Conv1d(in_channels, + out_channels, + kernel_size=conv_layer_parameter[1], padding=0), + nn.ReLU(), + nn.MaxPool1d(conv_layer_parameter[2])) + else: + conv_layer = nn.Sequential(nn.Conv1d(in_channels, + out_channels, + kernel_size=conv_layer_parameter[1], padding=0), + nn.ReLU()) + conv_layers.append(conv_layer) + self.conv_layers = nn.ModuleList(conv_layers) + + input_shape = (args.batch_size, args.max_length, + args.number_of_characters + len(args.extra_characters)) + dimension = self._get_conv_output(input_shape) + + print('dimension :', dimension) + + fc_layer_parameter = self.config['model_parameters'][args.model_size]['fc'][0] + fc_layers = nn.ModuleList([ + nn.Sequential( + nn.Linear(dimension, fc_layer_parameter), nn.Dropout(0.5)), + nn.Sequential(nn.Linear(fc_layer_parameter, + fc_layer_parameter), nn.Dropout(0.5)), + nn.Linear(fc_layer_parameter, args.num_classes), + ]) + + self.fc_layers = fc_layers + + if args.model_size == 'small': + self._create_weights(mean=0.0, std=0.05) + elif args.model_size == 'large': + self._create_weights(mean=0.0, std=0.02) + + def _create_weights(self, mean=0.0, std=0.05): + for module in self.modules(): + if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear): + module.weight.data.normal_(mean, std) + + def _get_conv_output(self, shape): + input = torch.rand(shape) + output = input.transpose(1, 2) + # forward pass through conv layers + for i in range(len(self.conv_layers)): + output = self.conv_layers[i](output) + + output = output.view(output.size(0), -1) + n_size = output.size(1) + return n_size + + def forward(self, chars): + input=self.embedding(chars) + output = input.transpose(1, 2) + # forward pass through conv layers + for i in range(len(self.conv_layers)): + output = self.conv_layers[i](output) + + output = output.view(output.size(0), -1) + + # forward pass through fc layers + for i in range(len(self.fc_layers)): + output = self.fc_layers[i](output) + + return {C.OUTPUT: output} \ No newline at end of file diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index f87f5c14..a846af72 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -1 +1,111 @@ -# TODO \ No newline at end of file + + +import torch +import torch.nn as nn +from fastNLP.modules.utils import get_embeddings +from fastNLP.core import Const as C + + + +class DPCNN(nn.Module): + + def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + super().__init__() + self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + embed_dim = self.region_embed.embedding_dim + self.conv_list = nn.ModuleList() + for i in range(n_layers): + self.conv_list.append(nn.Sequential( + nn.ReLU(), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + )) + + self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + self.embed_drop = nn.Dropout(embed_dropout) + self.classfier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(n_filters, num_cls), + ) + self.reset_parameters() + + + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + + + def forward(self, words, seq_len=None): + words = words.long() + # get region embeddings + x = self.region_embed(words) + x = self.embed_drop(x) + + # not pooling on first conv + x = self.conv_list[0](x) + x + for conv in self.conv_list[1:]: + x = self.pool(x) + x = conv(x) + x + + # B, C, L => B, C + x, _ = torch.max(x, dim=2) + x = self.classfier(x) + return {C.OUTPUT: x} + + + + def predict(self, words, seq_len=None): + x = self.forward(words, seq_len)[C.OUTPUT] + return {C.OUTPUT: torch.argmax(x, 1)} + + + + + +class RegionEmbedding(nn.Module): + def __init__(self, init_embed, out_dim=300, kernel_sizes=None): + super().__init__() + if kernel_sizes is None: + kernel_sizes = [5, 9] + assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + self.embed = get_embeddings(init_embed) + try: + embed_dim = self.embed.embedding_dim + except Exception: + embed_dim = self.embed.embed_size + self.region_embeds = nn.ModuleList() + for ksz in kernel_sizes: + self.region_embeds.append(nn.Sequential( + nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), + )) + self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) + for _ in range(len(kernel_sizes) + 1)]) + self.embedding_dim = embed_dim + + + def forward(self, x): + x = self.embed(x) + x = x.transpose(1, 2) + # B, C, L + out = self.linears[0](x) + for conv, fc in zip(self.region_embeds, self.linears[1:]): + conv_i = conv(x) + out = out + fc(conv_i) + # B, C, L + + return out + + + + + +if __name__ == '__main__': + x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long) + model = DPCNN((10000, 300), 20) + y = model(x) + print(y.size(), y.mean(1), y.std(1)) \ No newline at end of file diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index e69de29b..c2c983a4 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -0,0 +1,206 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + +import sys +sys.path.append('../..') +from fastNLP.core.const import Const as C +import torch.nn as nn +from fastNLP.io.dataset_loader import SSTLoader +from data.yelpLoader import yelpLoader +from data.sstLoader import sst2Loader +from data.IMDBLoader import IMDBLoader +from model.char_cnn import CharacterLevelCNN +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.models.cnn_text_classification import CNNText +from fastNLP.modules.encoder.embedding import CNNCharEmbedding,StaticEmbedding,StackEmbedding,LSTMCharEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +from torch.autograd import Variable +import torch +from fastNLP import BucketSampler + +##hyper +#todo 这里加入fastnlp的记录 +class Config(): + model_dir_or_name="en-base-uncased" + embedding_grad= False, + bert_embedding_larers= '4,-2,-1' + train_epoch= 50 + num_classes=2 + task= "IMDB" + #yelp_p + datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", + "test": "/remote-home/ygwang/yelp_polarity/test.csv"} + #IMDB + #datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv", + # "test": "/remote-home/ygwang/IMDB_data/test.csv"} + # sst + # datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", + # "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} + + lr=0.01 + batch_size=128 + model_size="large" + number_of_characters=69 + extra_characters='' + max_length=1014 + + char_cnn_config={ + "alphabet": { + "en": { + "lower": { + "alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}", + "number_of_characters": 69 + }, + "both": { + "alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}", + "number_of_characters": 95 + } + } + }, + "model_parameters": { + "small": { + "conv": [ + #依次是channel,kennnel_size,maxpooling_size + [256,7,3], + [256,7,3], + [256,3,-1], + [256,3,-1], + [256,3,-1], + [256,3,3] + ], + "fc": [1024,1024] + }, + "large":{ + "conv":[ + [1024, 7, 3], + [1024, 7, 3], + [1024, 3, -1], + [1024, 3, -1], + [1024, 3, -1], + [1024, 3, 3] + ], + "fc": [2048,2048] + } + }, + "data": { + "text_column": "SentimentText", + "label_column": "Sentiment", + "max_length": 1014, + "num_of_classes": 2, + "encoding": None, + "chunksize": 50000, + "max_rows": 100000, + "preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"] + }, + "training": { + "batch_size": 128, + "learning_rate": 0.01, + "epochs": 10, + "optimizer": "sgd" + } + } +ops=Config + + +##1.task相关信息:利用dataloader载入dataInfo +dataloader=sst2Loader() +dataloader=IMDBLoader() +#dataloader=yelpLoader(fine_grained=True) +datainfo=dataloader.process(ops.datapath,char_level_op=True) +char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] +ops.number_of_characters=len(char_vocab) +ops.embedding_dim=ops.number_of_characters + +#chartoindex +def chartoindex(chars): + max_seq_len=ops.max_length + zero_index=len(char_vocab) + char_index_list=[] + for char in chars: + if char in char_vocab: + char_index_list.append(char_vocab.index(char)) + else: + #均使用最后一个作为embbeding + char_index_list.append(zero_index) + if len(char_index_list) > max_seq_len: + char_index_list = char_index_list[:max_seq_len] + elif 0 < len(char_index_list) < max_seq_len: + char_index_list = char_index_list+[zero_index]*(max_seq_len-len(char_index_list)) + elif len(char_index_list) == 0: + char_index_list=[zero_index]*max_seq_len + return char_index_list + +for dataset in datainfo.datasets.values(): + dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars') + +datainfo.datasets['train'].set_input('chars') +datainfo.datasets['test'].set_input('chars') +datainfo.datasets['train'].set_target('target') +datainfo.datasets['test'].set_target('target') + +##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model +class ModelFactory(nn.Module): + """ + 用于拼装embedding,encoder,decoder 以及设计forward过程 + + :param embedding: embbeding model + :param encoder: encoder model + :param decoder: decoder model + + """ + def __int__(self,embedding,encoder,decoder,**kwargs): + super(ModelFactory,self).__init__() + self.embedding=embedding + self.encoder=encoder + self.decoder=decoder + + def forward(self,x): + return {C.OUTPUT:None} + +## 2.或直接复用fastNLP的模型 +#vocab=datainfo.vocabs['words'] +vocab_label=datainfo.vocabs['target'] +''' +# emded_char=CNNCharEmbedding(vocab) +# embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) +# embedding=StackEmbedding([emded_char, embed_word]) +# cnn_char_embed = CNNCharEmbedding(vocab) +# lstm_char_embed = LSTMCharEmbedding(vocab) +# embedding = StackEmbedding([cnn_char_embed, lstm_char_embed]) +''' +#one-hot embedding +embedding_weight= Variable(torch.zeros(len(char_vocab)+1, len(char_vocab))) + +for i in range(len(char_vocab)): + embedding_weight[i][i]=1 +embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_vocab),padding_idx=len(char_vocab),_weight=embedding_weight) +for para in embedding.parameters(): + para.requires_grad=False +#CNNText太过于简单 +#model=CNNText(init_embed=embedding, num_classes=ops.num_classes) +model=CharacterLevelCNN(ops,embedding) + +## 3. 声明loss,metric,optimizer +loss=CrossEntropyLoss +metric=AccuracyMetric +optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) + +## 4.定义train方法 +def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'), + metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, + n_epochs=num_epochs) + print(trainer.train()) + + + +if __name__=="__main__": + #print(vocab_label) + + #print(datainfo.datasets["train"]) + train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch) + \ No newline at end of file diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index e69de29b..8ddea1a3 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -0,0 +1,101 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 + +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +import sys +sys.path.append('../..') +from fastNLP.core.const import Const as C +from fastNLP.core import LRScheduler +import torch.nn as nn +from fastNLP.io.dataset_loader import SSTLoader +from data.yelpLoader import yelpLoader +from reproduction.text_classification.model.dpcnn import DPCNN +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +import torch.cuda +from torch.optim.lr_scheduler import CosineAnnealingLR + + +##hyper + +class Config(): + model_dir_or_name="en-base-uncased" + embedding_grad= False, + train_epoch= 30 + batch_size = 100 + num_classes=2 + task= "yelp_p" + #datadir = '/remote-home/yfshao/workdir/datasets/SST' + datadir = '/remote-home/ygwang/yelp_polarity' + #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + datafile = {"train": "train.csv", "test": "test.csv"} + lr=1e-3 + + def __init__(self): + self.datapath = {k:os.path.join(self.datadir, v) + for k, v in self.datafile.items()} + +ops=Config() + + + +##1.task相关信息:利用dataloader载入dataInfo + +#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) +datainfo=yelpLoader(fine_grained=True,lower=True).process(paths=ops.datapath, train_ds=['train']) +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['test'])) + + +## 2.或直接复用fastNLP的模型 + +vocab = datainfo.vocabs['words'] +# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) +#embedding = StaticEmbedding(vocab) +embedding = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) + +print(len(vocab)) +print(len(datainfo.vocabs['target'])) + +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) + + + +## 3. 声明loss,metric,optimizer +loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], + lr=ops.lr, momentum=0.9, weight_decay=0) + +callbacks = [] +callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) + + +device = 'cuda:3' if torch.cuda.is_available() else 'cpu' + +print(device) + +for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + + +## 4.定义train方法 +def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metrics], dev_data=datainfo.datasets['test'], device=3, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=num_epochs) + + print(trainer.train()) + + + +if __name__=="__main__": + train(model,datainfo,loss,metric,optimizer) \ No newline at end of file From c7463cf0907b49f20007b3fd8d487d9ee84051e1 Mon Sep 17 00:00:00 2001 From: wyg <1505116161@qq.com> Date: Sat, 6 Jul 2019 16:45:37 +0800 Subject: [PATCH 06/20] [verify] yelpdataloader --- reproduction/text_classification/data/yelpLoader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index 9e1e1c6b..90a80301 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -128,7 +128,7 @@ class yelpLoader(DataSetLoader): all_count+=1 if len(row)==2: target=self.tag_v[row[0]+".0"] - words=clean_str(row[1],self.lower) + words=clean_str(row[1],self.tokenizer,self.lower) if len(words)!=0: ds.append(Instance(words=words,target=target)) real_count += 1 From d05aca6da62ecf1d7f41bbeeb03ea470a40ca165 Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Sun, 7 Jul 2019 08:21:51 +0800 Subject: [PATCH 07/20] TC/LSTM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LSTM、AWDLSTM和LSTM+self attention三个模型 --- .../text_classification/data/IMDBLoader.py | 80 +++++++++++++ .../text_classification/data/MTL16Loader.py | 6 +- .../text_classification/data/SSTLoader.py | 99 ++++++++++++++++ .../text_classification/data/yelpLoader.py | 111 ++++++++++-------- .../text_classification/model/awd_lstm.py | 31 +++++ .../model/awdlstm_module.py | 86 ++++++++++++++ .../text_classification/model/lstm.py | 30 +++++ .../model/lstm_self_attention.py | 35 ++++++ .../text_classification/model/weight_drop.py | 99 ++++++++++++++++ .../text_classification/results_LSTM.xlsx | Bin 0 -> 9944 bytes .../text_classification/train_awdlstm.py | 102 ++++++++++++++++ .../text_classification/train_lstm.py | 99 ++++++++++++++++ .../text_classification/train_lstm_att.py | 101 ++++++++++++++++ 13 files changed, 827 insertions(+), 52 deletions(-) create mode 100644 reproduction/text_classification/data/IMDBLoader.py create mode 100644 reproduction/text_classification/data/SSTLoader.py create mode 100644 reproduction/text_classification/model/awd_lstm.py create mode 100644 reproduction/text_classification/model/awdlstm_module.py create mode 100644 reproduction/text_classification/model/lstm.py create mode 100644 reproduction/text_classification/model/lstm_self_attention.py create mode 100644 reproduction/text_classification/model/weight_drop.py create mode 100644 reproduction/text_classification/results_LSTM.xlsx create mode 100644 reproduction/text_classification/train_awdlstm.py create mode 100644 reproduction/text_classification/train_lstm.py create mode 100644 reproduction/text_classification/train_lstm_att.py diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py new file mode 100644 index 00000000..d591cdf8 --- /dev/null +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -0,0 +1,80 @@ +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial + +class IMDBLoader(DataSetLoader): + """ + 读取IMDB数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1].lower().split() + dataset.append(Instance(words=words, target=target)) + if len(dataset)==0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None): + + # paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info diff --git a/reproduction/text_classification/data/MTL16Loader.py b/reproduction/text_classification/data/MTL16Loader.py index 1b3e6245..066b53b4 100644 --- a/reproduction/text_classification/data/MTL16Loader.py +++ b/reproduction/text_classification/data/MTL16Loader.py @@ -32,7 +32,7 @@ class MTL16Loader(DataSetLoader): continue parts = line.split('\t') target = parts[0] - words = parts[1].split() + words = parts[1].lower().split() dataset.append(Instance(words=words, target=target)) if len(dataset)==0: raise RuntimeError(f"{path} has no valid data.") @@ -72,4 +72,8 @@ class MTL16Loader(DataSetLoader): embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) info.embeddings['words'] = embed + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + return info diff --git a/reproduction/text_classification/data/SSTLoader.py b/reproduction/text_classification/data/SSTLoader.py new file mode 100644 index 00000000..b570994e --- /dev/null +++ b/reproduction/text_classification/data/SSTLoader.py @@ -0,0 +1,99 @@ +from typing import Iterable +from nltk import Tree +from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.core.vocabulary import VocabularyOption, Vocabulary +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader + + +class SSTLoader(DataSetLoader): + URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' + DATA_DIR = 'sst/' + + """ + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` + + 读取SST数据集, DataSet包含fields:: + + words: list(str) 需要分类的文本 + target: str 文本的标签 + + 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + + :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + """ + + def __init__(self, subtree=False, fine_grained=False): + self.subtree = subtree + + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', + '3': 'positive', '4': 'very positive'} + if not fine_grained: + tag_v['0'] = tag_v['1'] + tag_v['4'] = tag_v['3'] + self.tag_v = tag_v + + def _load(self, path): + """ + + :param str path: 存储数据的路径 + :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + datas = [] + for l in f: + datas.extend([(s, self.tag_v[t]) + for s, t in self._get_one(l, self.subtree)]) + ds = DataSet() + for words, tag in datas: + ds.append(Instance(words=words, target=tag)) + return ds + + @staticmethod + def _get_one(data, subtree): + tree = Tree.fromstring(data) + if subtree: + return [(t.leaves(), t.label()) for t in tree.subtrees()] + return [(tree.leaves(), tree.label())] + + def process(self, + paths, + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + src_embed_op: EmbeddingOption = None): + input_name, target_name = 'words', 'target' + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + + info = DataInfo(datasets=self.load(paths)) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + src_vocab.from_dataset(*_train_ds, field_name=input_name) + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + src_vocab.index_dataset( + *info.datasets.values(), + field_name=input_name, new_field_name=input_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs = { + input_name: src_vocab, + target_name: tgt_vocab + } + + if src_embed_op is not None: + src_embed_op.vocab = src_vocab + init_emb = EmbedLoader.load_with_vocab(**src_embed_op) + info.embeddings[input_name] = init_emb + + for name, dataset in info.datasets.items(): + dataset.set_input(input_name) + dataset.set_target(target_name) + + return info + diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index c47d48fd..680b3488 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -1,68 +1,77 @@ -import ast -from fastNLP import DataSet, Instance, Vocabulary +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io import JsonLoader -from fastNLP.io.base_loader import DataInfo -from fastNLP.io.embed_loader import EmbeddingOption -from fastNLP.io.file_reader import _read_json -from typing import Union, Dict -from reproduction.Star_transformer.datasets import EmbedLoader -from reproduction.utils import check_dataloader_paths +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial +import pandas as pd - -class yelpLoader(JsonLoader): - +class yelpLoader(DataSetLoader): """ - 读取Yelp数据集, DataSet包含fields: - - review_id: str, 22 character unique review id - user_id: str, 22 character unique user id - business_id: str, 22 character business id - useful: int, number of useful votes received - funny: int, number of funny votes received - cool: int, number of cool votes received - date: str, date formatted YYYY-MM-DD + 读取IMDB数据集,DataSet包含以下fields: + words: list(str), 需要分类的文本 target: str, 文本的标签 - - 数据来源: https://www.yelp.com/dataset/download - - :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + + """ - - def __init__(self, fine_grained=False): + + def __init__(self): super(yelpLoader, self).__init__() - tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', - '4.0': 'positive', '5.0': 'very positive'} - if not fine_grained: - tag_v['1.0'] = tag_v['2.0'] - tag_v['5.0'] = tag_v['4.0'] - self.fine_grained = fine_grained - self.tag_v = tag_v - + def _load(self, path): - ds = DataSet() - for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): - d = ast.literal_eval(d) - d["words"] = d.pop("text").split() - d["target"] = self.tag_v[str(d.pop("stars"))] - ds.append(Instance(**d)) - return ds + dataset = DataSet() + data = pd.read_csv(path, header=None, sep=",").values + for line in data: + target = str(line[0]) + words = str(line[1]).lower().split() + dataset.append(Instance(words=words, target=target)) + if len(dataset)==0: + raise RuntimeError(f"{path} has no valid data.") - def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, - embed_opt: EmbeddingOption = None): - paths = check_dataloader_paths(paths) + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None): + + # paths = check_dataloader_paths(paths) datasets = {} info = DataInfo() - vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset - vocab.from_dataset(dataset, field_name="words") - info.vocabs = vocab + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + info.datasets = datasets - if embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) info.embeddings['words'] = embed - return info + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info diff --git a/reproduction/text_classification/model/awd_lstm.py b/reproduction/text_classification/model/awd_lstm.py new file mode 100644 index 00000000..0d8f711a --- /dev/null +++ b/reproduction/text_classification/model/awd_lstm.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from .awdlstm_module import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.decoder.mlp import MLP + + +class AWDLSTMSentiment(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + nfc=128, + wdrop=0.5): + super(AWDLSTMSentiment,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop) + self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + output = self.mlp(output[:,-1,:]) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} + diff --git a/reproduction/text_classification/model/awdlstm_module.py b/reproduction/text_classification/model/awdlstm_module.py new file mode 100644 index 00000000..87bfe730 --- /dev/null +++ b/reproduction/text_classification/model/awdlstm_module.py @@ -0,0 +1,86 @@ +""" +轻量封装的 Pytorch LSTM 模块. +可在 forward 时传入序列的长度, 自动对padding做合适的处理. +""" +__all__ = [ + "LSTM" +] + +import torch +import torch.nn as nn +import torch.nn.utils.rnn as rnn + +from fastNLP.modules.utils import initial_parameter +from torch import autograd +from .weight_drop import WeightDrop + + +class LSTM(nn.Module): + """ + 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` + + LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 + 为1; 且可以应对DataParallel中LSTM的使用问题。 + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度. + :param num_layers: rnn的层数. Default: 1 + :param dropout: 层间dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + """ + + def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, + bidirectional=False, bias=True, wdrop=0.5): + super(LSTM, self).__init__() + self.batch_first = batch_first + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, + dropout=dropout, bidirectional=bidirectional) + self.lstm = WeightDrop(self.lstm, ['weight_hh_l0'], dropout=wdrop) + self.init_param() + + def init_param(self): + for name, param in self.named_parameters(): + if 'bias' in name: + # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 + param.data.fill_(0) + n = param.size(0) + start, end = n // 4, n // 2 + param.data[start:end].fill_(1) + else: + nn.init.xavier_uniform_(param) + + def forward(self, x, seq_len=None, h0=None, c0=None): + """ + + :param x: [batch, seq_len, input_size] 输入序列 + :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` + :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` + :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` + :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 + 和 [batch, hidden_size*num_direction] 最后时刻隐状态. + """ + batch_size, max_len, _ = x.size() + if h0 is not None and c0 is not None: + hx = (h0, c0) + else: + hx = None + if seq_len is not None and not isinstance(x, rnn.PackedSequence): + sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) + if self.batch_first: + x = x[sort_idx] + else: + x = x[:, sort_idx] + x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) + output, hx = self.lstm(x, hx) # -> [N,L,C] + output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + if self.batch_first: + output = output[unsort_idx] + else: + output = output[:, unsort_idx] + else: + output, hx = self.lstm(x, hx) + return output, hx diff --git a/reproduction/text_classification/model/lstm.py b/reproduction/text_classification/model/lstm.py new file mode 100644 index 00000000..388f3f1c --- /dev/null +++ b/reproduction/text_classification/model/lstm.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.decoder.mlp import MLP + + +class BiLSTMSentiment(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + nfc=128): + super(BiLSTMSentiment,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) + self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + output = self.mlp(output[:,-1,:]) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} + diff --git a/reproduction/text_classification/model/lstm_self_attention.py b/reproduction/text_classification/model/lstm_self_attention.py new file mode 100644 index 00000000..239635fe --- /dev/null +++ b/reproduction/text_classification/model/lstm_self_attention.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.modules import encoder +from fastNLP.modules.aggregator.attention import SelfAttention +from fastNLP.modules.decoder.mlp import MLP + + +class BiLSTM_SELF_ATTENTION(nn.Module): + def __init__(self, init_embed, + num_classes, + hidden_dim=256, + num_layers=1, + attention_unit=256, + attention_hops=1, + nfc=128): + super(BiLSTM_SELF_ATTENTION,self).__init__() + self.embed = encoder.Embedding(init_embed) + self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) + self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops) + self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes]) + + def forward(self, words): + x_emb = self.embed(words) + output, _ = self.lstm(x_emb) + after_attention, penalty = self.attention(output,words) + after_attention =after_attention.view(after_attention.size(0),-1) + output = self.mlp(after_attention) + return {C.OUTPUT: output} + + def predict(self, words): + output = self(words) + _, predict = output[C.OUTPUT].max(dim=1) + return {C.OUTPUT: predict} diff --git a/reproduction/text_classification/model/weight_drop.py b/reproduction/text_classification/model/weight_drop.py new file mode 100644 index 00000000..60fda179 --- /dev/null +++ b/reproduction/text_classification/model/weight_drop.py @@ -0,0 +1,99 @@ +import torch +from torch.nn import Parameter +from functools import wraps + +class WeightDrop(torch.nn.Module): + def __init__(self, module, weights, dropout=0, variational=False): + super(WeightDrop, self).__init__() + self.module = module + self.weights = weights + self.dropout = dropout + self.variational = variational + self._setup() + + def widget_demagnetizer_y2k_edition(*args, **kwargs): + # We need to replace flatten_parameters with a nothing function + # It must be a function rather than a lambda as otherwise pickling explodes + # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! + # (╯°□°)╯︵ ┻━┻ + return + + def _setup(self): + # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN + if issubclass(type(self.module), torch.nn.RNNBase): + self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition + + for name_w in self.weights: + print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) + w = getattr(self.module, name_w) + del self.module._parameters[name_w] + self.module.register_parameter(name_w + '_raw', Parameter(w.data)) + + def _setweights(self): + for name_w in self.weights: + raw_w = getattr(self.module, name_w + '_raw') + w = None + if self.variational: + mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) + if raw_w.is_cuda: mask = mask.cuda() + mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) + w = mask.expand_as(raw_w) * raw_w + else: + w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) + setattr(self.module, name_w, w) + + def forward(self, *args): + self._setweights() + return self.module.forward(*args) + +if __name__ == '__main__': + import torch + from weight_drop import WeightDrop + + # Input is (seq, batch, input) + x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() + h0 = None + + ### + + print('Testing WeightDrop') + print('=-=-=-=-=-=-=-=-=-=') + + ### + + print('Testing WeightDrop with Linear') + + lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) + lin.cuda() + run1 = [x.sum() for x in lin(x).data] + run2 = [x.sum() for x in lin(x).data] + + print('All items should be different') + print('Run 1:', run1) + print('Run 2:', run2) + + assert run1[0] != run2[0] + assert run1[1] != run2[1] + + print('---') + + ### + + print('Testing WeightDrop with LSTM') + + wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) + wdrnn.cuda() + + run1 = [x.sum() for x in wdrnn(x, h0)[0].data] + run2 = [x.sum() for x in wdrnn(x, h0)[0].data] + + print('First timesteps should be equal, all others should differ') + print('Run 1:', run1) + print('Run 2:', run2) + + # First time step, not influenced by hidden to hidden weights, should be equal + assert run1[0] == run2[0] + # Second step should not + assert run1[1] != run2[1] + + print('---') diff --git a/reproduction/text_classification/results_LSTM.xlsx b/reproduction/text_classification/results_LSTM.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..0d7b841b12b43ee346c4d9db17032fade8c0b40a GIT binary patch literal 9944 zcmeHNWmFtlw{0APHPE=by9Rd+1a}G2xVu{jmf!?;OK^8dkYEYHgIj>0!ChY`GxIXT zSV;*Y zQ45hC)h0rUXS_C5!#Y0QjUTPKibgeJC$)<0upZ5^AYS<_jgGkJsx8Zv&Qq#q$Zw*` zoeza>WwWJ%3(7R3YGYq==is64@57JRBRpxEHDR%)xTm1T;jVkcxzTx~v?X2|E+RqQ z&hV`S`f_?lXrN06T7ljT&xFtOxhgY({GwJ{n+gY~CRJs+3$X@E;&~K3s9L$X?ut}y zdvFYtV^D30&o4nLYE=gH){cV|g)q^$+P==7=ZFW(zITKjq(H%dru&($dEKO zat7PFu(JHv|2IAVgLCq4uU;0f__Ui1C3Ii*CanK_av>HduHYdi+e)tPA1F7EULTcD zMY!;emI$a$@Ek@apxyswaBe{`YP+BO>?>EL&<93FpCxe5>_EmD2`BSkZ!M{&Z5ah1=OUt z+V}F%sz%4$4_?0QnCA8Oe|(9eEV(5PcI2H)GQxcRheWt^HX>k znc2QSmP~(#b?eKll+mwB!G(FjF)Y(ZopTzXTgP?yG1IGu7pc3n|6m~aCF$EmNRRdJ zB;mXrhQ)#b05;(P0Cb3FJZ)J$9Gq=T92{(Z^kZcj1`g?LAiwOo`@kEe&XCHyB5Wkd z(g`<}R(n?2*ESi;*le|m2_@SBOWgEZ4Q;NZh_)=a(ZlWM`Yz55cGf#Vav!7wq%QiU zaBPa{yjrvJ(ua`+q?5)9nB|ytd8A^V;N03_5M|;nC^2; zOxZ=FJUAbF5`H0_!!*AcxhV<*4>Hga4rGfQc-GlVfBWO%JQmI;^%R?><^<9ZmXbSTw}hEt4w`G*JdfoV%`SrmhYGM z*wFP!2i#Ad@osrt4zAF5>#V_LIb3^f&UD16`S5tC5n4b+J0*Tdkh-zW<8EER{miqi z!Q5NdVC1D2IRsKPLFc;TOY4KuPVYQ_o!k<9LE}^xa29mv1{ad=xrwrcZ>V;%)R#_&?kIbnXhQS+ut?#FV^t`Dm%`tEvaH~hb?izQlIGI`!*`>c-O z4P3OcGci(}TF7k*-4s+zTT6p1t#E+{aP-k-nTMt0q>$|Rcm4zkCng_5T3R5T3;_TU z3X(s6^fQ0YqQCkaCkTYI%N}VAdwcVI2dVRL( z%)dfeCn{ASDd86K4Rs|8>z5%7ppSQ1&r5C9*k5rKzaFH2H+RJJ@nqk8ZJ=**3G$`? z9Wr<6;B66D0N^_s0DuR%;?I)nVhIMjy0HE{u>YvN8CrAE`Fub>BddF;%l77wd2Eg} z{Xl3$v~Y|ss+y3fF%wz*hnwgqRGtnhg)jG6l7(Q8~Fh^6_D<2cJ3-d0siYP3gF3qSfYi>@HKWOxyAN8cvh?29yIIEVV2q>jy zXiF>TftbSLp2y;AI5Ng4>Bvf-Pl|wxIUN_O6I8OW^4XGOEW^8lXl*IuU3#6tl2i)C z!I#OjhKcpMNrDu~DwqB&hL~~KWYw$mTG_AZs%y;dyg$7tXa)HOcuJ&clVbWK9BfP( zGqmB-uSS3h`RE7}i8(hqZc{m3EXU|hZ$+q9QHcjw^l1-RxG$RNz4E1P9C42WRBh+O zx7Op2IhPRDT79m)f*hU(G@ZwI?~2su?>;uuZmppbfL$x)jcJE+N_trmc=v5rx1oYr z%EN&P9AN_Np-ie6&IInKT|A0Kjt5JP#B-_KbbVn`gsI6Ajkwo?3sOK_=E9hra;vK< z*<32Dw?Ob#nk4;7!09hYk0?+lv#g#x=zHI642I@@lxr7I;D!|!jOeS2KZQ;>fp^Q6 zuL!BEU@kv=|v(d!SjJ;=0eVEj@lTkV%8!z>j;C(?^h`WivdJp!6{3tSUMXpo@g0&7Xdw zbO?1CCiHIma_YH?3ZsV-f)~yPYz&NUJmma=fhzQrvD@t?s*gD z{=Ca{)=GN;i9?(5v%g+wjk)G3rA9~2+ z@|}E9Vf(O`5WWEqksWAEW)DT|KtcqY<(Y5eKD7XPH;THR#b38jRvHd{;YMjgO`foW^M<}nO1}TRX7{OsBr7tWD(|V`sS9Bx9 z&s(sjyFA3j`gE5vmnZtjaz_xb!l~5#T*m!OxAnE*m{WXLm&o`)g|S558{68rvEs7S z^jbaqPQcgJfIHr>3g!xJ>3TC8MJLsgDtI1LWAWzn$Q@JU3IRc zL_q?9RFHuOK3fiP3JP*@*mxrM7KaHNnfK|yXM=g?=LN;Mzg}5 zf{Cl1UXpA#RRdqPTfcsqgMVma8>HsRW2zGjJ8yox%_Xt1%$qyevhiXPHg3XAH8F5b zpcX&%nFWMu*|%pYijEaj#Nqjxqe1^us7v zmc{x`C$jYx-Wf{yyuiQ~u;p&ey<*&b^nA<8;52opFI~w`^}VbMOUrQ1)ZXcH=+dAgg_ryOwM%dDJ%<8JU%K$diojHgHU|VCix5czr6A{z zg{yiBbYgsep3Zmiw%ac`5O~5x(#t(<95Dp)r)Pb^60}|Qo7Q241Ge_PICtK+@c&|W^+8@Jo0M##c1e)nno_q%;+pQ%>g zQ*i@p&VKi1VXe|rvkenJX*wYc#>ki^-Ab2~xFQKlbwZ^2^m*Gw9Cm#iJs23?;n*4{ z4w?=s3gL)RXq;2-`qPZ#uG0(wEJ*5x$k`OTh=RDoa1DRVPV|6rdqTD(0C{;;_x#YvqFGsoB^Go5b*_Lnz< zHw&F(@yz54w1+42_~T%Abg>Ayp){1`L`S}(;`PxCOup`UL>-*~T$&`oRKWw9_vSJhDNu{^@z4I_M$*7psjea&`zIcxv^%KfY-Q2nKqtCXsnf4;f|y8 z4%n0p%1u0bt#O5&N#6U z9z+V_d?3K6$N)?8?_sfSUa1I$GZ-@Bso#X|Yv_tKs%*iF1s{2bpLMcD-f7TYwR~|U*}_7xUh-F z=>Kl2q%OZoYb#8pR%BO3tW?N%o8RKmt_NWpDVpx4+c*?=9C zhiFOrGdEoHcCTSazf0>PQ56_aRoH*6!?}wf@{q%ITJ1`3{MzxgYdmM29go&3eUgQ# z)zNDA@FLye8{u43Hg9k~a*0imS@10-IwgaP23xcILHhBh)Mx616v$n#3y>Da2t=M{ z3rE=0uAkT{7K;JMcVH`{ip?wZHf$JV{R$Wuj0@FufiP3b8IEmR_8P)Dn=H~purPE+z z0^c%s?q;g*gUypgR~8wgb`)I2__Td@i<~N&#`*(Ym7Bj%MC4@lolXr;Ep_Ns#v}GN zW^z}BYVIy7veC73wLb2hQt3tv646{?K53h44Bp+!c;27e$bPm|qyE<*-`_o?&r zR#H1ptUg&;<;16ze>t;JhkK_0Mmez~zG*q9AaMr2@yB7I?Xf#ac!T9GUf6*#mj3|LB3Q3V1E$VceH}Sk!`p~sI$w` zSi9?zXNfcyeUp>Jtl7(GThwA7QCC9;1PB0ZFTahbW+~21D>T_P_<=Udf|Ja;hc5lu zJHrkcb*rKR<4+7LXI^kHO9rE_CN~`4R!4~u9@m7X!*^Kb>jaBre9BhDrPWz}(stQ{ zB*W!gH0rw|xpL@R-Ctz5OdLQY2zP$~tYgKc+X<^p>vqY#HuE{MK+p?DNQm~I!xQ!C z_ak`KjvKfCc&7h2)A0$-XEVXb45p*!ANVq|C3U!#Nra9%5A>Jzvq%(#)9`8NES%p= zD&^*4O9pR|#yfh*H-)`}+hYAK6YNOxVnKnZ5DbtZDCp0mvNU!Eo2j`vTiILuky{BW z3f*kLz9qPK68%0&6~W;J_%0H7e87zU_sJuMQ=;AxLmIn{yj^rY?p{y2?mf;|S*%aT z1PQ$n^g$U=UplHiOD?WgbOar72A-K_$h^IPjjKHQ9u*m7+2T{TU6#pqs8wA6h#mNPL-0RoLGHK<7#JPWQ*+wX2Q>GLWX>EHR^QZ{JjKr zKdAC?cIHr7VYf^m&0M)XE0>w&rXR*vB(D$iUp*}AF4R1Vaq6V!blR)fC=+R*NS^yN zxXfG2!Do0CfmneZiWoaqvrYiE-<{#7JVRs&KO$IK@V|Cd>^Mhl8g?_dN2eG!ntJzJ z?+Ms(D0D(v#33G}{KI=cl%78a?|(=?zlQKXg`bd^&d4q{kkmQMorI@P!j~5)5^B5@ zO&S|8CRelYbIGyipj)3>BJlh5b;sd#p7d+MVnTDYIHrjf=of3IxRzB1KC{g4#H#tW z=Rru!g0aTqEFEbuN=ot(eX-iu$ki`YsBZDP{d;DrqmhbZg}=?nMR>6dT)re6DJ!wI zJ2FD3V^yl(XIk^qdd@GTs&ls+?MEtTpa1T&?9p(D+~e41!g;cf`=!1g1)7NMc${U~ z_@TiT*?szWe1`P9qO7~gbaf4R^H&l80PPf>qSsmZ??nIUfh!<2ZOM z*;XG5ybnF%R(~2NSk5)BslYrjs%`V#P6NBhX+&c~Y=ND_8hh1658q9(-1E3nFX_vX zU|L2|jJEZg!#D3h?LMs6RE)K)PFXSn*}H@K&Ar1=OE0^M7J@_pP_V%NznybXa+E_v zPE!kTS`XQ8-3>5lgxm&+k$=lS3%cQxT96h~h+>EVQ4Gx-OqHD-99>vV9h|{`&K>_( zD1-!-Z@i*XHybGATy~!PNZ85W(yE;5HNXuH6H|=LJ?VB3aBS}}z!8s!=Tp|k!7PE9gsLatGrh1 zu_tp!x}#MxxStTOWuwTCm0X2#LYFo5JzySlfVlME(a3Lbf{_M^Mj(V{0skJ2#*U8v z4MvE|{@gO+JMHIw%oM_1p@r^6W!Te&ifXD9FnUi{0pz^3!2CE23c)=D16kl?V)$KD z_t`nu_aymsQAJ8~Y7SHZU#DbLzxvKim%T`NlR=AN zhZtmw5Rv(Dm9jcqSX2^*o)S6vktSs*{)XJt$XhP*nK;$UX}SU~`=&~U{FSktYj~og zaY2#QDk(?mL|S8qe)v9}m;2nz?RytvL)N&6+L14TWuuR*THxnw^uG=P`!I8cU z-b~8b&+p9q1V{SDr1B2%L>rv%M=VI9g0|NYoZf+0nHyk?UNUi<&?1}rnzmC6ItD~&&Pk?`}8~Drc$MFp$N&dEq;GyBexzt~#pCHQHgPGNb z#{Zm1`eh0Lpgj6%{C`XixxaI|7Yd?yZI-wznTA8$CVY}ATbCDCdiKl62=)n(g*NA DxVh?K literal 0 HcmV?d00001 diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py new file mode 100644 index 00000000..ce3e52bc --- /dev/null +++ b/reproduction/text_classification/train_awdlstm.py @@ -0,0 +1,102 @@ +# 这个模型需要在pytorch=0.4下运行,weight_drop不支持1.0 + +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.SSTLoader import SSTLoader +from data.IMDBLoader import IMDBLoader +from data.yelpLoader import yelpLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.awd_lstm import AWDLSTMSentiment + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + nfc=128 + wdrop=0.5 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" + save_model_path="./result_IMDB_test/" +opt=Config + + +# load data +dataloaders = { + "IMDB":IMDBLoader(), + "YELP":yelpLoader(), + "SST-5":SSTLoader(subtree=True,fine_grained=True), + "SST-3":SSTLoader(subtree=True,fine_grained=False) +} + +if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: + raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") + +dataloader = dataloaders[opt.task_name] +datainfo=dataloader.process(opt.datapath) +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=AWDLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc, wdrop=opt.wdrop) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +def test(datainfo, metrics, opt): + # load model + model = ModelLoader.load_pytorch_model(opt.load_model_path) + print("model loaded!") + + # Tester + tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) + acc = tester.test() + print("acc=",acc) + + + +parser = argparse.ArgumentParser() +parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') + + +args = parser.parse_args() +if args.mode == 'train': + train(datainfo, model, optimizer, loss, metrics, opt) +elif args.mode == 'test': + test(datainfo, metrics, opt) +else: + print('no mode specified for model!') + parser.print_help() diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py new file mode 100644 index 00000000..b320e79c --- /dev/null +++ b/reproduction/text_classification/train_lstm.py @@ -0,0 +1,99 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.SSTLoader import SSTLoader +from data.IMDBLoader import IMDBLoader +from data.yelpLoader import yelpLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.lstm import BiLSTMSentiment + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + nfc=128 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" + save_model_path="./result_IMDB_test/" +opt=Config + + +# load data +dataloaders = { + "IMDB":IMDBLoader(), + "YELP":yelpLoader(), + "SST-5":SSTLoader(subtree=True,fine_grained=True), + "SST-3":SSTLoader(subtree=True,fine_grained=False) +} + +if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: + raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") + +dataloader = dataloaders[opt.task_name] +datainfo=dataloader.process(opt.datapath) +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=BiLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +def test(datainfo, metrics, opt): + # load model + model = ModelLoader.load_pytorch_model(opt.load_model_path) + print("model loaded!") + + # Tester + tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) + acc = tester.test() + print("acc=",acc) + + + +parser = argparse.ArgumentParser() +parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') + + +args = parser.parse_args() +if args.mode == 'train': + train(datainfo, model, optimizer, loss, metrics, opt) +elif args.mode == 'test': + test(datainfo, metrics, opt) +else: + print('no mode specified for model!') + parser.print_help() diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py new file mode 100644 index 00000000..8db27d09 --- /dev/null +++ b/reproduction/text_classification/train_lstm_att.py @@ -0,0 +1,101 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' + + +import torch.nn as nn + +from data.SSTLoader import SSTLoader +from data.IMDBLoader import IMDBLoader +from data.yelpLoader import yelpLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from model.lstm_self_attention import BiLSTM_SELF_ATTENTION + +from fastNLP.core.const import Const as C +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import Trainer, Tester +from torch.optim import Adam +from fastNLP.io.model_io import ModelLoader, ModelSaver + +import argparse + + +class Config(): + train_epoch= 10 + lr=0.001 + + num_classes=2 + hidden_dim=256 + num_layers=1 + attention_unit=256 + attention_hops=1 + nfc=128 + + task_name = "IMDB" + datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} + load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" + save_model_path="./result_IMDB_test/" +opt=Config + + +# load data +dataloaders = { + "IMDB":IMDBLoader(), + "YELP":yelpLoader(), + "SST-5":SSTLoader(subtree=True,fine_grained=True), + "SST-3":SSTLoader(subtree=True,fine_grained=False) +} + +if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: + raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") + +dataloader = dataloaders[opt.task_name] +datainfo=dataloader.process(opt.datapath) +# print(datainfo.datasets["train"]) +# print(datainfo) + + +# define model +vocab=datainfo.vocabs['words'] +embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) +model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc) + + +# define loss_function and metrics +loss=CrossEntropyLoss() +metrics=AccuracyMetric() +optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) + + +def train(datainfo, model, optimizer, loss, metrics, opt): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + n_epochs=opt.train_epoch, save_path=opt.save_model_path) + trainer.train() + + +def test(datainfo, metrics, opt): + # load model + model = ModelLoader.load_pytorch_model(opt.load_model_path) + print("model loaded!") + + # Tester + tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) + acc = tester.test() + print("acc=",acc) + + + +parser = argparse.ArgumentParser() +parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') + + +args = parser.parse_args() +if args.mode == 'train': + train(datainfo, model, optimizer, loss, metrics, opt) +elif args.mode == 'test': + test(datainfo, metrics, opt) +else: + print('no mode specified for model!') + parser.print_help() From f6bba93696b8266d27a657eceace16653df3525f Mon Sep 17 00:00:00 2001 From: wyg <1505116161@qq.com> Date: Sun, 7 Jul 2019 14:50:36 +0800 Subject: [PATCH 08/20] [verify] yelpdataloader [add] HAN train_HAN --- .../text_classification/data/sstLoader.py | 106 +++++++++++++++-- .../text_classification/data/yelpLoader.py | 6 + reproduction/text_classification/model/HAN.py | 109 ++++++++++++++++++ reproduction/text_classification/train_HAN.py | 109 ++++++++++++++++++ .../text_classification/train_char_cnn.py | 7 +- 5 files changed, 324 insertions(+), 13 deletions(-) create mode 100644 reproduction/text_classification/model/HAN.py create mode 100644 reproduction/text_classification/train_HAN.py diff --git a/reproduction/text_classification/data/sstLoader.py b/reproduction/text_classification/data/sstLoader.py index bffb67fd..0d1b647c 100644 --- a/reproduction/text_classification/data/sstLoader.py +++ b/reproduction/text_classification/data/sstLoader.py @@ -1,13 +1,101 @@ -import csv from typing import Iterable -from fastNLP import DataSet, Instance, Vocabulary -from fastNLP.core.vocabulary import VocabularyOption -from fastNLP.io.base_loader import DataInfo,DataSetLoader -from fastNLP.io.embed_loader import EmbeddingOption -from fastNLP.io.file_reader import _read_json -from typing import Union, Dict -from reproduction.Star_transformer.datasets import EmbedLoader -from reproduction.utils import check_dataloader_paths +from nltk import Tree +from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.core.vocabulary import VocabularyOption, Vocabulary +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader + + +class SSTLoader(DataSetLoader): + URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' + DATA_DIR = 'sst/' + + """ + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` + + 读取SST数据集, DataSet包含fields:: + + words: list(str) 需要分类的文本 + target: str 文本的标签 + + 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + + :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` + """ + + def __init__(self, subtree=False, fine_grained=False): + self.subtree = subtree + + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', + '3': 'positive', '4': 'very positive'} + if not fine_grained: + tag_v['0'] = tag_v['1'] + tag_v['4'] = tag_v['3'] + self.tag_v = tag_v + + def _load(self, path): + """ + + :param str path: 存储数据的路径 + :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + datas = [] + for l in f: + datas.extend([(s, self.tag_v[t]) + for s, t in self._get_one(l, self.subtree)]) + ds = DataSet() + for words, tag in datas: + ds.append(Instance(words=words, target=tag)) + return ds + + @staticmethod + def _get_one(data, subtree): + tree = Tree.fromstring(data) + if subtree: + return [(t.leaves(), t.label()) for t in tree.subtrees()] + return [(tree.leaves(), tree.label())] + + def process(self, + paths, + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + src_embed_op: EmbeddingOption = None): + input_name, target_name = 'words', 'target' + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + + info = DataInfo(datasets=self.load(paths)) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + src_vocab.from_dataset(*_train_ds, field_name=input_name) + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + src_vocab.index_dataset( + *info.datasets.values(), + field_name=input_name, new_field_name=input_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs = { + input_name: src_vocab, + target_name: tgt_vocab + } + + if src_embed_op is not None: + src_embed_op.vocab = src_vocab + init_emb = EmbedLoader.load_with_vocab(**src_embed_op) + info.embeddings[input_name] = init_emb + + for name, dataset in info.datasets.items(): + dataset.set_input(input_name) + dataset.set_target(target_name) + + return info class sst2Loader(DataSetLoader): ''' diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index 0e65fb20..280e8be0 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -184,6 +184,12 @@ class yelpLoader(DataSetLoader): info.vocabs[target_name]=tgt_vocab + info.datasets['train'],info.datasets['dev']=info.datasets['train'].split(0.1, shuffle=False) + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + return info if __name__=="__main__": diff --git a/reproduction/text_classification/model/HAN.py b/reproduction/text_classification/model/HAN.py new file mode 100644 index 00000000..0902d1e4 --- /dev/null +++ b/reproduction/text_classification/model/HAN.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +from fastNLP.modules.utils import get_embeddings +from fastNLP.core import Const as C + + +def pack_sequence(tensor_seq, padding_value=0.0): + if len(tensor_seq) <= 0: + return + length = [v.size(0) for v in tensor_seq] + max_len = max(length) + size = [len(tensor_seq), max_len] + size.extend(list(tensor_seq[0].size()[1:])) + ans = torch.Tensor(*size).fill_(padding_value) + if tensor_seq[0].data.is_cuda: + ans = ans.cuda() + ans = Variable(ans) + for i, v in enumerate(tensor_seq): + ans[i, :length[i], :] = v + return ans + + +class HANCLS(nn.Module): + def __init__(self, init_embed, num_cls): + super(HANCLS, self).__init__() + + self.embed = get_embeddings(init_embed) + self.han = HAN(input_size=300, + output_size=num_cls, + word_hidden_size=50, word_num_layers=1, word_context_size=100, + sent_hidden_size=50, sent_num_layers=1, sent_context_size=100 + ) + + def forward(self, input_sents): + # input_sents [B, num_sents, seq-len] dtype long + # target + B, num_sents, seq_len = input_sents.size() + input_sents = input_sents.view(-1, seq_len) # flat + words_embed = self.embed(input_sents) # should be [B*num-sent, seqlen , word-dim] + words_embed = words_embed.view(B, num_sents, seq_len, -1) # recover # [B, num-sent, seqlen , word-dim] + out = self.han(words_embed) + + return {C.OUTPUT: out} + + def predict(self, input_sents): + x = self.forward(input_sents)[C.OUTPUT] + return {C.OUTPUT: torch.argmax(x, 1)} + + +class HAN(nn.Module): + def __init__(self, input_size, output_size, + word_hidden_size, word_num_layers, word_context_size, + sent_hidden_size, sent_num_layers, sent_context_size): + super(HAN, self).__init__() + + self.word_layer = AttentionNet(input_size, + word_hidden_size, + word_num_layers, + word_context_size) + self.sent_layer = AttentionNet(2 * word_hidden_size, + sent_hidden_size, + sent_num_layers, + sent_context_size) + self.output_layer = nn.Linear(2 * sent_hidden_size, output_size) + self.softmax = nn.LogSoftmax(dim=1) + + def forward(self, batch_doc): + # input is a sequence of matrix + doc_vec_list = [] + for doc in batch_doc: + sent_mat = self.word_layer(doc) # doc's dim (num_sent, seq_len, word_dim) + doc_vec_list.append(sent_mat) # sent_mat's dim (num_sent, vec_dim) + doc_vec = self.sent_layer(pack_sequence(doc_vec_list)) + output = self.softmax(self.output_layer(doc_vec)) + return output + + +class AttentionNet(nn.Module): + def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size): + super(AttentionNet, self).__init__() + + self.input_size = input_size + self.gru_hidden_size = gru_hidden_size + self.gru_num_layers = gru_num_layers + self.context_vec_size = context_vec_size + + # Encoder + self.gru = nn.GRU(input_size=input_size, + hidden_size=gru_hidden_size, + num_layers=gru_num_layers, + batch_first=True, + bidirectional=True) + # Attention + self.fc = nn.Linear(2 * gru_hidden_size, context_vec_size) + self.tanh = nn.Tanh() + self.softmax = nn.Softmax(dim=1) + # context vector + self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1)) + self.context_vec.data.uniform_(-0.1, 0.1) + + def forward(self, inputs): + # GRU part + h_t, hidden = self.gru(inputs) # inputs's dim (batch_size, seq_len, word_dim) + u = self.tanh(self.fc(h_t)) + # Attention part + alpha = self.softmax(torch.matmul(u, self.context_vec)) # u's dim (batch_size, seq_len, context_vec_size) + output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) # alpha's dim (batch_size, seq_len, 1) + return torch.squeeze(output, dim=2) # output's dim (batch_size, 2*hidden_size, 1) diff --git a/reproduction/text_classification/train_HAN.py b/reproduction/text_classification/train_HAN.py new file mode 100644 index 00000000..b1135342 --- /dev/null +++ b/reproduction/text_classification/train_HAN.py @@ -0,0 +1,109 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 + +import os +import sys +sys.path.append('../../') +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +from fastNLP.core.const import Const as C +from fastNLP.core import LRScheduler +import torch.nn as nn +from fastNLP.io.dataset_loader import SSTLoader +from reproduction.text_classification.data.yelpLoader import yelpLoader +from reproduction.text_classification.model.HAN import HANCLS +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +import torch.cuda +from torch.optim.lr_scheduler import CosineAnnealingLR + + +##hyper + +class Config(): + model_dir_or_name = "en-base-uncased" + embedding_grad = False, + train_epoch = 30 + batch_size = 100 + num_classes = 5 + task = "yelp" + #datadir = '/remote-home/lyli/fastNLP/yelp_polarity/' + datadir = '/remote-home/ygwang/yelp_polarity/' + datafile = {"train": "train.csv", "test": "test.csv"} + lr = 1e-3 + + def __init__(self): + self.datapath = {k: os.path.join(self.datadir, v) + for k, v in self.datafile.items()} + + +ops = Config() + +##1.task相关信息:利用dataloader载入dataInfo + +datainfo = yelpLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['test'])) + + +# post process +def make_sents(words): + sents = [words] + return sents + + +for dataset in datainfo.datasets.values(): + dataset.apply_field(make_sents, field_name='words', new_field_name='input_sents') + +datainfo = datainfo +datainfo.datasets['train'].set_input('input_sents') +datainfo.datasets['test'].set_input('input_sents') +datainfo.datasets['train'].set_target('target') +datainfo.datasets['test'].set_target('target') + +## 2.或直接复用fastNLP的模型 + +vocab = datainfo.vocabs['words'] +# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) +embedding = StaticEmbedding(vocab) + +print(len(vocab)) +print(len(datainfo.vocabs['target'])) + +# model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) +model = HANCLS(init_embed=embedding, num_cls=ops.num_classes) + +## 3. 声明loss,metric,optimizer +loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], + lr=ops.lr, momentum=0.9, weight_decay=0) + +callbacks = [] +callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +print(device) + +for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + + +## 4.定义train方法 +def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metrics], dev_data=datainfo.datasets['test'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=num_epochs) + + print(trainer.train()) + + +if __name__ == "__main__": + train(model, datainfo, loss, metric, optimizer) diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index c2c983a4..050527fe 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -7,7 +7,6 @@ import sys sys.path.append('../..') from fastNLP.core.const import Const as C import torch.nn as nn -from fastNLP.io.dataset_loader import SSTLoader from data.yelpLoader import yelpLoader from data.sstLoader import sst2Loader from data.IMDBLoader import IMDBLoader @@ -107,9 +106,9 @@ ops=Config ##1.task相关信息:利用dataloader载入dataInfo -dataloader=sst2Loader() -dataloader=IMDBLoader() -#dataloader=yelpLoader(fine_grained=True) +#dataloader=sst2Loader() +#dataloader=IMDBLoader() +dataloader=yelpLoader(fine_grained=True) datainfo=dataloader.process(ops.datapath,char_level_op=True) char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] ops.number_of_characters=len(char_vocab) From f369778ab33aa91447821db0de37b1bce4be4b62 Mon Sep 17 00:00:00 2001 From: wyg <1505116161@qq.com> Date: Sun, 7 Jul 2019 15:33:08 +0800 Subject: [PATCH 09/20] [verify] sstdataloader add sst2 [add] readme --- reproduction/text_classification/README.md | 22 +++++++++++++++++++ .../text_classification/data/sstLoader.py | 3 ++- 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 reproduction/text_classification/README.md diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md new file mode 100644 index 00000000..b058fbb2 --- /dev/null +++ b/reproduction/text_classification/README.md @@ -0,0 +1,22 @@ +# text_classification任务模型复现 +这里使用fastNLP复现以下模型: +char_cnn :论文链接[Character-level Convolutional Networks for Text Classification](https://arxiv.org/pdf/1509.01626v3.pdf) +dpcnn:论文链接[Deep Pyramid Convolutional Neural Networks for TextCategorization](https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf) +HAN:论文链接[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) +#待补充 +awd_lstm: +lstm_self_attention(BCN?): +awd-sltm: + +# 数据集及复现结果汇总 + +使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) + +model name | yelp_p | sst-2|IMDB| +:---: | :---: | :---: | :---: +char_cnn | 93.80/95.12 | - |- | +dpcnn | 95.50/97.36 | - |- | +HAN |- | - |-| +BCN| - |- |-| +awd-lstm| - |- |-| + diff --git a/reproduction/text_classification/data/sstLoader.py b/reproduction/text_classification/data/sstLoader.py index 0d1b647c..d8403b7a 100644 --- a/reproduction/text_classification/data/sstLoader.py +++ b/reproduction/text_classification/data/sstLoader.py @@ -5,7 +5,8 @@ from fastNLP.core.vocabulary import VocabularyOption, Vocabulary from fastNLP import DataSet from fastNLP import Instance from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader - +import csv +from typing import Union, Dict class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' From d4fa6986019e9ea857fb10591d7c1d647251033b Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 27 Jun 2019 21:56:49 +0800 Subject: [PATCH 10/20] add dpcnn --- .../text_classification/model/dpcnn.py | 92 ++++++++++++++++++- .../text_classification/train_dpcnn.py | 80 ++++++++++++++++ 2 files changed, 171 insertions(+), 1 deletion(-) diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index f87f5c14..2da7b3e5 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -1 +1,91 @@ -# TODO \ No newline at end of file +import torch +import torch.nn as nn +from fastNLP.modules.utils import get_embeddings +from fastNLP.core import Const as C + +class DPCNN(nn.Module): + def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + super().__init__() + self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + embed_dim = self.region_embed.embedding_dim + self.conv_list = nn.ModuleList() + for i in range(n_layers): + self.conv_list.append(nn.Sequential( + nn.ReLU(), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + )) + self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + self.embed_drop = nn.Dropout(embed_dropout) + self.classfier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(n_filters, num_cls), + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + def forward(self, words, seq_len=None): + words = words.long() + # get region embeddings + x = self.region_embed(words) + x = self.embed_drop(x) + + # not pooling on first conv + x = self.conv_list[0](x) + x + for conv in self.conv_list[1:]: + x = self.pool(x) + x = conv(x) + x + + # B, C, L => B, C + x, _ = torch.max(x, dim=2) + x = self.classfier(x) + return {C.OUTPUT: x} + + def predict(self, words, seq_len=None): + x = self.forward(words, seq_len)[C.OUTPUT] + return {C.OUTPUT: torch.argmax(x, 1)} + + +class RegionEmbedding(nn.Module): + def __init__(self, init_embed, out_dim=300, kernel_sizes=None): + super().__init__() + if kernel_sizes is None: + kernel_sizes = [5, 9] + assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + self.embed = get_embeddings(init_embed) + try: + embed_dim = self.embed.embedding_dim + except Exception: + embed_dim = self.embed.embed_size + self.region_embeds = nn.ModuleList() + for ksz in kernel_sizes: + self.region_embeds.append(nn.Sequential( + nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), + )) + self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) + for _ in range(len(kernel_sizes) + 1)]) + self.embedding_dim = embed_dim + + def forward(self, x): + x = self.embed(x) + x = x.transpose(1, 2) + # B, C, L + out = self.linears[0](x) + for conv, fc in zip(self.region_embeds, self.linears[1:]): + conv_i = conv(x) + out = out + fc(conv_i) + # B, C, L + return out + + +if __name__ == '__main__': + x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long) + model = DPCNN((10000, 300), 20) + y = model(x) + print(y.size(), y.mean(1), y.std(1)) diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index e69de29b..13ff4fc1 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -0,0 +1,80 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +from fastNLP.core.const import Const as C +from fastNLP.core import LRScheduler +import torch.nn as nn +from fastNLP.io.dataset_loader import SSTLoader +from reproduction.text_classification.model.dpcnn import DPCNN +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +import torch.cuda +from torch.optim.lr_scheduler import CosineAnnealingLR + +##hyper +class Config(): + model_dir_or_name="en-base-uncased" + embedding_grad= False, + train_epoch= 30 + batch_size = 100 + num_classes=5 + task= "SST" + datadir = '/remote-home/yfshao/workdir/datasets/SST' + datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + lr=1e-3 + def __init__(self): + self.datapath = {k:os.path.join(self.datadir, v) + for k, v in self.datafile.items()} + +ops=Config() + + +##1.task相关信息:利用dataloader载入dataInfo +datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds='train') + +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['dev'])) + + +## 2.或直接复用fastNLP的模型 +vocab = datainfo.vocabs['words'] + +# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) +embedding = StaticEmbedding(vocab) +print(len(vocab)) +print(len(datainfo.vocabs['target'])) +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) + +## 3. 声明loss,metric,optimizer +loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], + lr=ops.lr, momentum=0.9, weight_decay=0) + +callbacks = [] +callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' +print(device) + +for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + +## 4.定义train方法 +def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metrics], dev_data=datainfo.datasets['dev'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=num_epochs) + print(trainer.train()) + + +if __name__=="__main__": + train(model,datainfo,loss,metric,optimizer) \ No newline at end of file From 4e3fba55d884250a65463a564c4730a4275b3819 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 4 Jul 2019 13:56:37 +0800 Subject: [PATCH 11/20] add ID-CNN --- .../ner/model/dilated_cnn.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 reproduction/seqence_labelling/ner/model/dilated_cnn.py diff --git a/reproduction/seqence_labelling/ner/model/dilated_cnn.py b/reproduction/seqence_labelling/ner/model/dilated_cnn.py new file mode 100644 index 00000000..cd2fa64b --- /dev/null +++ b/reproduction/seqence_labelling/ner/model/dilated_cnn.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from fastNLP.modules.decoder import ConditionalRandomField +from fastNLP.modules.encoder import Embedding +from fastNLP.core.utils import seq_len_to_mask +from fastNLP.core.const import Const as C + + +class IDCNN(nn.Module): + def __init__(self, init_embed, char_embed, + num_cls, + repeats, num_layers, num_filters, kernel_size, + use_crf=False, use_projection=False, block_loss=False, + input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): + super(IDCNN, self).__init__() + self.word_embeddings = Embedding(init_embed) + self.char_embeddings = Embedding(char_embed) + embedding_size = self.word_embeddings.embedding_dim + \ + self.char_embeddings.embedding_dim + + self.conv0 = nn.Sequential( + nn.Conv1d(in_channels=embedding_size, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=1, + padding=kernel_size//2, + bias=True), + nn.ReLU(), + ) + + block = [] + for layer_i in range(num_layers): + dilated = 2 ** layer_i + block.append(nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=dilated, + padding=(kernel_size//2) * dilated, + bias=True)) + block.append(nn.ReLU()) + self.block = nn.Sequential(*block) + + if use_projection: + self.projection = nn.Sequential( + nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters//2, + kernel_size=1, + bias=True), + nn.ReLU(),) + encode_dim = num_filters // 2 + else: + self.projection = None + encode_dim = num_filters + + self.input_drop = nn.Dropout(input_dropout) + self.hidden_drop = nn.Dropout(hidden_dropout) + self.inner_drop = nn.Dropout(inner_dropout) + self.repeats = repeats + self.out_fc = nn.Conv1d( + in_channels=encode_dim, + out_channels=num_cls, + kernel_size=1, + bias=True) + self.crf = ConditionalRandomField( + num_tags=num_cls) if use_crf else None + self.block_loss = block_loss + + def forward(self, words, chars, seq_len, target=None): + e1 = self.word_embeddings(words) + e2 = self.char_embeddings(chars) + x = torch.cat((e1, e2), dim=-1) # b,l,h + mask = seq_len_to_mask(seq_len) + + x = x.transpose(1, 2) # b,h,l + last_output = self.conv0(x) + output = [] + for repeat in range(self.repeats): + last_output = self.block(last_output) + hidden = self.projection(last_output) if self.projection is not None else last_output + output.append(self.out_fc(hidden)) + + def compute_loss(y, t, mask): + if self.crf is not None and target is not None: + loss = self.crf(y, t, mask) + else: + t.masked_fill_(mask == 0, -100) + loss = F.cross_entropy(y, t, ignore_index=-100) + return loss + + if self.block_loss: + losses = [compute_loss(o, target, mask) for o in output] + loss = sum(losses) + else: + loss = compute_loss(output[-1], target, mask) + + scores = output[-1] + if self.crf is not None: + pred = self.crf.viterbi_decode(scores, target, mask) + else: + pred = scores.max(1)[1] * mask.long() + + return { + C.LOSS: loss, + C.OUTPUT: pred, + } + + def predict(self, words, chars, seq_len): + return self.forward(words, chars, seq_len)[C.OUTPUT] From 4b5713b5a2e2e2dcedf2b757659295016db6a9ce Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 4 Jul 2019 14:03:53 +0800 Subject: [PATCH 12/20] update model & dataloader in text_classification --- .../text_classification/data/IMDBLoader.py | 82 +++++++++ .../text_classification/data/yelpLoader.py | 164 +++++++++++++++--- .../text_classification/train_dpcnn.py | 97 +++++++---- 3 files changed, 283 insertions(+), 60 deletions(-) create mode 100644 reproduction/text_classification/data/IMDBLoader.py diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py new file mode 100644 index 00000000..2df87e26 --- /dev/null +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -0,0 +1,82 @@ +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial + + +class IMDBLoader(DataSetLoader): + """ + 读取IMDB数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1].split() + dataset.append(Instance(words=words, target=target)) + if len(dataset) == 0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None): + + # paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + # src_vocab.from_dataset(datasets['train'], datasets["dev"], datasets["test"], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index c47d48fd..63605ecf 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -1,4 +1,6 @@ import ast +import csv +from typing import Iterable from fastNLP import DataSet, Instance, Vocabulary from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io import JsonLoader @@ -10,11 +12,34 @@ from reproduction.Star_transformer.datasets import EmbedLoader from reproduction.utils import check_dataloader_paths +def clean_str(sentence, char_lower=False): + """ + heavily borrowed from github + https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb + :param sentence: is a str + :return: + """ + if char_lower: + sentence = sentence.lower() + import re + nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') + words = sentence.split() + words_collection = [] + for word in words: + if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: + continue + tt = nonalpnum.split(word) + t = ''.join(tt) + if t != '': + words_collection.append(t) + + return words_collection + + class yelpLoader(JsonLoader): - """ 读取Yelp数据集, DataSet包含fields: - + review_id: str, 22 character unique review id user_id: str, 22 character unique user id business_id: str, 22 character business id @@ -24,23 +49,25 @@ class yelpLoader(JsonLoader): date: str, date formatted YYYY-MM-DD words: list(str), 需要分类的文本 target: str, 文本的标签 - + 数据来源: https://www.yelp.com/dataset/download - + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ - - def __init__(self, fine_grained=False): + + def __init__(self, fine_grained=False, lower=False): super(yelpLoader, self).__init__() tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', - '4.0': 'positive', '5.0': 'very positive'} + '4.0': 'positive', '5.0': 'very positive'} if not fine_grained: tag_v['1.0'] = tag_v['2.0'] tag_v['5.0'] = tag_v['4.0'] self.fine_grained = fine_grained self.tag_v = tag_v - - def _load(self, path): + self.lower = lower + + ''' + def _load_json(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): d = ast.literal_eval(d) @@ -49,20 +76,113 @@ class yelpLoader(JsonLoader): ds.append(Instance(**d)) return ds - def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, - embed_opt: EmbeddingOption = None): + def _load_yelp2015_broken(self,path): + ds = DataSet() + with open (path,encoding='ISO 8859-1') as f: + row=f.readline() + all_count=0 + exp_count=0 + while row: + row=row.split("\t\t") + all_count+=1 + if len(row)>=3: + words=row[-1].split() + try: + target=self.tag_v[str(row[-2])+".0"] + ds.append(Instance(words=words, target=target)) + except KeyError: + exp_count+=1 + else: + exp_count+=1 + row = f.readline() + print("error sample count:",exp_count) + print("all count:",all_count) + return ds + ''' + + def _load(self, path): + ds = DataSet() + csv_reader = csv.reader(open(path, encoding='utf-8')) + all_count = 0 + real_count = 0 + for row in csv_reader: + all_count += 1 + if len(row) == 2: + target = self.tag_v[row[0] + ".0"] + words = clean_str(row[1], self.lower) + if len(words) != 0: + ds.append(Instance(words=words, target=target)) + real_count += 1 + print("all count:", all_count) + print("real count:", real_count) + return ds + + def process(self, paths: Union[str, Dict[str, str]], + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + embed_opt: EmbeddingOption = None, + char_level_op=False): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() - vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) - for name, path in paths.items(): - dataset = self.load(path) - datasets[name] = dataset - vocab.from_dataset(dataset, field_name="words") - info.vocabs = vocab - info.datasets = datasets - if embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) - info.embeddings['words'] = embed + info = DataInfo(datasets=self.load(paths)) + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + + # vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) + # for name, path in paths.items(): + # dataset = self.load(path) + # datasets[name] = dataset + # vocab.from_dataset(dataset, field_name="words") + # info.vocabs = vocab + # info.datasets = datasets + + def wordtochar(words): + chars = [] + for word in words: + word = word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs = {} + # 就分隔为char形式 + if char_level_op: + for dataset in info.datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + # if embed_opt is not None: + # embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) + # info.embeddings['words'] = embed + else: + src_vocab.from_dataset(*_train_ds, field_name=input_name) + src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) + info.vocabs[input_name] = src_vocab + + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs[target_name] = tgt_vocab + return info + +if __name__ == "__main__": + testloader = yelpLoader() + # datapath = {"train": "/remote-home/ygwang/yelp_full/train.csv", + # "test": "/remote-home/ygwang/yelp_full/test.csv"} + # datapath={"train": "/remote-home/ygwang/yelp_full/test.csv"} + datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", + "test": "/remote-home/ygwang/yelp_polarity/test.csv"} + datainfo = testloader.process(datapath, char_level_op=True) + + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index 13ff4fc1..bf243ffb 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -1,65 +1,83 @@ # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 + +from torch.optim.lr_scheduler import CosineAnnealingLR +import torch.cuda +from torch.optim import SGD +from fastNLP.core.trainer import Trainer +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from reproduction.text_classification.model.dpcnn import DPCNN +from .data.yelpLoader import yelpLoader +from fastNLP.io.dataset_loader import SSTLoader +import torch.nn as nn +from fastNLP.core import LRScheduler +from fastNLP.core.const import Const as C +import sys import os os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from fastNLP.core.const import Const as C -from fastNLP.core import LRScheduler -import torch.nn as nn -from fastNLP.io.dataset_loader import SSTLoader -from reproduction.text_classification.model.dpcnn import DPCNN -from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding -from fastNLP import CrossEntropyLoss, AccuracyMetric -from fastNLP.core.trainer import Trainer -from torch.optim import SGD -import torch.cuda -from torch.optim.lr_scheduler import CosineAnnealingLR +sys.path.append('../..') + + +# hyper -##hyper class Config(): - model_dir_or_name="en-base-uncased" - embedding_grad= False, - train_epoch= 30 + model_dir_or_name = "en-base-uncased" + embedding_grad = False, + train_epoch = 30 batch_size = 100 - num_classes=5 - task= "SST" - datadir = '/remote-home/yfshao/workdir/datasets/SST' - datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} - lr=1e-3 + num_classes = 2 + task = "yelp_p" + #datadir = '/remote-home/yfshao/workdir/datasets/SST' + datadir = '/remote-home/ygwang/yelp_polarity' + #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + datafile = {"train": "train.csv", "test": "test.csv"} + lr = 1e-3 + def __init__(self): - self.datapath = {k:os.path.join(self.datadir, v) + self.datapath = {k: os.path.join(self.datadir, v) for k, v in self.datafile.items()} -ops=Config() + +ops = Config() -##1.task相关信息:利用dataloader载入dataInfo -datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds='train') +# 1.task相关信息:利用dataloader载入dataInfo +#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) +datainfo = yelpLoader(fine_grained=True, lower=True).process( + paths=ops.datapath, train_ds=['train']) print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['dev'])) +print(len(datainfo.datasets['test'])) -## 2.或直接复用fastNLP的模型 -vocab = datainfo.vocabs['words'] +# 2.或直接复用fastNLP的模型 +vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) -embedding = StaticEmbedding(vocab) +#embedding = StaticEmbedding(vocab) +embedding = StaticEmbedding( + vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) + print(len(vocab)) print(len(datainfo.vocabs['target'])) + model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) -## 3. 声明loss,metric,optimizer -loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) -metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) -optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], - lr=ops.lr, momentum=0.9, weight_decay=0) + +# 3. 声明loss,metric,optimizer +loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], + lr=ops.lr, momentum=0.9, weight_decay=0) callbacks = [] callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + print(device) for ds in datainfo.datasets.values(): @@ -67,14 +85,17 @@ for ds in datainfo.datasets.values(): ds.set_input(C.INPUT, C.INPUT_LEN) ds.set_target(C.TARGET) -## 4.定义train方法 -def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): + +# 4.定义train方法 +def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=[metrics], dev_data=datainfo.datasets['dev'], device=device, + metrics=[metrics], + dev_data=datainfo.datasets['test'], device=device, check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, n_epochs=num_epochs) + print(trainer.train()) -if __name__=="__main__": - train(model,datainfo,loss,metric,optimizer) \ No newline at end of file +if __name__ == "__main__": + train(model, datainfo, loss, metric, optimizer) From 4272778c9ae6be70204d77261e35ceebd8e8f31c Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 6 Jul 2019 13:15:40 +0800 Subject: [PATCH 13/20] -update DPCNN & train script -use spacy tokenizer for yelp data -add set_rng_seed --- fastNLP/modules/aggregator/attention.py | 9 +- .../text_classification/data/yelpLoader.py | 18 +++- .../text_classification/model/dpcnn.py | 22 +++-- .../text_classification/train_dpcnn.py | 83 ++++++++++++------- .../text_classification/utils/util_init.py | 11 +++ 5 files changed, 94 insertions(+), 49 deletions(-) create mode 100644 reproduction/text_classification/utils/util_init.py diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 4101b033..2bee7f2e 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -19,7 +19,7 @@ class DotAttention(nn.Module): 补上文档 """ - def __init__(self, key_size, value_size, dropout=0): + def __init__(self, key_size, value_size, dropout=0.0): super(DotAttention, self).__init__() self.key_size = key_size self.value_size = value_size @@ -37,7 +37,7 @@ class DotAttention(nn.Module): """ output = torch.matmul(Q, K.transpose(1, 2)) / self.scale if mask_out is not None: - output.masked_fill_(mask_out, -1e8) + output.masked_fill_(mask_out, -1e18) output = self.softmax(output) output = self.drop(output) return torch.matmul(output, V) @@ -67,9 +67,8 @@ class MultiHeadAttention(nn.Module): self.k_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size) # follow the paper, do not apply dropout within dot-product - self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) + self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) self.out = nn.Linear(value_size * num_head, input_size) - self.drop = TimestepDropout(dropout) self.reset_parameters() def reset_parameters(self): @@ -105,7 +104,7 @@ class MultiHeadAttention(nn.Module): # concat all heads, do output linear atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) - output = self.drop(self.out(atte)) + output = self.out(atte) return output diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index 63605ecf..d97f9399 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -8,11 +8,20 @@ from fastNLP.io.base_loader import DataInfo from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict -from reproduction.Star_transformer.datasets import EmbedLoader from reproduction.utils import check_dataloader_paths -def clean_str(sentence, char_lower=False): +def get_tokenizer(): + try: + import spacy + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split() + +def clean_str(sentence, tokenizer, char_lower=False): """ heavily borrowed from github https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb @@ -23,7 +32,7 @@ def clean_str(sentence, char_lower=False): sentence = sentence.lower() import re nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') - words = sentence.split() + words = tokenizer(sentence) words_collection = [] for word in words: if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: @@ -65,6 +74,7 @@ class yelpLoader(JsonLoader): self.fine_grained = fine_grained self.tag_v = tag_v self.lower = lower + self.tokenizer = get_tokenizer() ''' def _load_json(self, path): @@ -109,7 +119,7 @@ class yelpLoader(JsonLoader): all_count += 1 if len(row) == 2: target = self.tag_v[row[0] + ".0"] - words = clean_str(row[1], self.lower) + words = clean_str(row[1], self.tokenizer, self.lower) if len(words) != 0: ds.append(Instance(words=words, target=target)) real_count += 1 diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index 2da7b3e5..dafe62bc 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -3,22 +3,27 @@ import torch.nn as nn from fastNLP.modules.utils import get_embeddings from fastNLP.core import Const as C + class DPCNN(nn.Module): - def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + def __init__(self, init_embed, num_cls, n_filters=256, + kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1): super().__init__() - self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + self.region_embed = RegionEmbedding( + init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5]) embed_dim = self.region_embed.embedding_dim self.conv_list = nn.ModuleList() for i in range(n_layers): self.conv_list.append(nn.Sequential( nn.ReLU(), - nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), - nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), )) self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.embed_drop = nn.Dropout(embed_dropout) self.classfier = nn.Sequential( - nn.Dropout(dropout), + nn.Dropout(cls_dropout), nn.Linear(n_filters, num_cls), ) self.reset_parameters() @@ -57,7 +62,8 @@ class RegionEmbedding(nn.Module): super().__init__() if kernel_sizes is None: kernel_sizes = [5, 9] - assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + assert isinstance( + kernel_sizes, list), 'kernel_sizes should be List(int)' self.embed = get_embeddings(init_embed) try: embed_dim = self.embed.embedding_dim @@ -69,14 +75,14 @@ class RegionEmbedding(nn.Module): nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), )) self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) - for _ in range(len(kernel_sizes) + 1)]) + for _ in range(len(kernel_sizes))]) self.embedding_dim = embed_dim def forward(self, x): x = self.embed(x) x = x.transpose(1, 2) # B, C, L - out = self.linears[0](x) + out = 0 for conv, fc in zip(self.region_embeds, self.linears[1:]): conv_i = conv(x) out = out + fc(conv_i) diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index bf243ffb..294a0742 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -1,40 +1,44 @@ # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 -from torch.optim.lr_scheduler import CosineAnnealingLR import torch.cuda +from fastNLP.core.utils import cache_results from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from fastNLP.core.trainer import Trainer from fastNLP import CrossEntropyLoss, AccuracyMetric from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding from reproduction.text_classification.model.dpcnn import DPCNN -from .data.yelpLoader import yelpLoader -from fastNLP.io.dataset_loader import SSTLoader +from data.yelpLoader import yelpLoader import torch.nn as nn from fastNLP.core import LRScheduler from fastNLP.core.const import Const as C -import sys +from fastNLP.core.vocabulary import VocabularyOption +from utils.util_init import set_rng_seeds import os os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -sys.path.append('../..') - # hyper class Config(): - model_dir_or_name = "en-base-uncased" - embedding_grad = False, + seed = 12345 + model_dir_or_name = "dpcnn-yelp-p" + embedding_grad = True train_epoch = 30 batch_size = 100 num_classes = 2 task = "yelp_p" #datadir = '/remote-home/yfshao/workdir/datasets/SST' - datadir = '/remote-home/ygwang/yelp_polarity' + datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity' #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} datafile = {"train": "train.csv", "test": "test.csv"} lr = 1e-3 + src_vocab_op = VocabularyOption() + embed_dropout = 0.3 + cls_dropout = 0.1 + weight_decay = 1e-4 def __init__(self): self.datapath = {k: os.path.join(self.datadir, v) @@ -43,15 +47,23 @@ class Config(): ops = Config() +set_rng_seeds(ops.seed) +print('RNG SEED: {}'.format(ops.seed)) # 1.task相关信息:利用dataloader载入dataInfo #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) -datainfo = yelpLoader(fine_grained=True, lower=True).process( - paths=ops.datapath, train_ds=['train']) -print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['test'])) - +@cache_results(ops.model_dir_or_name+'-data-cache') +def load_data(): + datainfo = yelpLoader(fine_grained=True, lower=True).process( + paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op) + for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + return datainfo + +datainfo = load_data() # 2.或直接复用fastNLP的模型 @@ -59,43 +71,50 @@ vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) #embedding = StaticEmbedding(vocab) embedding = StaticEmbedding( - vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) + vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad, + normalize=False +) + +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['test'])) +print(datainfo.datasets['train'][0]) print(len(vocab)) print(len(datainfo.vocabs['target'])) -model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) - +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes, + embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) +print(model) # 3. 声明loss,metric,optimizer loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], - lr=ops.lr, momentum=0.9, weight_decay=0) + lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) callbacks = [] callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) +# callbacks.append +# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < +# ops.train_epoch * 0.8 else ops.lr * 0.1)) +# ) + +# callbacks.append( +# FitlogCallback(data=datainfo.datasets, verbose=1) +# ) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print(device) -for ds in datainfo.datasets.values(): - ds.apply_field(len, C.INPUT, C.INPUT_LEN) - ds.set_input(C.INPUT, C.INPUT_LEN) - ds.set_target(C.TARGET) - - # 4.定义train方法 -def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): - trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=[metrics], - dev_data=datainfo.datasets['test'], device=device, - check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, - n_epochs=num_epochs) +trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metric], + dev_data=datainfo.datasets['test'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=ops.train_epoch, num_workers=4) - print(trainer.train()) if __name__ == "__main__": - train(model, datainfo, loss, metric, optimizer) + print(trainer.train()) diff --git a/reproduction/text_classification/utils/util_init.py b/reproduction/text_classification/utils/util_init.py new file mode 100644 index 00000000..fcb8fffb --- /dev/null +++ b/reproduction/text_classification/utils/util_init.py @@ -0,0 +1,11 @@ +import numpy +import torch +import random + + +def set_rng_seeds(seed): + random.seed(seed) + numpy.random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # print('RNG_SEED {}'.format(seed)) From 451af5318753e854deabd73ce9b6877a55480bbd Mon Sep 17 00:00:00 2001 From: yunfan Date: Sun, 7 Jul 2019 16:07:11 +0800 Subject: [PATCH 14/20] - update dpcnn - add train_idcnn - update sst loader --- fastNLP/io/data_loader/sst.py | 44 +++++---- fastNLP/io/utils.py | 69 +++++++++++++ .../ner/model/dilated_cnn.py | 65 ++++++++---- .../seqence_labelling/ner/train_idcnn.py | 99 +++++++++++++++++++ .../text_classification/data/yelpLoader.py | 13 +-- .../text_classification/train_dpcnn.py | 48 ++++----- reproduction/utils.py | 11 ++- 7 files changed, 274 insertions(+), 75 deletions(-) create mode 100644 fastNLP/io/utils.py create mode 100644 reproduction/seqence_labelling/ner/train_idcnn.py diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index 021a79b7..8d0d005f 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -5,10 +5,8 @@ from ..base_loader import DataInfo, DataSetLoader from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet from ...core.instance import Instance -from ..embed_loader import EmbeddingOption, EmbedLoader +from ..utils import check_dataloader_paths, get_tokenizer -spacy.prefer_gpu() -sptk = spacy.load('en') class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' @@ -37,6 +35,7 @@ class SSTLoader(DataSetLoader): tag_v['0'] = tag_v['1'] tag_v['4'] = tag_v['3'] self.tag_v = tag_v + self.tokenizer = get_tokenizer() def _load(self, path): """ @@ -55,29 +54,37 @@ class SSTLoader(DataSetLoader): ds.append(Instance(words=words, target=tag)) return ds - @staticmethod - def _get_one(data, subtree): + def _get_one(self, data, subtree): tree = Tree.fromstring(data) if subtree: - return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] - return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())] + return [([x.text for x in self.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] + return [([x.text for x in self.tokenizer(' '.join(tree.leaves()))], tree.label())] def process(self, - paths, - train_ds: Iterable[str] = None, + paths, train_subtree=True, src_vocab_op: VocabularyOption = None, - tgt_vocab_op: VocabularyOption = None, - src_embed_op: EmbeddingOption = None): + tgt_vocab_op: VocabularyOption = None,): + paths = check_dataloader_paths(paths) input_name, target_name = 'words', 'target' src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) - info = DataInfo(datasets=self.load(paths)) - _train_ds = [info.datasets[name] - for name in train_ds] if train_ds else info.datasets.values() - src_vocab.from_dataset(*_train_ds, field_name=input_name) - tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + info = DataInfo() + origin_subtree = self.subtree + self.subtree = train_subtree + info.datasets['train'] = self._load(paths['train']) + self.subtree = origin_subtree + for n, p in paths.items(): + if n != 'train': + info.datasets[n] = self._load(p) + + src_vocab.from_dataset( + info.datasets['train'], + field_name=input_name, + no_create_entry_dataset=[ds for n, ds in info.datasets.items() if n != 'train']) + tgt_vocab.from_dataset(info.datasets['train'], field_name=target_name) + src_vocab.index_dataset( *info.datasets.values(), field_name=input_name, new_field_name=input_name) @@ -89,10 +96,5 @@ class SSTLoader(DataSetLoader): target_name: tgt_vocab } - if src_embed_op is not None: - src_embed_op.vocab = src_vocab - init_emb = EmbedLoader.load_with_vocab(**src_embed_op) - info.embeddings[input_name] = init_emb - return info diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py new file mode 100644 index 00000000..a7d2de85 --- /dev/null +++ b/fastNLP/io/utils.py @@ -0,0 +1,69 @@ +import os + +from typing import Union, Dict + + +def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: + """ + 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 + { + 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 + 'test': 'xxx' # 可能有,也可能没有 + ... + } + 如果paths为不合法的,将直接进行raise相应的错误 + + :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 + 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 + :return: + """ + if isinstance(paths, str): + if os.path.isfile(paths): + return {'train': paths} + elif os.path.isdir(paths): + filenames = os.listdir(paths) + files = {} + for filename in filenames: + path_pair = None + if 'train' in filename: + path_pair = ('train', filename) + if 'dev' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) + path_pair = ('dev', filename) + if 'test' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) + path_pair = ('test', filename) + if path_pair: + files[path_pair[0]] = os.path.join(paths, path_pair[1]) + return files + else: + raise FileNotFoundError(f"{paths} is not a valid file path.") + + elif isinstance(paths, dict): + if paths: + if 'train' not in paths: + raise KeyError("You have to include `train` in your dict.") + for key, value in paths.items(): + if isinstance(key, str) and isinstance(value, str): + if not os.path.isfile(value): + raise TypeError(f"{value} is not a valid file.") + else: + raise TypeError("All keys and values in paths should be str.") + return paths + else: + raise ValueError("Empty paths is not allowed.") + else: + raise TypeError(f"paths only supports str and dict. not {type(paths)}.") + +def get_tokenizer(): + try: + import spacy + spacy.prefer_gpu() + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split() diff --git a/reproduction/seqence_labelling/ner/model/dilated_cnn.py b/reproduction/seqence_labelling/ner/model/dilated_cnn.py index cd2fa64b..a4e02159 100644 --- a/reproduction/seqence_labelling/ner/model/dilated_cnn.py +++ b/reproduction/seqence_labelling/ner/model/dilated_cnn.py @@ -8,16 +8,23 @@ from fastNLP.core.const import Const as C class IDCNN(nn.Module): - def __init__(self, init_embed, char_embed, + def __init__(self, + init_embed, + char_embed, num_cls, repeats, num_layers, num_filters, kernel_size, use_crf=False, use_projection=False, block_loss=False, input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): super(IDCNN, self).__init__() self.word_embeddings = Embedding(init_embed) - self.char_embeddings = Embedding(char_embed) - embedding_size = self.word_embeddings.embedding_dim + \ - self.char_embeddings.embedding_dim + + if char_embed is None: + self.char_embeddings = None + embedding_size = self.word_embeddings.embedding_dim + else: + self.char_embeddings = Embedding(char_embed) + embedding_size = self.word_embeddings.embedding_dim + \ + self.char_embeddings.embedding_dim self.conv0 = nn.Sequential( nn.Conv1d(in_channels=embedding_size, @@ -31,7 +38,7 @@ class IDCNN(nn.Module): block = [] for layer_i in range(num_layers): - dilated = 2 ** layer_i + dilated = 2 ** layer_i if layer_i+1 < num_layers else 1 block.append(nn.Conv1d( in_channels=num_filters, out_channels=num_filters, @@ -67,11 +74,24 @@ class IDCNN(nn.Module): self.crf = ConditionalRandomField( num_tags=num_cls) if use_crf else None self.block_loss = block_loss + self.reset_parameters() - def forward(self, words, chars, seq_len, target=None): - e1 = self.word_embeddings(words) - e2 = self.char_embeddings(chars) - x = torch.cat((e1, e2), dim=-1) # b,l,h + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.xavier_normal_(m.weight, gain=1) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + def forward(self, words, seq_len, target=None, chars=None): + if self.char_embeddings is None: + x = self.word_embeddings(words) + else: + if chars is None: + raise ValueError('must provide chars for model with char embedding') + e1 = self.word_embeddings(words) + e2 = self.char_embeddings(chars) + x = torch.cat((e1, e2), dim=-1) # b,l,h mask = seq_len_to_mask(seq_len) x = x.transpose(1, 2) # b,h,l @@ -84,21 +104,24 @@ class IDCNN(nn.Module): def compute_loss(y, t, mask): if self.crf is not None and target is not None: - loss = self.crf(y, t, mask) + loss = self.crf(y.transpose(1, 2), t, mask) else: t.masked_fill_(mask == 0, -100) loss = F.cross_entropy(y, t, ignore_index=-100) return loss - if self.block_loss: - losses = [compute_loss(o, target, mask) for o in output] - loss = sum(losses) + if target is not None: + if self.block_loss: + losses = [compute_loss(o, target, mask) for o in output] + loss = sum(losses) + else: + loss = compute_loss(output[-1], target, mask) else: - loss = compute_loss(output[-1], target, mask) + loss = None scores = output[-1] if self.crf is not None: - pred = self.crf.viterbi_decode(scores, target, mask) + pred, _ = self.crf.viterbi_decode(scores.transpose(1, 2), mask) else: pred = scores.max(1)[1] * mask.long() @@ -107,5 +130,13 @@ class IDCNN(nn.Module): C.OUTPUT: pred, } - def predict(self, words, chars, seq_len): - return self.forward(words, chars, seq_len)[C.OUTPUT] + def predict(self, words, seq_len, chars=None): + res = self.forward( + words=words, + seq_len=seq_len, + chars=chars, + target=None + )[C.OUTPUT] + return { + C.OUTPUT: res + } diff --git a/reproduction/seqence_labelling/ner/train_idcnn.py b/reproduction/seqence_labelling/ner/train_idcnn.py new file mode 100644 index 00000000..1781c763 --- /dev/null +++ b/reproduction/seqence_labelling/ner/train_idcnn.py @@ -0,0 +1,99 @@ +from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader +from fastNLP.core.callback import FitlogCallback, LRScheduler +from fastNLP import GradientClipCallback +from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR +from torch.optim import SGD, Adam +from fastNLP import Const +from fastNLP import RandomSampler, BucketSampler +from fastNLP import SpanFPreRecMetric +from fastNLP import Trainer +from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN +from fastNLP.core.utils import Option +from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding +from fastNLP.core.utils import cache_results +import sys +import torch.cuda +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +encoding_type = 'bioes' + + +def get_path(path): + return os.path.join(os.environ['HOME'], path) + +data_path = get_path('workdir/datasets/ontonotes-v4') + +ops = Option( + batch_size=128, + num_epochs=100, + lr=3e-4, + repeats=3, + num_layers=3, + num_filters=400, + use_crf=True, + gradient_clip=5, +) + +@cache_results('ontonotes-cache') +def load_data(): + + data = OntoNoteNERDataLoader(encoding_type=encoding_type).process(data_path, + lower=True) + + # char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], + # kernel_sizes=[3]) + + word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], + model_dir_or_name='en-glove-840b-300', + requires_grad=True) + return data, [word_embed] + +data, embeds = load_data() +print(data.datasets['train'][0]) +print(list(data.vocabs.keys())) + +for ds in data.datasets.values(): + ds.rename_field('cap_words', 'chars') + ds.set_input('chars') + +word_embed = embeds[0] +char_embed = CNNCharEmbedding(data.vocabs['cap_words']) +# for ds in data.datasets: +# ds.rename_field('') + +print(data.vocabs[Const.TARGET].word2idx) + +model = IDCNN(init_embed=word_embed, + char_embed=char_embed, + num_cls=len(data.vocabs[Const.TARGET]), + repeats=ops.repeats, + num_layers=ops.num_layers, + num_filters=ops.num_filters, + kernel_size=3, + use_crf=ops.use_crf, use_projection=True, + block_loss=True, + input_dropout=0.33, hidden_dropout=0.2, inner_dropout=0.2) + +print(model) + +callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='norm'),] + +optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=0) +# scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) +# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 15))) +# optimizer = SWATS(model.parameters(), verbose=True) +# optimizer = Adam(model.parameters(), lr=0.005) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, + sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), + device=device, dev_data=data.datasets['dev'], batch_size=ops.batch_size, + metrics=SpanFPreRecMetric( + tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), + check_code_level=-1, + callbacks=callbacks, num_workers=2, n_epochs=ops.num_epochs) +trainer.train() diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index d97f9399..704c29e5 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -8,18 +8,7 @@ from fastNLP.io.base_loader import DataInfo from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict -from reproduction.utils import check_dataloader_paths - - -def get_tokenizer(): - try: - import spacy - en = spacy.load('en') - print('use spacy tokenizer') - return lambda x: [w.text for w in en.tokenizer(x)] - except Exception as e: - print('use raw tokenizer') - return lambda x: x.split() +from reproduction.utils import check_dataloader_paths, get_tokenizer def clean_str(sentence, tokenizer, char_lower=False): """ diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index 294a0742..9664bf75 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -9,6 +9,7 @@ from fastNLP import CrossEntropyLoss, AccuracyMetric from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding from reproduction.text_classification.model.dpcnn import DPCNN from data.yelpLoader import yelpLoader +from fastNLP.core.sampler import BucketSampler import torch.nn as nn from fastNLP.core import LRScheduler from fastNLP.core.const import Const as C @@ -28,19 +29,20 @@ class Config(): embedding_grad = True train_epoch = 30 batch_size = 100 - num_classes = 2 task = "yelp_p" - #datadir = '/remote-home/yfshao/workdir/datasets/SST' - datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity' + #datadir = 'workdir/datasets/SST' + datadir = 'workdir/datasets/yelp_polarity' + # datadir = 'workdir/datasets/yelp_full' #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} datafile = {"train": "train.csv", "test": "test.csv"} lr = 1e-3 - src_vocab_op = VocabularyOption() + src_vocab_op = VocabularyOption(max_size=100000) embed_dropout = 0.3 cls_dropout = 0.1 - weight_decay = 1e-4 + weight_decay = 1e-5 def __init__(self): + self.datadir = os.path.join(os.environ['HOME'], self.datadir) self.datapath = {k: os.path.join(self.datadir, v) for k, v in self.datafile.items()} @@ -53,6 +55,8 @@ print('RNG SEED: {}'.format(ops.seed)) # 1.task相关信息:利用dataloader载入dataInfo #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) + + @cache_results(ops.model_dir_or_name+'-data-cache') def load_data(): datainfo = yelpLoader(fine_grained=True, lower=True).process( @@ -61,28 +65,23 @@ def load_data(): ds.apply_field(len, C.INPUT, C.INPUT_LEN) ds.set_input(C.INPUT, C.INPUT_LEN) ds.set_target(C.TARGET) - return datainfo + embedding = StaticEmbedding( + datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad, + normalize=False + ) + return datainfo, embedding -datainfo = load_data() + +datainfo, embedding = load_data() # 2.或直接复用fastNLP的模型 -vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) -#embedding = StaticEmbedding(vocab) -embedding = StaticEmbedding( - vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad, - normalize=False -) -print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['test'])) +print(datainfo) print(datainfo.datasets['train'][0]) -print(len(vocab)) -print(len(datainfo.vocabs['target'])) - -model = DPCNN(init_embed=embedding, num_cls=ops.num_classes, +model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) print(model) @@ -93,11 +92,11 @@ optimizer = SGD([param for param in model.parameters() if param.requires_grad == lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) callbacks = [] -callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) -# callbacks.append -# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < -# ops.train_epoch * 0.8 else ops.lr * 0.1)) -# ) +# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) +callbacks.append( + LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < + ops.train_epoch * 0.8 else ops.lr * 0.1)) +) # callbacks.append( # FitlogCallback(data=datainfo.datasets, verbose=1) @@ -109,6 +108,7 @@ print(device) # 4.定义train方法 trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), metrics=[metric], dev_data=datainfo.datasets['test'], device=device, check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, diff --git a/reproduction/utils.py b/reproduction/utils.py index 4f0d021e..a7d2de85 100644 --- a/reproduction/utils.py +++ b/reproduction/utils.py @@ -57,4 +57,13 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: else: raise TypeError(f"paths only supports str and dict. not {type(paths)}.") - +def get_tokenizer(): + try: + import spacy + spacy.prefer_gpu() + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split() From 1d4e9968f242f2547a13ad0ab1586fc9f15a59a8 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sun, 7 Jul 2019 16:29:01 +0800 Subject: [PATCH 15/20] - update star-transformer README --- reproduction/Star_transformer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reproduction/Star_transformer/README.md b/reproduction/Star_transformer/README.md index 37c5f1e9..d07d5536 100644 --- a/reproduction/Star_transformer/README.md +++ b/reproduction/Star_transformer/README.md @@ -6,7 +6,7 @@ paper: [Star-Transformer](https://arxiv.org/abs/1902.09113) |Pos Tagging|CTB 9.0|-|ACC 92.31| |Pos Tagging|CONLL 2012|-|ACC 96.51| |Named Entity Recognition|CONLL 2012|-|F1 85.66| -|Text Classification|SST|-|49.18| +|Text Classification|SST|-|51.2| |Natural Language Inference|SNLI|-|83.76| ## Usage From 5d9e064ec27595a1e09d7ddcbb27280c685b0701 Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 01:11:46 +0800 Subject: [PATCH 16/20] text_classfication --- .../text_classification/data/SSTLoader.py | 90 ++++++++++++++++++- .../text_classification/train_awdlstm.py | 41 +-------- .../text_classification/train_lstm.py | 43 ++------- .../text_classification/train_lstm_att.py | 41 +-------- 4 files changed, 102 insertions(+), 113 deletions(-) diff --git a/reproduction/text_classification/data/SSTLoader.py b/reproduction/text_classification/data/SSTLoader.py index b570994e..d8403b7a 100644 --- a/reproduction/text_classification/data/SSTLoader.py +++ b/reproduction/text_classification/data/SSTLoader.py @@ -5,7 +5,8 @@ from fastNLP.core.vocabulary import VocabularyOption, Vocabulary from fastNLP import DataSet from fastNLP import Instance from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader - +import csv +from typing import Union, Dict class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' @@ -97,3 +98,90 @@ class SSTLoader(DataSetLoader): return info +class sst2Loader(DataSetLoader): + ''' + 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', + ''' + def __init__(self): + super(sst2Loader, self).__init__() + + def _load(self, path: str) -> DataSet: + ds = DataSet() + all_count=0 + csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') + skip_row = 0 + for idx,row in enumerate(csv_reader): + if idx<=skip_row: + continue + target = row[1] + words = row[0].split() + ds.append(Instance(words=words,target=target)) + all_count+=1 + print("all count:", all_count) + return ds + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None, + char_level_op=False): + + paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + def wordtochar(words): + chars=[] + for word in words: + word=word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs={} + + # 就分隔为char形式 + if char_level_op: + for dataset in datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + return info + +if __name__=="__main__": + datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", + "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} + datainfo=sst2Loader().process(datapath,char_level_op=True) + #print(datainfo.datasets["train"]) + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py index ce3e52bc..e67bd25b 100644 --- a/reproduction/text_classification/train_awdlstm.py +++ b/reproduction/text_classification/train_awdlstm.py @@ -8,9 +8,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.awd_lstm import AWDLSTMSentiment @@ -41,18 +39,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -71,32 +60,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': +if __name__ == "__main__": train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help() diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py index b320e79c..b89abc14 100644 --- a/reproduction/text_classification/train_lstm.py +++ b/reproduction/text_classification/train_lstm.py @@ -6,9 +6,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.lstm import BiLSTMSentiment @@ -38,18 +36,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -68,32 +57,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': - train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help() +if __name__ == "__main__": + train(datainfo, model, optimizer, loss, metrics, opt) \ No newline at end of file diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py index 8db27d09..b4d37525 100644 --- a/reproduction/text_classification/train_lstm_att.py +++ b/reproduction/text_classification/train_lstm_att.py @@ -6,9 +6,7 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' import torch.nn as nn -from data.SSTLoader import SSTLoader from data.IMDBLoader import IMDBLoader -from data.yelpLoader import yelpLoader from fastNLP.modules.encoder.embedding import StaticEmbedding from model.lstm_self_attention import BiLSTM_SELF_ATTENTION @@ -40,18 +38,9 @@ opt=Config # load data -dataloaders = { - "IMDB":IMDBLoader(), - "YELP":yelpLoader(), - "SST-5":SSTLoader(subtree=True,fine_grained=True), - "SST-3":SSTLoader(subtree=True,fine_grained=False) -} - -if opt.task_name not in ["IMDB", "YELP", "SST-5", "SST-3"]: - raise ValueError("task name must in ['IMDB', 'YELP, 'SST-5', 'SST-3']") - -dataloader = dataloaders[opt.task_name] +dataloader=IMDBLoader() datainfo=dataloader.process(opt.datapath) + # print(datainfo.datasets["train"]) # print(datainfo) @@ -70,32 +59,10 @@ optimizer= Adam([param for param in model.parameters() if param.requires_grad==T def train(datainfo, model, optimizer, loss, metrics, opt): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=metrics, dev_data=datainfo.datasets['dev'], device=0, check_code_level=-1, + metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, n_epochs=opt.train_epoch, save_path=opt.save_model_path) trainer.train() -def test(datainfo, metrics, opt): - # load model - model = ModelLoader.load_pytorch_model(opt.load_model_path) - print("model loaded!") - - # Tester - tester = Tester(datainfo.datasets['test'], model, metrics, batch_size=4, device=0) - acc = tester.test() - print("acc=",acc) - - - -parser = argparse.ArgumentParser() -parser.add_argument('--mode', required=True, dest="mode",help='set the model\'s model') - - -args = parser.parse_args() -if args.mode == 'train': +if __name__ == "__main__": train(datainfo, model, optimizer, loss, metrics, opt) -elif args.mode == 'test': - test(datainfo, metrics, opt) -else: - print('no mode specified for model!') - parser.print_help() From 46c82a7daac64c10ce425fe9c1201569dc02a75d Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 01:29:42 +0800 Subject: [PATCH 17/20] text_classfication --- .../text_classification/results_LSTM.xlsx | Bin 9944 -> 0 bytes .../text_classification/train_awdlstm.py | 4 ++-- reproduction/text_classification/train_lstm.py | 4 ++-- .../text_classification/train_lstm_att.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) delete mode 100644 reproduction/text_classification/results_LSTM.xlsx diff --git a/reproduction/text_classification/results_LSTM.xlsx b/reproduction/text_classification/results_LSTM.xlsx deleted file mode 100644 index 0d7b841b12b43ee346c4d9db17032fade8c0b40a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9944 zcmeHNWmFtlw{0APHPE=by9Rd+1a}G2xVu{jmf!?;OK^8dkYEYHgIj>0!ChY`GxIXT zSV;*Y zQ45hC)h0rUXS_C5!#Y0QjUTPKibgeJC$)<0upZ5^AYS<_jgGkJsx8Zv&Qq#q$Zw*` zoeza>WwWJ%3(7R3YGYq==is64@57JRBRpxEHDR%)xTm1T;jVkcxzTx~v?X2|E+RqQ z&hV`S`f_?lXrN06T7ljT&xFtOxhgY({GwJ{n+gY~CRJs+3$X@E;&~K3s9L$X?ut}y zdvFYtV^D30&o4nLYE=gH){cV|g)q^$+P==7=ZFW(zITKjq(H%dru&($dEKO zat7PFu(JHv|2IAVgLCq4uU;0f__Ui1C3Ii*CanK_av>HduHYdi+e)tPA1F7EULTcD zMY!;emI$a$@Ek@apxyswaBe{`YP+BO>?>EL&<93FpCxe5>_EmD2`BSkZ!M{&Z5ah1=OUt z+V}F%sz%4$4_?0QnCA8Oe|(9eEV(5PcI2H)GQxcRheWt^HX>k znc2QSmP~(#b?eKll+mwB!G(FjF)Y(ZopTzXTgP?yG1IGu7pc3n|6m~aCF$EmNRRdJ zB;mXrhQ)#b05;(P0Cb3FJZ)J$9Gq=T92{(Z^kZcj1`g?LAiwOo`@kEe&XCHyB5Wkd z(g`<}R(n?2*ESi;*le|m2_@SBOWgEZ4Q;NZh_)=a(ZlWM`Yz55cGf#Vav!7wq%QiU zaBPa{yjrvJ(ua`+q?5)9nB|ytd8A^V;N03_5M|;nC^2; zOxZ=FJUAbF5`H0_!!*AcxhV<*4>Hga4rGfQc-GlVfBWO%JQmI;^%R?><^<9ZmXbSTw}hEt4w`G*JdfoV%`SrmhYGM z*wFP!2i#Ad@osrt4zAF5>#V_LIb3^f&UD16`S5tC5n4b+J0*Tdkh-zW<8EER{miqi z!Q5NdVC1D2IRsKPLFc;TOY4KuPVYQ_o!k<9LE}^xa29mv1{ad=xrwrcZ>V;%)R#_&?kIbnXhQS+ut?#FV^t`Dm%`tEvaH~hb?izQlIGI`!*`>c-O z4P3OcGci(}TF7k*-4s+zTT6p1t#E+{aP-k-nTMt0q>$|Rcm4zkCng_5T3R5T3;_TU z3X(s6^fQ0YqQCkaCkTYI%N}VAdwcVI2dVRL( z%)dfeCn{ASDd86K4Rs|8>z5%7ppSQ1&r5C9*k5rKzaFH2H+RJJ@nqk8ZJ=**3G$`? z9Wr<6;B66D0N^_s0DuR%;?I)nVhIMjy0HE{u>YvN8CrAE`Fub>BddF;%l77wd2Eg} z{Xl3$v~Y|ss+y3fF%wz*hnwgqRGtnhg)jG6l7(Q8~Fh^6_D<2cJ3-d0siYP3gF3qSfYi>@HKWOxyAN8cvh?29yIIEVV2q>jy zXiF>TftbSLp2y;AI5Ng4>Bvf-Pl|wxIUN_O6I8OW^4XGOEW^8lXl*IuU3#6tl2i)C z!I#OjhKcpMNrDu~DwqB&hL~~KWYw$mTG_AZs%y;dyg$7tXa)HOcuJ&clVbWK9BfP( zGqmB-uSS3h`RE7}i8(hqZc{m3EXU|hZ$+q9QHcjw^l1-RxG$RNz4E1P9C42WRBh+O zx7Op2IhPRDT79m)f*hU(G@ZwI?~2su?>;uuZmppbfL$x)jcJE+N_trmc=v5rx1oYr z%EN&P9AN_Np-ie6&IInKT|A0Kjt5JP#B-_KbbVn`gsI6Ajkwo?3sOK_=E9hra;vK< z*<32Dw?Ob#nk4;7!09hYk0?+lv#g#x=zHI642I@@lxr7I;D!|!jOeS2KZQ;>fp^Q6 zuL!BEU@kv=|v(d!SjJ;=0eVEj@lTkV%8!z>j;C(?^h`WivdJp!6{3tSUMXpo@g0&7Xdw zbO?1CCiHIma_YH?3ZsV-f)~yPYz&NUJmma=fhzQrvD@t?s*gD z{=Ca{)=GN;i9?(5v%g+wjk)G3rA9~2+ z@|}E9Vf(O`5WWEqksWAEW)DT|KtcqY<(Y5eKD7XPH;THR#b38jRvHd{;YMjgO`foW^M<}nO1}TRX7{OsBr7tWD(|V`sS9Bx9 z&s(sjyFA3j`gE5vmnZtjaz_xb!l~5#T*m!OxAnE*m{WXLm&o`)g|S558{68rvEs7S z^jbaqPQcgJfIHr>3g!xJ>3TC8MJLsgDtI1LWAWzn$Q@JU3IRc zL_q?9RFHuOK3fiP3JP*@*mxrM7KaHNnfK|yXM=g?=LN;Mzg}5 zf{Cl1UXpA#RRdqPTfcsqgMVma8>HsRW2zGjJ8yox%_Xt1%$qyevhiXPHg3XAH8F5b zpcX&%nFWMu*|%pYijEaj#Nqjxqe1^us7v zmc{x`C$jYx-Wf{yyuiQ~u;p&ey<*&b^nA<8;52opFI~w`^}VbMOUrQ1)ZXcH=+dAgg_ryOwM%dDJ%<8JU%K$diojHgHU|VCix5czr6A{z zg{yiBbYgsep3Zmiw%ac`5O~5x(#t(<95Dp)r)Pb^60}|Qo7Q241Ge_PICtK+@c&|W^+8@Jo0M##c1e)nno_q%;+pQ%>g zQ*i@p&VKi1VXe|rvkenJX*wYc#>ki^-Ab2~xFQKlbwZ^2^m*Gw9Cm#iJs23?;n*4{ z4w?=s3gL)RXq;2-`qPZ#uG0(wEJ*5x$k`OTh=RDoa1DRVPV|6rdqTD(0C{;;_x#YvqFGsoB^Go5b*_Lnz< zHw&F(@yz54w1+42_~T%Abg>Ayp){1`L`S}(;`PxCOup`UL>-*~T$&`oRKWw9_vSJhDNu{^@z4I_M$*7psjea&`zIcxv^%KfY-Q2nKqtCXsnf4;f|y8 z4%n0p%1u0bt#O5&N#6U z9z+V_d?3K6$N)?8?_sfSUa1I$GZ-@Bso#X|Yv_tKs%*iF1s{2bpLMcD-f7TYwR~|U*}_7xUh-F z=>Kl2q%OZoYb#8pR%BO3tW?N%o8RKmt_NWpDVpx4+c*?=9C zhiFOrGdEoHcCTSazf0>PQ56_aRoH*6!?}wf@{q%ITJ1`3{MzxgYdmM29go&3eUgQ# z)zNDA@FLye8{u43Hg9k~a*0imS@10-IwgaP23xcILHhBh)Mx616v$n#3y>Da2t=M{ z3rE=0uAkT{7K;JMcVH`{ip?wZHf$JV{R$Wuj0@FufiP3b8IEmR_8P)Dn=H~purPE+z z0^c%s?q;g*gUypgR~8wgb`)I2__Td@i<~N&#`*(Ym7Bj%MC4@lolXr;Ep_Ns#v}GN zW^z}BYVIy7veC73wLb2hQt3tv646{?K53h44Bp+!c;27e$bPm|qyE<*-`_o?&r zR#H1ptUg&;<;16ze>t;JhkK_0Mmez~zG*q9AaMr2@yB7I?Xf#ac!T9GUf6*#mj3|LB3Q3V1E$VceH}Sk!`p~sI$w` zSi9?zXNfcyeUp>Jtl7(GThwA7QCC9;1PB0ZFTahbW+~21D>T_P_<=Udf|Ja;hc5lu zJHrkcb*rKR<4+7LXI^kHO9rE_CN~`4R!4~u9@m7X!*^Kb>jaBre9BhDrPWz}(stQ{ zB*W!gH0rw|xpL@R-Ctz5OdLQY2zP$~tYgKc+X<^p>vqY#HuE{MK+p?DNQm~I!xQ!C z_ak`KjvKfCc&7h2)A0$-XEVXb45p*!ANVq|C3U!#Nra9%5A>Jzvq%(#)9`8NES%p= zD&^*4O9pR|#yfh*H-)`}+hYAK6YNOxVnKnZ5DbtZDCp0mvNU!Eo2j`vTiILuky{BW z3f*kLz9qPK68%0&6~W;J_%0H7e87zU_sJuMQ=;AxLmIn{yj^rY?p{y2?mf;|S*%aT z1PQ$n^g$U=UplHiOD?WgbOar72A-K_$h^IPjjKHQ9u*m7+2T{TU6#pqs8wA6h#mNPL-0RoLGHK<7#JPWQ*+wX2Q>GLWX>EHR^QZ{JjKr zKdAC?cIHr7VYf^m&0M)XE0>w&rXR*vB(D$iUp*}AF4R1Vaq6V!blR)fC=+R*NS^yN zxXfG2!Do0CfmneZiWoaqvrYiE-<{#7JVRs&KO$IK@V|Cd>^Mhl8g?_dN2eG!ntJzJ z?+Ms(D0D(v#33G}{KI=cl%78a?|(=?zlQKXg`bd^&d4q{kkmQMorI@P!j~5)5^B5@ zO&S|8CRelYbIGyipj)3>BJlh5b;sd#p7d+MVnTDYIHrjf=of3IxRzB1KC{g4#H#tW z=Rru!g0aTqEFEbuN=ot(eX-iu$ki`YsBZDP{d;DrqmhbZg}=?nMR>6dT)re6DJ!wI zJ2FD3V^yl(XIk^qdd@GTs&ls+?MEtTpa1T&?9p(D+~e41!g;cf`=!1g1)7NMc${U~ z_@TiT*?szWe1`P9qO7~gbaf4R^H&l80PPf>qSsmZ??nIUfh!<2ZOM z*;XG5ybnF%R(~2NSk5)BslYrjs%`V#P6NBhX+&c~Y=ND_8hh1658q9(-1E3nFX_vX zU|L2|jJEZg!#D3h?LMs6RE)K)PFXSn*}H@K&Ar1=OE0^M7J@_pP_V%NznybXa+E_v zPE!kTS`XQ8-3>5lgxm&+k$=lS3%cQxT96h~h+>EVQ4Gx-OqHD-99>vV9h|{`&K>_( zD1-!-Z@i*XHybGATy~!PNZ85W(yE;5HNXuH6H|=LJ?VB3aBS}}z!8s!=Tp|k!7PE9gsLatGrh1 zu_tp!x}#MxxStTOWuwTCm0X2#LYFo5JzySlfVlME(a3Lbf{_M^Mj(V{0skJ2#*U8v z4MvE|{@gO+JMHIw%oM_1p@r^6W!Te&ifXD9FnUi{0pz^3!2CE23c)=D16kl?V)$KD z_t`nu_aymsQAJ8~Y7SHZU#DbLzxvKim%T`NlR=AN zhZtmw5Rv(Dm9jcqSX2^*o)S6vktSs*{)XJt$XhP*nK;$UX}SU~`=&~U{FSktYj~og zaY2#QDk(?mL|S8qe)v9}m;2nz?RytvL)N&6+L14TWuuR*THxnw^uG=P`!I8cU z-b~8b&+p9q1V{SDr1B2%L>rv%M=VI9g0|NYoZf+0nHyk?UNUi<&?1}rnzmC6ItD~&&Pk?`}8~Drc$MFp$N&dEq;GyBexzt~#pCHQHgPGNb z#{Zm1`eh0Lpgj6%{C`XixxaI|7Yd?yZI-wznTA8$CVY}ATbCDCdiKl62=)n(g*NA DxVh?K diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py index e67bd25b..007b2910 100644 --- a/reproduction/text_classification/train_awdlstm.py +++ b/reproduction/text_classification/train_awdlstm.py @@ -33,9 +33,9 @@ class Config(): task_name = "IMDB" datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} - load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" save_model_path="./result_IMDB_test/" -opt=Config + +opt=Config() # load data diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py index b89abc14..4ecc61a1 100644 --- a/reproduction/text_classification/train_lstm.py +++ b/reproduction/text_classification/train_lstm.py @@ -30,9 +30,9 @@ class Config(): task_name = "IMDB" datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} - load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" save_model_path="./result_IMDB_test/" -opt=Config + +opt=Config() # load data diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py index b4d37525..a6f0dd03 100644 --- a/reproduction/text_classification/train_lstm_att.py +++ b/reproduction/text_classification/train_lstm_att.py @@ -32,9 +32,9 @@ class Config(): task_name = "IMDB" datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} - load_model_path="./result_IMDB/best_BiLSTM_SELF_ATTENTION_acc_2019-07-07-04-16-51" save_model_path="./result_IMDB_test/" -opt=Config + +opt=Config() # load data From 8156f3c69e8e79eb5050f20fea046092e9d3ad4f Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 05:14:36 +0800 Subject: [PATCH 18/20] =?UTF-8?q?=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- reproduction/text_classification/README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index b058fbb2..4b8f44bd 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -3,20 +3,20 @@ char_cnn :论文链接[Character-level Convolutional Networks for Text Classification](https://arxiv.org/pdf/1509.01626v3.pdf) dpcnn:论文链接[Deep Pyramid Convolutional Neural Networks for TextCategorization](https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf) HAN:论文链接[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) +LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding]() +AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models]() #待补充 -awd_lstm: -lstm_self_attention(BCN?): -awd-sltm: # 数据集及复现结果汇总 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) -model name | yelp_p | sst-2|IMDB| -:---: | :---: | :---: | :---: -char_cnn | 93.80/95.12 | - |- | -dpcnn | 95.50/97.36 | - |- | -HAN |- | - |-| -BCN| - |- |-| -awd-lstm| - |- |-| +model name | yelp_p | yelp_f | sst-2|IMDB| +:---: | :---: | :---: | :---: |----- |:---: +char_cnn | 93.80/95.12 | - | - |- | +dpcnn | 95.50/97.36 | - | - |- | +HAN |- | - | - |-| +LSTM| 95.74/- |- |- |88.52/-| +AWD-LSTM| 95.96/- |- |- |88.91/-| +LSTM+self_attention| 96.34/- | - | - |89.53/-| From 8f78bf5250e7f183575a7dd6603aa1315668b217 Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Mon, 8 Jul 2019 05:17:50 +0800 Subject: [PATCH 19/20] =?UTF-8?q?readme=E6=A0=BC=E5=BC=8F=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- reproduction/text_classification/README.md | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index 4b8f44bd..08c893b7 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -1,22 +1,28 @@ # text_classification任务模型复现 这里使用fastNLP复现以下模型: + char_cnn :论文链接[Character-level Convolutional Networks for Text Classification](https://arxiv.org/pdf/1509.01626v3.pdf) + dpcnn:论文链接[Deep Pyramid Convolutional Neural Networks for TextCategorization](https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf) + HAN:论文链接[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) + LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding]() + AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models]() + #待补充 # 数据集及复现结果汇总 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) -model name | yelp_p | yelp_f | sst-2|IMDB| -:---: | :---: | :---: | :---: |----- |:---: -char_cnn | 93.80/95.12 | - | - |- | -dpcnn | 95.50/97.36 | - | - |- | -HAN |- | - | - |-| -LSTM| 95.74/- |- |- |88.52/-| -AWD-LSTM| 95.96/- |- |- |88.91/-| -LSTM+self_attention| 96.34/- | - | - |89.53/-| +model name | yelp_p | yelp_f | sst-2|IMDB +:---: | :---: | :---: | :---: |----- +char_cnn | 93.80/95.12 | - | - |- +dpcnn | 95.50/97.36 | - | - |- +HAN |- | - | - |- +LSTM| 95.74/- |- |- |88.52/- +AWD-LSTM| 95.96/- |- |- |88.91/- +LSTM+self_attention| 96.34/- | - | - |89.53/- From 4687b378bb224e03317aedb08f0307af4116e1a8 Mon Sep 17 00:00:00 2001 From: wyg <1505116161@qq.com> Date: Mon, 8 Jul 2019 11:03:00 +0800 Subject: [PATCH 20/20] [verify] readme --- reproduction/text_classification/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/reproduction/text_classification/README.md b/reproduction/text_classification/README.md index 08c893b7..a318cc61 100644 --- a/reproduction/text_classification/README.md +++ b/reproduction/text_classification/README.md @@ -11,8 +11,6 @@ LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding] AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models]() -#待补充 - # 数据集及复现结果汇总 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果)