From 32917dac7f8e87adbad72c69d3b269dd48606b8a Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 27 Jun 2019 21:56:49 +0800 Subject: [PATCH 1/4] add dpcnn --- .../text_classification/model/dpcnn.py | 92 ++++++++++++++++++- .../text_classification/train_dpcnn.py | 80 ++++++++++++++++ 2 files changed, 171 insertions(+), 1 deletion(-) diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index f87f5c14..2da7b3e5 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -1 +1,91 @@ -# TODO \ No newline at end of file +import torch +import torch.nn as nn +from fastNLP.modules.utils import get_embeddings +from fastNLP.core import Const as C + +class DPCNN(nn.Module): + def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + super().__init__() + self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + embed_dim = self.region_embed.embedding_dim + self.conv_list = nn.ModuleList() + for i in range(n_layers): + self.conv_list.append(nn.Sequential( + nn.ReLU(), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + )) + self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + self.embed_drop = nn.Dropout(embed_dropout) + self.classfier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(n_filters, num_cls), + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + def forward(self, words, seq_len=None): + words = words.long() + # get region embeddings + x = self.region_embed(words) + x = self.embed_drop(x) + + # not pooling on first conv + x = self.conv_list[0](x) + x + for conv in self.conv_list[1:]: + x = self.pool(x) + x = conv(x) + x + + # B, C, L => B, C + x, _ = torch.max(x, dim=2) + x = self.classfier(x) + return {C.OUTPUT: x} + + def predict(self, words, seq_len=None): + x = self.forward(words, seq_len)[C.OUTPUT] + return {C.OUTPUT: torch.argmax(x, 1)} + + +class RegionEmbedding(nn.Module): + def __init__(self, init_embed, out_dim=300, kernel_sizes=None): + super().__init__() + if kernel_sizes is None: + kernel_sizes = [5, 9] + assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + self.embed = get_embeddings(init_embed) + try: + embed_dim = self.embed.embedding_dim + except Exception: + embed_dim = self.embed.embed_size + self.region_embeds = nn.ModuleList() + for ksz in kernel_sizes: + self.region_embeds.append(nn.Sequential( + nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), + )) + self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) + for _ in range(len(kernel_sizes) + 1)]) + self.embedding_dim = embed_dim + + def forward(self, x): + x = self.embed(x) + x = x.transpose(1, 2) + # B, C, L + out = self.linears[0](x) + for conv, fc in zip(self.region_embeds, self.linears[1:]): + conv_i = conv(x) + out = out + fc(conv_i) + # B, C, L + return out + + +if __name__ == '__main__': + x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long) + model = DPCNN((10000, 300), 20) + y = model(x) + print(y.size(), y.mean(1), y.std(1)) diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index e69de29b..13ff4fc1 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -0,0 +1,80 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +from fastNLP.core.const import Const as C +from fastNLP.core import LRScheduler +import torch.nn as nn +from fastNLP.io.dataset_loader import SSTLoader +from reproduction.text_classification.model.dpcnn import DPCNN +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +import torch.cuda +from torch.optim.lr_scheduler import CosineAnnealingLR + +##hyper +class Config(): + model_dir_or_name="en-base-uncased" + embedding_grad= False, + train_epoch= 30 + batch_size = 100 + num_classes=5 + task= "SST" + datadir = '/remote-home/yfshao/workdir/datasets/SST' + datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + lr=1e-3 + def __init__(self): + self.datapath = {k:os.path.join(self.datadir, v) + for k, v in self.datafile.items()} + +ops=Config() + + +##1.task相关信息:利用dataloader载入dataInfo +datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds='train') + +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['dev'])) + + +## 2.或直接复用fastNLP的模型 +vocab = datainfo.vocabs['words'] + +# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) +embedding = StaticEmbedding(vocab) +print(len(vocab)) +print(len(datainfo.vocabs['target'])) +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) + +## 3. 声明loss,metric,optimizer +loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], + lr=ops.lr, momentum=0.9, weight_decay=0) + +callbacks = [] +callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' +print(device) + +for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + +## 4.定义train方法 +def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metrics], dev_data=datainfo.datasets['dev'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=num_epochs) + print(trainer.train()) + + +if __name__=="__main__": + train(model,datainfo,loss,metric,optimizer) \ No newline at end of file From f1adb0f9156357944183fdb692d5206aad7e3b0d Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 4 Jul 2019 13:56:37 +0800 Subject: [PATCH 2/4] add ID-CNN --- .../ner/model/dilated_cnn.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 reproduction/seqence_labelling/ner/model/dilated_cnn.py diff --git a/reproduction/seqence_labelling/ner/model/dilated_cnn.py b/reproduction/seqence_labelling/ner/model/dilated_cnn.py new file mode 100644 index 00000000..cd2fa64b --- /dev/null +++ b/reproduction/seqence_labelling/ner/model/dilated_cnn.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from fastNLP.modules.decoder import ConditionalRandomField +from fastNLP.modules.encoder import Embedding +from fastNLP.core.utils import seq_len_to_mask +from fastNLP.core.const import Const as C + + +class IDCNN(nn.Module): + def __init__(self, init_embed, char_embed, + num_cls, + repeats, num_layers, num_filters, kernel_size, + use_crf=False, use_projection=False, block_loss=False, + input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): + super(IDCNN, self).__init__() + self.word_embeddings = Embedding(init_embed) + self.char_embeddings = Embedding(char_embed) + embedding_size = self.word_embeddings.embedding_dim + \ + self.char_embeddings.embedding_dim + + self.conv0 = nn.Sequential( + nn.Conv1d(in_channels=embedding_size, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=1, + padding=kernel_size//2, + bias=True), + nn.ReLU(), + ) + + block = [] + for layer_i in range(num_layers): + dilated = 2 ** layer_i + block.append(nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=dilated, + padding=(kernel_size//2) * dilated, + bias=True)) + block.append(nn.ReLU()) + self.block = nn.Sequential(*block) + + if use_projection: + self.projection = nn.Sequential( + nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters//2, + kernel_size=1, + bias=True), + nn.ReLU(),) + encode_dim = num_filters // 2 + else: + self.projection = None + encode_dim = num_filters + + self.input_drop = nn.Dropout(input_dropout) + self.hidden_drop = nn.Dropout(hidden_dropout) + self.inner_drop = nn.Dropout(inner_dropout) + self.repeats = repeats + self.out_fc = nn.Conv1d( + in_channels=encode_dim, + out_channels=num_cls, + kernel_size=1, + bias=True) + self.crf = ConditionalRandomField( + num_tags=num_cls) if use_crf else None + self.block_loss = block_loss + + def forward(self, words, chars, seq_len, target=None): + e1 = self.word_embeddings(words) + e2 = self.char_embeddings(chars) + x = torch.cat((e1, e2), dim=-1) # b,l,h + mask = seq_len_to_mask(seq_len) + + x = x.transpose(1, 2) # b,h,l + last_output = self.conv0(x) + output = [] + for repeat in range(self.repeats): + last_output = self.block(last_output) + hidden = self.projection(last_output) if self.projection is not None else last_output + output.append(self.out_fc(hidden)) + + def compute_loss(y, t, mask): + if self.crf is not None and target is not None: + loss = self.crf(y, t, mask) + else: + t.masked_fill_(mask == 0, -100) + loss = F.cross_entropy(y, t, ignore_index=-100) + return loss + + if self.block_loss: + losses = [compute_loss(o, target, mask) for o in output] + loss = sum(losses) + else: + loss = compute_loss(output[-1], target, mask) + + scores = output[-1] + if self.crf is not None: + pred = self.crf.viterbi_decode(scores, target, mask) + else: + pred = scores.max(1)[1] * mask.long() + + return { + C.LOSS: loss, + C.OUTPUT: pred, + } + + def predict(self, words, chars, seq_len): + return self.forward(words, chars, seq_len)[C.OUTPUT] From 372496ca32a5015c19563c82ca1b29d6071a6e81 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 4 Jul 2019 14:03:53 +0800 Subject: [PATCH 3/4] update model & dataloader in text_classification --- .../text_classification/data/IMDBLoader.py | 82 +++++++++ .../text_classification/data/yelpLoader.py | 164 +++++++++++++++--- .../text_classification/train_dpcnn.py | 97 +++++++---- 3 files changed, 283 insertions(+), 60 deletions(-) create mode 100644 reproduction/text_classification/data/IMDBLoader.py diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py new file mode 100644 index 00000000..2df87e26 --- /dev/null +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -0,0 +1,82 @@ +from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from typing import Union, Dict, List, Iterator +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP import Const +# from reproduction.utils import check_dataloader_paths +from functools import partial + + +class IMDBLoader(DataSetLoader): + """ + 读取IMDB数据集,DataSet包含以下fields: + + words: list(str), 需要分类的文本 + target: str, 文本的标签 + + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1].split() + dataset.append(Instance(words=words, target=target)) + if len(dataset) == 0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def process(self, + paths: Union[str, Dict[str, str]], + src_vocab_opt: VocabularyOption = None, + tgt_vocab_opt: VocabularyOption = None, + src_embed_opt: EmbeddingOption = None): + + # paths = check_dataloader_paths(paths) + datasets = {} + info = DataInfo() + for name, path in paths.items(): + dataset = self.load(path) + datasets[name] = dataset + + datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) + + src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) + src_vocab.from_dataset(datasets['train'], field_name='words') + # src_vocab.from_dataset(datasets['train'], datasets["dev"], datasets["test"], field_name='words') + src_vocab.index_dataset(*datasets.values(), field_name='words') + + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) + tgt_vocab.from_dataset(datasets['train'], field_name='target') + tgt_vocab.index_dataset(*datasets.values(), field_name='target') + + info.vocabs = { + "words": src_vocab, + "target": tgt_vocab + } + + info.datasets = datasets + + if src_embed_opt is not None: + embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) + info.embeddings['words'] = embed + + for name, dataset in info.datasets.items(): + dataset.set_input("words") + dataset.set_target("target") + + return info diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index c47d48fd..63605ecf 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -1,4 +1,6 @@ import ast +import csv +from typing import Iterable from fastNLP import DataSet, Instance, Vocabulary from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io import JsonLoader @@ -10,11 +12,34 @@ from reproduction.Star_transformer.datasets import EmbedLoader from reproduction.utils import check_dataloader_paths +def clean_str(sentence, char_lower=False): + """ + heavily borrowed from github + https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb + :param sentence: is a str + :return: + """ + if char_lower: + sentence = sentence.lower() + import re + nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') + words = sentence.split() + words_collection = [] + for word in words: + if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: + continue + tt = nonalpnum.split(word) + t = ''.join(tt) + if t != '': + words_collection.append(t) + + return words_collection + + class yelpLoader(JsonLoader): - """ 读取Yelp数据集, DataSet包含fields: - + review_id: str, 22 character unique review id user_id: str, 22 character unique user id business_id: str, 22 character business id @@ -24,23 +49,25 @@ class yelpLoader(JsonLoader): date: str, date formatted YYYY-MM-DD words: list(str), 需要分类的文本 target: str, 文本的标签 - + 数据来源: https://www.yelp.com/dataset/download - + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ - - def __init__(self, fine_grained=False): + + def __init__(self, fine_grained=False, lower=False): super(yelpLoader, self).__init__() tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', - '4.0': 'positive', '5.0': 'very positive'} + '4.0': 'positive', '5.0': 'very positive'} if not fine_grained: tag_v['1.0'] = tag_v['2.0'] tag_v['5.0'] = tag_v['4.0'] self.fine_grained = fine_grained self.tag_v = tag_v - - def _load(self, path): + self.lower = lower + + ''' + def _load_json(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): d = ast.literal_eval(d) @@ -49,20 +76,113 @@ class yelpLoader(JsonLoader): ds.append(Instance(**d)) return ds - def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, - embed_opt: EmbeddingOption = None): + def _load_yelp2015_broken(self,path): + ds = DataSet() + with open (path,encoding='ISO 8859-1') as f: + row=f.readline() + all_count=0 + exp_count=0 + while row: + row=row.split("\t\t") + all_count+=1 + if len(row)>=3: + words=row[-1].split() + try: + target=self.tag_v[str(row[-2])+".0"] + ds.append(Instance(words=words, target=target)) + except KeyError: + exp_count+=1 + else: + exp_count+=1 + row = f.readline() + print("error sample count:",exp_count) + print("all count:",all_count) + return ds + ''' + + def _load(self, path): + ds = DataSet() + csv_reader = csv.reader(open(path, encoding='utf-8')) + all_count = 0 + real_count = 0 + for row in csv_reader: + all_count += 1 + if len(row) == 2: + target = self.tag_v[row[0] + ".0"] + words = clean_str(row[1], self.lower) + if len(words) != 0: + ds.append(Instance(words=words, target=target)) + real_count += 1 + print("all count:", all_count) + print("real count:", real_count) + return ds + + def process(self, paths: Union[str, Dict[str, str]], + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + embed_opt: EmbeddingOption = None, + char_level_op=False): paths = check_dataloader_paths(paths) datasets = {} - info = DataInfo() - vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) - for name, path in paths.items(): - dataset = self.load(path) - datasets[name] = dataset - vocab.from_dataset(dataset, field_name="words") - info.vocabs = vocab - info.datasets = datasets - if embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) - info.embeddings['words'] = embed + info = DataInfo(datasets=self.load(paths)) + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary(unknown=None, padding=None) \ + if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + + # vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) + # for name, path in paths.items(): + # dataset = self.load(path) + # datasets[name] = dataset + # vocab.from_dataset(dataset, field_name="words") + # info.vocabs = vocab + # info.datasets = datasets + + def wordtochar(words): + chars = [] + for word in words: + word = word.lower() + for char in word: + chars.append(char) + return chars + + input_name, target_name = 'words', 'target' + info.vocabs = {} + # 就分隔为char形式 + if char_level_op: + for dataset in info.datasets.values(): + dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') + # if embed_opt is not None: + # embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) + # info.embeddings['words'] = embed + else: + src_vocab.from_dataset(*_train_ds, field_name=input_name) + src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) + info.vocabs[input_name] = src_vocab + + tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + tgt_vocab.index_dataset( + *info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs[target_name] = tgt_vocab + return info + +if __name__ == "__main__": + testloader = yelpLoader() + # datapath = {"train": "/remote-home/ygwang/yelp_full/train.csv", + # "test": "/remote-home/ygwang/yelp_full/test.csv"} + # datapath={"train": "/remote-home/ygwang/yelp_full/test.csv"} + datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", + "test": "/remote-home/ygwang/yelp_polarity/test.csv"} + datainfo = testloader.process(datapath, char_level_op=True) + + len_count = 0 + for instance in datainfo.datasets["train"]: + len_count += len(instance["chars"]) + + ave_len = len_count / len(datainfo.datasets["train"]) + print(ave_len) \ No newline at end of file diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index 13ff4fc1..bf243ffb 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -1,65 +1,83 @@ # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 + +from torch.optim.lr_scheduler import CosineAnnealingLR +import torch.cuda +from torch.optim import SGD +from fastNLP.core.trainer import Trainer +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from reproduction.text_classification.model.dpcnn import DPCNN +from .data.yelpLoader import yelpLoader +from fastNLP.io.dataset_loader import SSTLoader +import torch.nn as nn +from fastNLP.core import LRScheduler +from fastNLP.core.const import Const as C +import sys import os os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from fastNLP.core.const import Const as C -from fastNLP.core import LRScheduler -import torch.nn as nn -from fastNLP.io.dataset_loader import SSTLoader -from reproduction.text_classification.model.dpcnn import DPCNN -from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding -from fastNLP import CrossEntropyLoss, AccuracyMetric -from fastNLP.core.trainer import Trainer -from torch.optim import SGD -import torch.cuda -from torch.optim.lr_scheduler import CosineAnnealingLR +sys.path.append('../..') + + +# hyper -##hyper class Config(): - model_dir_or_name="en-base-uncased" - embedding_grad= False, - train_epoch= 30 + model_dir_or_name = "en-base-uncased" + embedding_grad = False, + train_epoch = 30 batch_size = 100 - num_classes=5 - task= "SST" - datadir = '/remote-home/yfshao/workdir/datasets/SST' - datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} - lr=1e-3 + num_classes = 2 + task = "yelp_p" + #datadir = '/remote-home/yfshao/workdir/datasets/SST' + datadir = '/remote-home/ygwang/yelp_polarity' + #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + datafile = {"train": "train.csv", "test": "test.csv"} + lr = 1e-3 + def __init__(self): - self.datapath = {k:os.path.join(self.datadir, v) + self.datapath = {k: os.path.join(self.datadir, v) for k, v in self.datafile.items()} -ops=Config() + +ops = Config() -##1.task相关信息:利用dataloader载入dataInfo -datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds='train') +# 1.task相关信息:利用dataloader载入dataInfo +#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) +datainfo = yelpLoader(fine_grained=True, lower=True).process( + paths=ops.datapath, train_ds=['train']) print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['dev'])) +print(len(datainfo.datasets['test'])) -## 2.或直接复用fastNLP的模型 -vocab = datainfo.vocabs['words'] +# 2.或直接复用fastNLP的模型 +vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) -embedding = StaticEmbedding(vocab) +#embedding = StaticEmbedding(vocab) +embedding = StaticEmbedding( + vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) + print(len(vocab)) print(len(datainfo.vocabs['target'])) + model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) -## 3. 声明loss,metric,optimizer -loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) -metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) -optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], - lr=ops.lr, momentum=0.9, weight_decay=0) + +# 3. 声明loss,metric,optimizer +loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], + lr=ops.lr, momentum=0.9, weight_decay=0) callbacks = [] callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + print(device) for ds in datainfo.datasets.values(): @@ -67,14 +85,17 @@ for ds in datainfo.datasets.values(): ds.set_input(C.INPUT, C.INPUT_LEN) ds.set_target(C.TARGET) -## 4.定义train方法 -def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): + +# 4.定义train方法 +def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=[metrics], dev_data=datainfo.datasets['dev'], device=device, + metrics=[metrics], + dev_data=datainfo.datasets['test'], device=device, check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, n_epochs=num_epochs) + print(trainer.train()) -if __name__=="__main__": - train(model,datainfo,loss,metric,optimizer) \ No newline at end of file +if __name__ == "__main__": + train(model, datainfo, loss, metric, optimizer) From c5fc29dfef2a48729aafed42ccf6309aa5bee17e Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 6 Jul 2019 13:15:40 +0800 Subject: [PATCH 4/4] -update DPCNN & train script -use spacy tokenizer for yelp data -add set_rng_seed --- fastNLP/modules/aggregator/attention.py | 9 +- .../text_classification/data/yelpLoader.py | 18 +++- .../text_classification/model/dpcnn.py | 22 +++-- .../text_classification/train_dpcnn.py | 83 ++++++++++++------- .../text_classification/utils/util_init.py | 11 +++ 5 files changed, 94 insertions(+), 49 deletions(-) create mode 100644 reproduction/text_classification/utils/util_init.py diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 4101b033..2bee7f2e 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -19,7 +19,7 @@ class DotAttention(nn.Module): 补上文档 """ - def __init__(self, key_size, value_size, dropout=0): + def __init__(self, key_size, value_size, dropout=0.0): super(DotAttention, self).__init__() self.key_size = key_size self.value_size = value_size @@ -37,7 +37,7 @@ class DotAttention(nn.Module): """ output = torch.matmul(Q, K.transpose(1, 2)) / self.scale if mask_out is not None: - output.masked_fill_(mask_out, -1e8) + output.masked_fill_(mask_out, -1e18) output = self.softmax(output) output = self.drop(output) return torch.matmul(output, V) @@ -67,9 +67,8 @@ class MultiHeadAttention(nn.Module): self.k_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size) # follow the paper, do not apply dropout within dot-product - self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) + self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) self.out = nn.Linear(value_size * num_head, input_size) - self.drop = TimestepDropout(dropout) self.reset_parameters() def reset_parameters(self): @@ -105,7 +104,7 @@ class MultiHeadAttention(nn.Module): # concat all heads, do output linear atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) - output = self.drop(self.out(atte)) + output = self.out(atte) return output diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index 63605ecf..d97f9399 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -8,11 +8,20 @@ from fastNLP.io.base_loader import DataInfo from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict -from reproduction.Star_transformer.datasets import EmbedLoader from reproduction.utils import check_dataloader_paths -def clean_str(sentence, char_lower=False): +def get_tokenizer(): + try: + import spacy + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split() + +def clean_str(sentence, tokenizer, char_lower=False): """ heavily borrowed from github https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb @@ -23,7 +32,7 @@ def clean_str(sentence, char_lower=False): sentence = sentence.lower() import re nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') - words = sentence.split() + words = tokenizer(sentence) words_collection = [] for word in words: if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: @@ -65,6 +74,7 @@ class yelpLoader(JsonLoader): self.fine_grained = fine_grained self.tag_v = tag_v self.lower = lower + self.tokenizer = get_tokenizer() ''' def _load_json(self, path): @@ -109,7 +119,7 @@ class yelpLoader(JsonLoader): all_count += 1 if len(row) == 2: target = self.tag_v[row[0] + ".0"] - words = clean_str(row[1], self.lower) + words = clean_str(row[1], self.tokenizer, self.lower) if len(words) != 0: ds.append(Instance(words=words, target=target)) real_count += 1 diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index 2da7b3e5..dafe62bc 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -3,22 +3,27 @@ import torch.nn as nn from fastNLP.modules.utils import get_embeddings from fastNLP.core import Const as C + class DPCNN(nn.Module): - def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + def __init__(self, init_embed, num_cls, n_filters=256, + kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1): super().__init__() - self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + self.region_embed = RegionEmbedding( + init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5]) embed_dim = self.region_embed.embedding_dim self.conv_list = nn.ModuleList() for i in range(n_layers): self.conv_list.append(nn.Sequential( nn.ReLU(), - nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), - nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), )) self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.embed_drop = nn.Dropout(embed_dropout) self.classfier = nn.Sequential( - nn.Dropout(dropout), + nn.Dropout(cls_dropout), nn.Linear(n_filters, num_cls), ) self.reset_parameters() @@ -57,7 +62,8 @@ class RegionEmbedding(nn.Module): super().__init__() if kernel_sizes is None: kernel_sizes = [5, 9] - assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + assert isinstance( + kernel_sizes, list), 'kernel_sizes should be List(int)' self.embed = get_embeddings(init_embed) try: embed_dim = self.embed.embedding_dim @@ -69,14 +75,14 @@ class RegionEmbedding(nn.Module): nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), )) self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) - for _ in range(len(kernel_sizes) + 1)]) + for _ in range(len(kernel_sizes))]) self.embedding_dim = embed_dim def forward(self, x): x = self.embed(x) x = x.transpose(1, 2) # B, C, L - out = self.linears[0](x) + out = 0 for conv, fc in zip(self.region_embeds, self.linears[1:]): conv_i = conv(x) out = out + fc(conv_i) diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index bf243ffb..294a0742 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -1,40 +1,44 @@ # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 -from torch.optim.lr_scheduler import CosineAnnealingLR import torch.cuda +from fastNLP.core.utils import cache_results from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR from fastNLP.core.trainer import Trainer from fastNLP import CrossEntropyLoss, AccuracyMetric from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding from reproduction.text_classification.model.dpcnn import DPCNN -from .data.yelpLoader import yelpLoader -from fastNLP.io.dataset_loader import SSTLoader +from data.yelpLoader import yelpLoader import torch.nn as nn from fastNLP.core import LRScheduler from fastNLP.core.const import Const as C -import sys +from fastNLP.core.vocabulary import VocabularyOption +from utils.util_init import set_rng_seeds import os os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -sys.path.append('../..') - # hyper class Config(): - model_dir_or_name = "en-base-uncased" - embedding_grad = False, + seed = 12345 + model_dir_or_name = "dpcnn-yelp-p" + embedding_grad = True train_epoch = 30 batch_size = 100 num_classes = 2 task = "yelp_p" #datadir = '/remote-home/yfshao/workdir/datasets/SST' - datadir = '/remote-home/ygwang/yelp_polarity' + datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity' #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} datafile = {"train": "train.csv", "test": "test.csv"} lr = 1e-3 + src_vocab_op = VocabularyOption() + embed_dropout = 0.3 + cls_dropout = 0.1 + weight_decay = 1e-4 def __init__(self): self.datapath = {k: os.path.join(self.datadir, v) @@ -43,15 +47,23 @@ class Config(): ops = Config() +set_rng_seeds(ops.seed) +print('RNG SEED: {}'.format(ops.seed)) # 1.task相关信息:利用dataloader载入dataInfo #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) -datainfo = yelpLoader(fine_grained=True, lower=True).process( - paths=ops.datapath, train_ds=['train']) -print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['test'])) - +@cache_results(ops.model_dir_or_name+'-data-cache') +def load_data(): + datainfo = yelpLoader(fine_grained=True, lower=True).process( + paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op) + for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + return datainfo + +datainfo = load_data() # 2.或直接复用fastNLP的模型 @@ -59,43 +71,50 @@ vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) #embedding = StaticEmbedding(vocab) embedding = StaticEmbedding( - vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) + vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad, + normalize=False +) + +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['test'])) +print(datainfo.datasets['train'][0]) print(len(vocab)) print(len(datainfo.vocabs['target'])) -model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) - +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes, + embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) +print(model) # 3. 声明loss,metric,optimizer loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], - lr=ops.lr, momentum=0.9, weight_decay=0) + lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) callbacks = [] callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) +# callbacks.append +# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < +# ops.train_epoch * 0.8 else ops.lr * 0.1)) +# ) + +# callbacks.append( +# FitlogCallback(data=datainfo.datasets, verbose=1) +# ) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print(device) -for ds in datainfo.datasets.values(): - ds.apply_field(len, C.INPUT, C.INPUT_LEN) - ds.set_input(C.INPUT, C.INPUT_LEN) - ds.set_target(C.TARGET) - - # 4.定义train方法 -def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): - trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=[metrics], - dev_data=datainfo.datasets['test'], device=device, - check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, - n_epochs=num_epochs) +trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metric], + dev_data=datainfo.datasets['test'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=ops.train_epoch, num_workers=4) - print(trainer.train()) if __name__ == "__main__": - train(model, datainfo, loss, metric, optimizer) + print(trainer.train()) diff --git a/reproduction/text_classification/utils/util_init.py b/reproduction/text_classification/utils/util_init.py new file mode 100644 index 00000000..fcb8fffb --- /dev/null +++ b/reproduction/text_classification/utils/util_init.py @@ -0,0 +1,11 @@ +import numpy +import torch +import random + + +def set_rng_seeds(seed): + random.seed(seed) + numpy.random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # print('RNG_SEED {}'.format(seed))