diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 4101b033..2bee7f2e 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -19,7 +19,7 @@ class DotAttention(nn.Module): 补上文档 """ - def __init__(self, key_size, value_size, dropout=0): + def __init__(self, key_size, value_size, dropout=0.0): super(DotAttention, self).__init__() self.key_size = key_size self.value_size = value_size @@ -37,7 +37,7 @@ class DotAttention(nn.Module): """ output = torch.matmul(Q, K.transpose(1, 2)) / self.scale if mask_out is not None: - output.masked_fill_(mask_out, -1e8) + output.masked_fill_(mask_out, -1e18) output = self.softmax(output) output = self.drop(output) return torch.matmul(output, V) @@ -67,9 +67,8 @@ class MultiHeadAttention(nn.Module): self.k_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size) # follow the paper, do not apply dropout within dot-product - self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) + self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) self.out = nn.Linear(value_size * num_head, input_size) - self.drop = TimestepDropout(dropout) self.reset_parameters() def reset_parameters(self): @@ -105,7 +104,7 @@ class MultiHeadAttention(nn.Module): # concat all heads, do output linear atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) - output = self.drop(self.out(atte)) + output = self.out(atte) return output diff --git a/reproduction/seqence_labelling/ner/model/dilated_cnn.py b/reproduction/seqence_labelling/ner/model/dilated_cnn.py new file mode 100644 index 00000000..cd2fa64b --- /dev/null +++ b/reproduction/seqence_labelling/ner/model/dilated_cnn.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from fastNLP.modules.decoder import ConditionalRandomField +from fastNLP.modules.encoder import Embedding +from fastNLP.core.utils import seq_len_to_mask +from fastNLP.core.const import Const as C + + +class IDCNN(nn.Module): + def __init__(self, init_embed, char_embed, + num_cls, + repeats, num_layers, num_filters, kernel_size, + use_crf=False, use_projection=False, block_loss=False, + input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): + super(IDCNN, self).__init__() + self.word_embeddings = Embedding(init_embed) + self.char_embeddings = Embedding(char_embed) + embedding_size = self.word_embeddings.embedding_dim + \ + self.char_embeddings.embedding_dim + + self.conv0 = nn.Sequential( + nn.Conv1d(in_channels=embedding_size, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=1, + padding=kernel_size//2, + bias=True), + nn.ReLU(), + ) + + block = [] + for layer_i in range(num_layers): + dilated = 2 ** layer_i + block.append(nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters, + kernel_size=kernel_size, + stride=1, dilation=dilated, + padding=(kernel_size//2) * dilated, + bias=True)) + block.append(nn.ReLU()) + self.block = nn.Sequential(*block) + + if use_projection: + self.projection = nn.Sequential( + nn.Conv1d( + in_channels=num_filters, + out_channels=num_filters//2, + kernel_size=1, + bias=True), + nn.ReLU(),) + encode_dim = num_filters // 2 + else: + self.projection = None + encode_dim = num_filters + + self.input_drop = nn.Dropout(input_dropout) + self.hidden_drop = nn.Dropout(hidden_dropout) + self.inner_drop = nn.Dropout(inner_dropout) + self.repeats = repeats + self.out_fc = nn.Conv1d( + in_channels=encode_dim, + out_channels=num_cls, + kernel_size=1, + bias=True) + self.crf = ConditionalRandomField( + num_tags=num_cls) if use_crf else None + self.block_loss = block_loss + + def forward(self, words, chars, seq_len, target=None): + e1 = self.word_embeddings(words) + e2 = self.char_embeddings(chars) + x = torch.cat((e1, e2), dim=-1) # b,l,h + mask = seq_len_to_mask(seq_len) + + x = x.transpose(1, 2) # b,h,l + last_output = self.conv0(x) + output = [] + for repeat in range(self.repeats): + last_output = self.block(last_output) + hidden = self.projection(last_output) if self.projection is not None else last_output + output.append(self.out_fc(hidden)) + + def compute_loss(y, t, mask): + if self.crf is not None and target is not None: + loss = self.crf(y, t, mask) + else: + t.masked_fill_(mask == 0, -100) + loss = F.cross_entropy(y, t, ignore_index=-100) + return loss + + if self.block_loss: + losses = [compute_loss(o, target, mask) for o in output] + loss = sum(losses) + else: + loss = compute_loss(output[-1], target, mask) + + scores = output[-1] + if self.crf is not None: + pred = self.crf.viterbi_decode(scores, target, mask) + else: + pred = scores.max(1)[1] * mask.long() + + return { + C.LOSS: loss, + C.OUTPUT: pred, + } + + def predict(self, words, chars, seq_len): + return self.forward(words, chars, seq_len)[C.OUTPUT] diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py index cb422524..22841c4d 100644 --- a/reproduction/text_classification/data/IMDBLoader.py +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -9,6 +9,7 @@ from fastNLP import Const # from reproduction.utils import check_dataloader_paths from functools import partial + class IMDBLoader(DataSetLoader): """ 读取IMDB数据集,DataSet包含以下fields: @@ -33,6 +34,7 @@ class IMDBLoader(DataSetLoader): target = parts[0] words = parts[1].split() dataset.append(Instance(words=words, target=target)) + if len(dataset)==0: raise RuntimeError(f"{path} has no valid data.") @@ -44,15 +46,13 @@ class IMDBLoader(DataSetLoader): tgt_vocab_opt: VocabularyOption = None, src_embed_opt: EmbeddingOption = None, char_level_op=False): - - # paths = check_dataloader_paths(paths) - + datasets = {} info = DataInfo() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset - + def wordtochar(words): chars = [] for word in words: @@ -94,6 +94,7 @@ class IMDBLoader(DataSetLoader): return info + if __name__=="__main__": datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv", "test": "/remote-home/ygwang/IMDB_data/test.csv"} @@ -104,4 +105,4 @@ if __name__=="__main__": len_count += len(instance["chars"]) ave_len = len_count / len(datainfo.datasets["train"]) - print(ave_len) \ No newline at end of file + print(ave_len) diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index 9d34004d..9e1e1c6b 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -8,11 +8,21 @@ from fastNLP.io.base_loader import DataInfo,DataSetLoader from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict -from reproduction.Star_transformer.datasets import EmbedLoader from reproduction.utils import check_dataloader_paths -def clean_str(sentence,char_lower=False): + +def get_tokenizer(): + try: + import spacy + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split() + +def clean_str(sentence, tokenizer, char_lower=False): """ heavily borrowed from github https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb @@ -20,10 +30,10 @@ def clean_str(sentence,char_lower=False): :return: """ if char_lower: - sentence=sentence.lower() + sentence = sentence.lower() import re nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') - words = sentence.split() + words = tokenizer(sentence) words_collection = [] for word in words: if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: @@ -40,7 +50,6 @@ class yelpLoader(DataSetLoader): """ 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: - words: list(str), 需要分类的文本 target: str, 文本的标签 chars:list(str),未index的字符列表 @@ -52,13 +61,14 @@ class yelpLoader(DataSetLoader): def __init__(self, fine_grained=False,lower=False): super(yelpLoader, self).__init__() tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', - '4.0': 'positive', '5.0': 'very positive'} + '4.0': 'positive', '5.0': 'very positive'} if not fine_grained: tag_v['1.0'] = tag_v['2.0'] tag_v['5.0'] = tag_v['4.0'] self.fine_grained = fine_grained self.tag_v = tag_v - self.lower=lower + self.lower = lower + self.tokenizer = get_tokenizer() ''' 读取Yelp数据集, DataSet包含fields: @@ -75,6 +85,7 @@ class yelpLoader(DataSetLoader): 数据来源: https://www.yelp.com/dataset/download + def _load_json(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): @@ -107,6 +118,7 @@ class yelpLoader(DataSetLoader): print("all count:",all_count) return ds ''' + def _load(self, path): ds = DataSet() csv_reader=csv.reader(open(path,encoding='utf-8')) @@ -125,6 +137,7 @@ class yelpLoader(DataSetLoader): return ds + def process(self, paths: Union[str, Dict[str, str]], train_ds: Iterable[str] = None, src_vocab_op: VocabularyOption = None, @@ -139,15 +152,10 @@ class yelpLoader(DataSetLoader): if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) _train_ds = [info.datasets[name] for name in train_ds] if train_ds else info.datasets.values() - #vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) - # for name, path in paths.items(): - # dataset = self.load(path) - # datasets[name] = dataset - # vocab.from_dataset(dataset, field_name="words") - # info.vocabs = vocab - # info.datasets = datasets + def wordtochar(words): + chars=[] for word in words: word=word.lower() @@ -173,6 +181,7 @@ class yelpLoader(DataSetLoader): tgt_vocab.index_dataset( *info.datasets.values(), field_name=target_name, new_field_name=target_name) + info.vocabs[target_name]=tgt_vocab return info diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index a846af72..c31307bc 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -1,36 +1,38 @@ - import torch import torch.nn as nn from fastNLP.modules.utils import get_embeddings from fastNLP.core import Const as C - class DPCNN(nn.Module): - - def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + def __init__(self, init_embed, num_cls, n_filters=256, + kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1): super().__init__() - self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + self.region_embed = RegionEmbedding( + init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5]) + embed_dim = self.region_embed.embedding_dim self.conv_list = nn.ModuleList() for i in range(n_layers): self.conv_list.append(nn.Sequential( nn.ReLU(), - nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), - nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), - )) + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, + padding=kernel_size//2), + )) self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.embed_drop = nn.Dropout(embed_dropout) self.classfier = nn.Sequential( - nn.Dropout(dropout), + nn.Dropout(cls_dropout), + nn.Linear(n_filters, num_cls), ) self.reset_parameters() - def reset_parameters(self): for m in self.modules(): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): @@ -39,7 +41,6 @@ class DPCNN(nn.Module): nn.init.normal_(m.bias, mean=0, std=0.01) - def forward(self, words, seq_len=None): words = words.long() # get region embeddings @@ -58,21 +59,19 @@ class DPCNN(nn.Module): return {C.OUTPUT: x} - def predict(self, words, seq_len=None): x = self.forward(words, seq_len)[C.OUTPUT] return {C.OUTPUT: torch.argmax(x, 1)} - - - - class RegionEmbedding(nn.Module): def __init__(self, init_embed, out_dim=300, kernel_sizes=None): super().__init__() if kernel_sizes is None: kernel_sizes = [5, 9] - assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + + assert isinstance( + kernel_sizes, list), 'kernel_sizes should be List(int)' + self.embed = get_embeddings(init_embed) try: embed_dim = self.embed.embedding_dim @@ -84,28 +83,24 @@ class RegionEmbedding(nn.Module): nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), )) self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) - for _ in range(len(kernel_sizes) + 1)]) + for _ in range(len(kernel_sizes))]) self.embedding_dim = embed_dim - def forward(self, x): x = self.embed(x) x = x.transpose(1, 2) # B, C, L - out = self.linears[0](x) + out = 0 for conv, fc in zip(self.region_embeds, self.linears[1:]): conv_i = conv(x) out = out + fc(conv_i) # B, C, L - return out - - - if __name__ == '__main__': x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long) model = DPCNN((10000, 300), 20) y = model(x) - print(y.size(), y.mean(1), y.std(1)) \ No newline at end of file + print(y.size(), y.mean(1), y.std(1)) + diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index 8ddea1a3..fcfa138b 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -1,101 +1,125 @@ # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import torch.cuda +from fastNLP.core.utils import cache_results +from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR +from fastNLP.core.trainer import Trainer +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from reproduction.text_classification.model.dpcnn import DPCNN +from data.yelpLoader import yelpLoader +import torch.nn as nn +from fastNLP.core import LRScheduler +from fastNLP.core.const import Const as C +from fastNLP.core.vocabulary import VocabularyOption +from utils.util_init import set_rng_seeds import os os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -import sys -sys.path.append('../..') -from fastNLP.core.const import Const as C -from fastNLP.core import LRScheduler -import torch.nn as nn -from fastNLP.io.dataset_loader import SSTLoader -from data.yelpLoader import yelpLoader -from reproduction.text_classification.model.dpcnn import DPCNN -from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding -from fastNLP import CrossEntropyLoss, AccuracyMetric -from fastNLP.core.trainer import Trainer -from torch.optim import SGD -import torch.cuda -from torch.optim.lr_scheduler import CosineAnnealingLR -##hyper +# hyper class Config(): - model_dir_or_name="en-base-uncased" - embedding_grad= False, - train_epoch= 30 + seed = 12345 + model_dir_or_name = "dpcnn-yelp-p" + embedding_grad = True + train_epoch = 30 batch_size = 100 - num_classes=2 - task= "yelp_p" + num_classes = 2 + task = "yelp_p" #datadir = '/remote-home/yfshao/workdir/datasets/SST' - datadir = '/remote-home/ygwang/yelp_polarity' + datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity' #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} datafile = {"train": "train.csv", "test": "test.csv"} - lr=1e-3 + lr = 1e-3 + src_vocab_op = VocabularyOption() + embed_dropout = 0.3 + cls_dropout = 0.1 + weight_decay = 1e-4 def __init__(self): - self.datapath = {k:os.path.join(self.datadir, v) + self.datapath = {k: os.path.join(self.datadir, v) for k, v in self.datafile.items()} -ops=Config() +ops = Config() +set_rng_seeds(ops.seed) +print('RNG SEED: {}'.format(ops.seed)) -##1.task相关信息:利用dataloader载入dataInfo +# 1.task相关信息:利用dataloader载入dataInfo #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) -datainfo=yelpLoader(fine_grained=True,lower=True).process(paths=ops.datapath, train_ds=['train']) -print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['test'])) +@cache_results(ops.model_dir_or_name+'-data-cache') +def load_data(): + datainfo = yelpLoader(fine_grained=True, lower=True).process( + paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op) + for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + return datainfo +datainfo = load_data() -## 2.或直接复用fastNLP的模型 +# 2.或直接复用fastNLP的模型 vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) #embedding = StaticEmbedding(vocab) -embedding = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) + +embedding = StaticEmbedding( + vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad, + normalize=False +) + +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['test'])) +print(datainfo.datasets['train'][0]) + print(len(vocab)) print(len(datainfo.vocabs['target'])) -model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes, + embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) +print(model) - -## 3. 声明loss,metric,optimizer -loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) -metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) -optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], - lr=ops.lr, momentum=0.9, weight_decay=0) +# 3. 声明loss,metric,optimizer +loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], + lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) callbacks = [] callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) +# callbacks.append +# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < +# ops.train_epoch * 0.8 else ops.lr * 0.1)) +# ) +# callbacks.append( +# FitlogCallback(data=datainfo.datasets, verbose=1) +# ) -device = 'cuda:3' if torch.cuda.is_available() else 'cpu' +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print(device) -for ds in datainfo.datasets.values(): - ds.apply_field(len, C.INPUT, C.INPUT_LEN) - ds.set_input(C.INPUT, C.INPUT_LEN) - ds.set_target(C.TARGET) +# 4.定义train方法 +trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metric], + dev_data=datainfo.datasets['test'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=ops.train_epoch, num_workers=4) -## 4.定义train方法 -def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): - trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=[metrics], dev_data=datainfo.datasets['test'], device=3, - check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, - n_epochs=num_epochs) +if __name__ == "__main__": print(trainer.train()) - - -if __name__=="__main__": - train(model,datainfo,loss,metric,optimizer) \ No newline at end of file diff --git a/reproduction/text_classification/utils/util_init.py b/reproduction/text_classification/utils/util_init.py new file mode 100644 index 00000000..fcb8fffb --- /dev/null +++ b/reproduction/text_classification/utils/util_init.py @@ -0,0 +1,11 @@ +import numpy +import torch +import random + + +def set_rng_seeds(seed): + random.seed(seed) + numpy.random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # print('RNG_SEED {}'.format(seed))