diff --git a/reproduction/text_classification/model/dpcnn.py b/reproduction/text_classification/model/dpcnn.py index f87f5c14..2da7b3e5 100644 --- a/reproduction/text_classification/model/dpcnn.py +++ b/reproduction/text_classification/model/dpcnn.py @@ -1 +1,91 @@ -# TODO \ No newline at end of file +import torch +import torch.nn as nn +from fastNLP.modules.utils import get_embeddings +from fastNLP.core import Const as C + +class DPCNN(nn.Module): + def __init__(self, init_embed, num_cls, n_filters=256, kernel_size=3, n_layers=7, embed_dropout=0.1, dropout=0.1): + super().__init__() + self.region_embed = RegionEmbedding(init_embed, out_dim=n_filters, kernel_sizes=[3, 5, 9]) + embed_dim = self.region_embed.embedding_dim + self.conv_list = nn.ModuleList() + for i in range(n_layers): + self.conv_list.append(nn.Sequential( + nn.ReLU(), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + nn.Conv1d(n_filters, n_filters, kernel_size, padding=kernel_size//2), + )) + self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + self.embed_drop = nn.Dropout(embed_dropout) + self.classfier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(n_filters, num_cls), + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + def forward(self, words, seq_len=None): + words = words.long() + # get region embeddings + x = self.region_embed(words) + x = self.embed_drop(x) + + # not pooling on first conv + x = self.conv_list[0](x) + x + for conv in self.conv_list[1:]: + x = self.pool(x) + x = conv(x) + x + + # B, C, L => B, C + x, _ = torch.max(x, dim=2) + x = self.classfier(x) + return {C.OUTPUT: x} + + def predict(self, words, seq_len=None): + x = self.forward(words, seq_len)[C.OUTPUT] + return {C.OUTPUT: torch.argmax(x, 1)} + + +class RegionEmbedding(nn.Module): + def __init__(self, init_embed, out_dim=300, kernel_sizes=None): + super().__init__() + if kernel_sizes is None: + kernel_sizes = [5, 9] + assert isinstance(kernel_sizes, list), 'kernel_sizes should be List(int)' + self.embed = get_embeddings(init_embed) + try: + embed_dim = self.embed.embedding_dim + except Exception: + embed_dim = self.embed.embed_size + self.region_embeds = nn.ModuleList() + for ksz in kernel_sizes: + self.region_embeds.append(nn.Sequential( + nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), + )) + self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) + for _ in range(len(kernel_sizes) + 1)]) + self.embedding_dim = embed_dim + + def forward(self, x): + x = self.embed(x) + x = x.transpose(1, 2) + # B, C, L + out = self.linears[0](x) + for conv, fc in zip(self.region_embeds, self.linears[1:]): + conv_i = conv(x) + out = out + fc(conv_i) + # B, C, L + return out + + +if __name__ == '__main__': + x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long) + model = DPCNN((10000, 300), 20) + y = model(x) + print(y.size(), y.mean(1), y.std(1)) diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index e69de29b..13ff4fc1 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -0,0 +1,80 @@ +# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +from fastNLP.core.const import Const as C +from fastNLP.core import LRScheduler +import torch.nn as nn +from fastNLP.io.dataset_loader import SSTLoader +from reproduction.text_classification.model.dpcnn import DPCNN +from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP.core.trainer import Trainer +from torch.optim import SGD +import torch.cuda +from torch.optim.lr_scheduler import CosineAnnealingLR + +##hyper +class Config(): + model_dir_or_name="en-base-uncased" + embedding_grad= False, + train_epoch= 30 + batch_size = 100 + num_classes=5 + task= "SST" + datadir = '/remote-home/yfshao/workdir/datasets/SST' + datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} + lr=1e-3 + def __init__(self): + self.datapath = {k:os.path.join(self.datadir, v) + for k, v in self.datafile.items()} + +ops=Config() + + +##1.task相关信息:利用dataloader载入dataInfo +datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds='train') + +print(len(datainfo.datasets['train'])) +print(len(datainfo.datasets['dev'])) + + +## 2.或直接复用fastNLP的模型 +vocab = datainfo.vocabs['words'] + +# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) +embedding = StaticEmbedding(vocab) +print(len(vocab)) +print(len(datainfo.vocabs['target'])) +model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) + +## 3. 声明loss,metric,optimizer +loss=CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) +metric=AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) +optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], + lr=ops.lr, momentum=0.9, weight_decay=0) + +callbacks = [] +callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' +print(device) + +for ds in datainfo.datasets.values(): + ds.apply_field(len, C.INPUT, C.INPUT_LEN) + ds.set_input(C.INPUT, C.INPUT_LEN) + ds.set_target(C.TARGET) + +## 4.定义train方法 +def train(model,datainfo,loss,metrics,optimizer,num_epochs=ops.train_epoch): + trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + metrics=[metrics], dev_data=datainfo.datasets['dev'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=num_epochs) + print(trainer.train()) + + +if __name__=="__main__": + train(model,datainfo,loss,metric,optimizer) \ No newline at end of file