[new] text_classification:charcnn,dpcnn,HAN,awd-lstmtags/v0.4.10
| @@ -5,10 +5,8 @@ from ..base_loader import DataInfo, DataSetLoader | |||||
| from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
| from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
| from ...core.instance import Instance | from ...core.instance import Instance | ||||
| from ..embed_loader import EmbeddingOption, EmbedLoader | |||||
| from ..utils import check_dataloader_paths, get_tokenizer | |||||
| spacy.prefer_gpu() | |||||
| sptk = spacy.load('en') | |||||
| class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
| URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | ||||
| @@ -37,6 +35,7 @@ class SSTLoader(DataSetLoader): | |||||
| tag_v['0'] = tag_v['1'] | tag_v['0'] = tag_v['1'] | ||||
| tag_v['4'] = tag_v['3'] | tag_v['4'] = tag_v['3'] | ||||
| self.tag_v = tag_v | self.tag_v = tag_v | ||||
| self.tokenizer = get_tokenizer() | |||||
| def _load(self, path): | def _load(self, path): | ||||
| """ | """ | ||||
| @@ -55,29 +54,37 @@ class SSTLoader(DataSetLoader): | |||||
| ds.append(Instance(words=words, target=tag)) | ds.append(Instance(words=words, target=tag)) | ||||
| return ds | return ds | ||||
| @staticmethod | |||||
| def _get_one(data, subtree): | |||||
| def _get_one(self, data, subtree): | |||||
| tree = Tree.fromstring(data) | tree = Tree.fromstring(data) | ||||
| if subtree: | if subtree: | ||||
| return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] | |||||
| return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())] | |||||
| return [([x.text for x in self.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] | |||||
| return [([x.text for x in self.tokenizer(' '.join(tree.leaves()))], tree.label())] | |||||
| def process(self, | def process(self, | ||||
| paths, | |||||
| train_ds: Iterable[str] = None, | |||||
| paths, train_subtree=True, | |||||
| src_vocab_op: VocabularyOption = None, | src_vocab_op: VocabularyOption = None, | ||||
| tgt_vocab_op: VocabularyOption = None, | |||||
| src_embed_op: EmbeddingOption = None): | |||||
| tgt_vocab_op: VocabularyOption = None,): | |||||
| paths = check_dataloader_paths(paths) | |||||
| input_name, target_name = 'words', 'target' | input_name, target_name = 'words', 'target' | ||||
| src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | ||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | tgt_vocab = Vocabulary(unknown=None, padding=None) \ | ||||
| if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | ||||
| info = DataInfo(datasets=self.load(paths)) | |||||
| _train_ds = [info.datasets[name] | |||||
| for name in train_ds] if train_ds else info.datasets.values() | |||||
| src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
| tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
| info = DataInfo() | |||||
| origin_subtree = self.subtree | |||||
| self.subtree = train_subtree | |||||
| info.datasets['train'] = self._load(paths['train']) | |||||
| self.subtree = origin_subtree | |||||
| for n, p in paths.items(): | |||||
| if n != 'train': | |||||
| info.datasets[n] = self._load(p) | |||||
| src_vocab.from_dataset( | |||||
| info.datasets['train'], | |||||
| field_name=input_name, | |||||
| no_create_entry_dataset=[ds for n, ds in info.datasets.items() if n != 'train']) | |||||
| tgt_vocab.from_dataset(info.datasets['train'], field_name=target_name) | |||||
| src_vocab.index_dataset( | src_vocab.index_dataset( | ||||
| *info.datasets.values(), | *info.datasets.values(), | ||||
| field_name=input_name, new_field_name=input_name) | field_name=input_name, new_field_name=input_name) | ||||
| @@ -89,10 +96,5 @@ class SSTLoader(DataSetLoader): | |||||
| target_name: tgt_vocab | target_name: tgt_vocab | ||||
| } | } | ||||
| if src_embed_op is not None: | |||||
| src_embed_op.vocab = src_vocab | |||||
| init_emb = EmbedLoader.load_with_vocab(**src_embed_op) | |||||
| info.embeddings[input_name] = init_emb | |||||
| return info | return info | ||||
| @@ -0,0 +1,69 @@ | |||||
| import os | |||||
| from typing import Union, Dict | |||||
| def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
| """ | |||||
| 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | |||||
| { | |||||
| 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||||
| 'test': 'xxx' # 可能有,也可能没有 | |||||
| ... | |||||
| } | |||||
| 如果paths为不合法的,将直接进行raise相应的错误 | |||||
| :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||||
| 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
| :return: | |||||
| """ | |||||
| if isinstance(paths, str): | |||||
| if os.path.isfile(paths): | |||||
| return {'train': paths} | |||||
| elif os.path.isdir(paths): | |||||
| filenames = os.listdir(paths) | |||||
| files = {} | |||||
| for filename in filenames: | |||||
| path_pair = None | |||||
| if 'train' in filename: | |||||
| path_pair = ('train', filename) | |||||
| if 'dev' in filename: | |||||
| if path_pair: | |||||
| raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||||
| path_pair = ('dev', filename) | |||||
| if 'test' in filename: | |||||
| if path_pair: | |||||
| raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||||
| path_pair = ('test', filename) | |||||
| if path_pair: | |||||
| files[path_pair[0]] = os.path.join(paths, path_pair[1]) | |||||
| return files | |||||
| else: | |||||
| raise FileNotFoundError(f"{paths} is not a valid file path.") | |||||
| elif isinstance(paths, dict): | |||||
| if paths: | |||||
| if 'train' not in paths: | |||||
| raise KeyError("You have to include `train` in your dict.") | |||||
| for key, value in paths.items(): | |||||
| if isinstance(key, str) and isinstance(value, str): | |||||
| if not os.path.isfile(value): | |||||
| raise TypeError(f"{value} is not a valid file.") | |||||
| else: | |||||
| raise TypeError("All keys and values in paths should be str.") | |||||
| return paths | |||||
| else: | |||||
| raise ValueError("Empty paths is not allowed.") | |||||
| else: | |||||
| raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | |||||
| def get_tokenizer(): | |||||
| try: | |||||
| import spacy | |||||
| spacy.prefer_gpu() | |||||
| en = spacy.load('en') | |||||
| print('use spacy tokenizer') | |||||
| return lambda x: [w.text for w in en.tokenizer(x)] | |||||
| except Exception as e: | |||||
| print('use raw tokenizer') | |||||
| return lambda x: x.split() | |||||
| @@ -19,7 +19,7 @@ class DotAttention(nn.Module): | |||||
| 补上文档 | 补上文档 | ||||
| """ | """ | ||||
| def __init__(self, key_size, value_size, dropout=0): | |||||
| def __init__(self, key_size, value_size, dropout=0.0): | |||||
| super(DotAttention, self).__init__() | super(DotAttention, self).__init__() | ||||
| self.key_size = key_size | self.key_size = key_size | ||||
| self.value_size = value_size | self.value_size = value_size | ||||
| @@ -37,7 +37,7 @@ class DotAttention(nn.Module): | |||||
| """ | """ | ||||
| output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | ||||
| if mask_out is not None: | if mask_out is not None: | ||||
| output.masked_fill_(mask_out, -1e8) | |||||
| output.masked_fill_(mask_out, -1e18) | |||||
| output = self.softmax(output) | output = self.softmax(output) | ||||
| output = self.drop(output) | output = self.drop(output) | ||||
| return torch.matmul(output, V) | return torch.matmul(output, V) | ||||
| @@ -67,9 +67,8 @@ class MultiHeadAttention(nn.Module): | |||||
| self.k_in = nn.Linear(input_size, in_size) | self.k_in = nn.Linear(input_size, in_size) | ||||
| self.v_in = nn.Linear(input_size, in_size) | self.v_in = nn.Linear(input_size, in_size) | ||||
| # follow the paper, do not apply dropout within dot-product | # follow the paper, do not apply dropout within dot-product | ||||
| self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) | |||||
| self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) | |||||
| self.out = nn.Linear(value_size * num_head, input_size) | self.out = nn.Linear(value_size * num_head, input_size) | ||||
| self.drop = TimestepDropout(dropout) | |||||
| self.reset_parameters() | self.reset_parameters() | ||||
| def reset_parameters(self): | def reset_parameters(self): | ||||
| @@ -105,7 +104,7 @@ class MultiHeadAttention(nn.Module): | |||||
| # concat all heads, do output linear | # concat all heads, do output linear | ||||
| atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | ||||
| output = self.drop(self.out(atte)) | |||||
| output = self.out(atte) | |||||
| return output | return output | ||||
| @@ -6,7 +6,7 @@ paper: [Star-Transformer](https://arxiv.org/abs/1902.09113) | |||||
| |Pos Tagging|CTB 9.0|-|ACC 92.31| | |Pos Tagging|CTB 9.0|-|ACC 92.31| | ||||
| |Pos Tagging|CONLL 2012|-|ACC 96.51| | |Pos Tagging|CONLL 2012|-|ACC 96.51| | ||||
| |Named Entity Recognition|CONLL 2012|-|F1 85.66| | |Named Entity Recognition|CONLL 2012|-|F1 85.66| | ||||
| |Text Classification|SST|-|49.18| | |||||
| |Text Classification|SST|-|51.2| | |||||
| |Natural Language Inference|SNLI|-|83.76| | |Natural Language Inference|SNLI|-|83.76| | ||||
| ## Usage | ## Usage | ||||
| @@ -0,0 +1,142 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from fastNLP.modules.decoder import ConditionalRandomField | |||||
| from fastNLP.modules.encoder import Embedding | |||||
| from fastNLP.core.utils import seq_len_to_mask | |||||
| from fastNLP.core.const import Const as C | |||||
| class IDCNN(nn.Module): | |||||
| def __init__(self, | |||||
| init_embed, | |||||
| char_embed, | |||||
| num_cls, | |||||
| repeats, num_layers, num_filters, kernel_size, | |||||
| use_crf=False, use_projection=False, block_loss=False, | |||||
| input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): | |||||
| super(IDCNN, self).__init__() | |||||
| self.word_embeddings = Embedding(init_embed) | |||||
| if char_embed is None: | |||||
| self.char_embeddings = None | |||||
| embedding_size = self.word_embeddings.embedding_dim | |||||
| else: | |||||
| self.char_embeddings = Embedding(char_embed) | |||||
| embedding_size = self.word_embeddings.embedding_dim + \ | |||||
| self.char_embeddings.embedding_dim | |||||
| self.conv0 = nn.Sequential( | |||||
| nn.Conv1d(in_channels=embedding_size, | |||||
| out_channels=num_filters, | |||||
| kernel_size=kernel_size, | |||||
| stride=1, dilation=1, | |||||
| padding=kernel_size//2, | |||||
| bias=True), | |||||
| nn.ReLU(), | |||||
| ) | |||||
| block = [] | |||||
| for layer_i in range(num_layers): | |||||
| dilated = 2 ** layer_i if layer_i+1 < num_layers else 1 | |||||
| block.append(nn.Conv1d( | |||||
| in_channels=num_filters, | |||||
| out_channels=num_filters, | |||||
| kernel_size=kernel_size, | |||||
| stride=1, dilation=dilated, | |||||
| padding=(kernel_size//2) * dilated, | |||||
| bias=True)) | |||||
| block.append(nn.ReLU()) | |||||
| self.block = nn.Sequential(*block) | |||||
| if use_projection: | |||||
| self.projection = nn.Sequential( | |||||
| nn.Conv1d( | |||||
| in_channels=num_filters, | |||||
| out_channels=num_filters//2, | |||||
| kernel_size=1, | |||||
| bias=True), | |||||
| nn.ReLU(),) | |||||
| encode_dim = num_filters // 2 | |||||
| else: | |||||
| self.projection = None | |||||
| encode_dim = num_filters | |||||
| self.input_drop = nn.Dropout(input_dropout) | |||||
| self.hidden_drop = nn.Dropout(hidden_dropout) | |||||
| self.inner_drop = nn.Dropout(inner_dropout) | |||||
| self.repeats = repeats | |||||
| self.out_fc = nn.Conv1d( | |||||
| in_channels=encode_dim, | |||||
| out_channels=num_cls, | |||||
| kernel_size=1, | |||||
| bias=True) | |||||
| self.crf = ConditionalRandomField( | |||||
| num_tags=num_cls) if use_crf else None | |||||
| self.block_loss = block_loss | |||||
| self.reset_parameters() | |||||
| def reset_parameters(self): | |||||
| for m in self.modules(): | |||||
| if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |||||
| nn.init.xavier_normal_(m.weight, gain=1) | |||||
| if m.bias is not None: | |||||
| nn.init.normal_(m.bias, mean=0, std=0.01) | |||||
| def forward(self, words, seq_len, target=None, chars=None): | |||||
| if self.char_embeddings is None: | |||||
| x = self.word_embeddings(words) | |||||
| else: | |||||
| if chars is None: | |||||
| raise ValueError('must provide chars for model with char embedding') | |||||
| e1 = self.word_embeddings(words) | |||||
| e2 = self.char_embeddings(chars) | |||||
| x = torch.cat((e1, e2), dim=-1) # b,l,h | |||||
| mask = seq_len_to_mask(seq_len) | |||||
| x = x.transpose(1, 2) # b,h,l | |||||
| last_output = self.conv0(x) | |||||
| output = [] | |||||
| for repeat in range(self.repeats): | |||||
| last_output = self.block(last_output) | |||||
| hidden = self.projection(last_output) if self.projection is not None else last_output | |||||
| output.append(self.out_fc(hidden)) | |||||
| def compute_loss(y, t, mask): | |||||
| if self.crf is not None and target is not None: | |||||
| loss = self.crf(y.transpose(1, 2), t, mask) | |||||
| else: | |||||
| t.masked_fill_(mask == 0, -100) | |||||
| loss = F.cross_entropy(y, t, ignore_index=-100) | |||||
| return loss | |||||
| if target is not None: | |||||
| if self.block_loss: | |||||
| losses = [compute_loss(o, target, mask) for o in output] | |||||
| loss = sum(losses) | |||||
| else: | |||||
| loss = compute_loss(output[-1], target, mask) | |||||
| else: | |||||
| loss = None | |||||
| scores = output[-1] | |||||
| if self.crf is not None: | |||||
| pred, _ = self.crf.viterbi_decode(scores.transpose(1, 2), mask) | |||||
| else: | |||||
| pred = scores.max(1)[1] * mask.long() | |||||
| return { | |||||
| C.LOSS: loss, | |||||
| C.OUTPUT: pred, | |||||
| } | |||||
| def predict(self, words, seq_len, chars=None): | |||||
| res = self.forward( | |||||
| words=words, | |||||
| seq_len=seq_len, | |||||
| chars=chars, | |||||
| target=None | |||||
| )[C.OUTPUT] | |||||
| return { | |||||
| C.OUTPUT: res | |||||
| } | |||||
| @@ -0,0 +1,99 @@ | |||||
| from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader | |||||
| from fastNLP.core.callback import FitlogCallback, LRScheduler | |||||
| from fastNLP import GradientClipCallback | |||||
| from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR | |||||
| from torch.optim import SGD, Adam | |||||
| from fastNLP import Const | |||||
| from fastNLP import RandomSampler, BucketSampler | |||||
| from fastNLP import SpanFPreRecMetric | |||||
| from fastNLP import Trainer | |||||
| from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN | |||||
| from fastNLP.core.utils import Option | |||||
| from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding | |||||
| from fastNLP.core.utils import cache_results | |||||
| import sys | |||||
| import torch.cuda | |||||
| import os | |||||
| os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
| os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||||
| encoding_type = 'bioes' | |||||
| def get_path(path): | |||||
| return os.path.join(os.environ['HOME'], path) | |||||
| data_path = get_path('workdir/datasets/ontonotes-v4') | |||||
| ops = Option( | |||||
| batch_size=128, | |||||
| num_epochs=100, | |||||
| lr=3e-4, | |||||
| repeats=3, | |||||
| num_layers=3, | |||||
| num_filters=400, | |||||
| use_crf=True, | |||||
| gradient_clip=5, | |||||
| ) | |||||
| @cache_results('ontonotes-cache') | |||||
| def load_data(): | |||||
| data = OntoNoteNERDataLoader(encoding_type=encoding_type).process(data_path, | |||||
| lower=True) | |||||
| # char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], | |||||
| # kernel_sizes=[3]) | |||||
| word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | |||||
| model_dir_or_name='en-glove-840b-300', | |||||
| requires_grad=True) | |||||
| return data, [word_embed] | |||||
| data, embeds = load_data() | |||||
| print(data.datasets['train'][0]) | |||||
| print(list(data.vocabs.keys())) | |||||
| for ds in data.datasets.values(): | |||||
| ds.rename_field('cap_words', 'chars') | |||||
| ds.set_input('chars') | |||||
| word_embed = embeds[0] | |||||
| char_embed = CNNCharEmbedding(data.vocabs['cap_words']) | |||||
| # for ds in data.datasets: | |||||
| # ds.rename_field('') | |||||
| print(data.vocabs[Const.TARGET].word2idx) | |||||
| model = IDCNN(init_embed=word_embed, | |||||
| char_embed=char_embed, | |||||
| num_cls=len(data.vocabs[Const.TARGET]), | |||||
| repeats=ops.repeats, | |||||
| num_layers=ops.num_layers, | |||||
| num_filters=ops.num_filters, | |||||
| kernel_size=3, | |||||
| use_crf=ops.use_crf, use_projection=True, | |||||
| block_loss=True, | |||||
| input_dropout=0.33, hidden_dropout=0.2, inner_dropout=0.2) | |||||
| print(model) | |||||
| callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='norm'),] | |||||
| optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=0) | |||||
| # scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) | |||||
| # callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 15))) | |||||
| # optimizer = SWATS(model.parameters(), verbose=True) | |||||
| # optimizer = Adam(model.parameters(), lr=0.005) | |||||
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||||
| trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, | |||||
| sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||||
| device=device, dev_data=data.datasets['dev'], batch_size=ops.batch_size, | |||||
| metrics=SpanFPreRecMetric( | |||||
| tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), | |||||
| check_code_level=-1, | |||||
| callbacks=callbacks, num_workers=2, n_epochs=ops.num_epochs) | |||||
| trainer.train() | |||||
| @@ -0,0 +1,26 @@ | |||||
| # text_classification任务模型复现 | |||||
| 这里使用fastNLP复现以下模型: | |||||
| char_cnn :论文链接[Character-level Convolutional Networks for Text Classification](https://arxiv.org/pdf/1509.01626v3.pdf) | |||||
| dpcnn:论文链接[Deep Pyramid Convolutional Neural Networks for TextCategorization](https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf) | |||||
| HAN:论文链接[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) | |||||
| LSTM+self_attention:论文链接[A Structured Self-attentive Sentence Embedding](<https://arxiv.org/pdf/1703.03130.pdf>) | |||||
| AWD-LSTM:论文链接[Regularizing and Optimizing LSTM Language Models](<https://arxiv.org/pdf/1708.02182.pdf>) | |||||
| # 数据集及复现结果汇总 | |||||
| 使用fastNLP复现的结果vs论文汇报结果(/前为fastNLP实现,后面为论文报道,-表示论文没有在该数据集上列出结果) | |||||
| model name | yelp_p | yelp_f | sst-2|IMDB | |||||
| :---: | :---: | :---: | :---: |----- | |||||
| char_cnn | 93.80/95.12 | - | - |- | |||||
| dpcnn | 95.50/97.36 | - | - |- | |||||
| HAN |- | - | - |- | |||||
| LSTM| 95.74/- |- |- |88.52/- | |||||
| AWD-LSTM| 95.96/- |- |- |88.91/- | |||||
| LSTM+self_attention| 96.34/- | - | - |89.53/- | |||||
| @@ -0,0 +1,110 @@ | |||||
| from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||||
| from fastNLP.core.vocabulary import VocabularyOption | |||||
| from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||||
| from typing import Union, Dict, List, Iterator | |||||
| from fastNLP import DataSet | |||||
| from fastNLP import Instance | |||||
| from fastNLP import Vocabulary | |||||
| from fastNLP import Const | |||||
| # from reproduction.utils import check_dataloader_paths | |||||
| from functools import partial | |||||
| class IMDBLoader(DataSetLoader): | |||||
| """ | |||||
| 读取IMDB数据集,DataSet包含以下fields: | |||||
| words: list(str), 需要分类的文本 | |||||
| target: str, 文本的标签 | |||||
| """ | |||||
| def __init__(self): | |||||
| super(IMDBLoader, self).__init__() | |||||
| def _load(self, path): | |||||
| dataset = DataSet() | |||||
| with open(path, 'r', encoding="utf-8") as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if not line: | |||||
| continue | |||||
| parts = line.split('\t') | |||||
| target = parts[0] | |||||
| words = parts[1].lower().split() | |||||
| dataset.append(Instance(words=words, target=target)) | |||||
| if len(dataset)==0: | |||||
| raise RuntimeError(f"{path} has no valid data.") | |||||
| return dataset | |||||
| def process(self, | |||||
| paths: Union[str, Dict[str, str]], | |||||
| src_vocab_opt: VocabularyOption = None, | |||||
| tgt_vocab_opt: VocabularyOption = None, | |||||
| src_embed_opt: EmbeddingOption = None, | |||||
| char_level_op=False): | |||||
| datasets = {} | |||||
| info = DataInfo() | |||||
| for name, path in paths.items(): | |||||
| dataset = self.load(path) | |||||
| datasets[name] = dataset | |||||
| def wordtochar(words): | |||||
| chars = [] | |||||
| for word in words: | |||||
| word = word.lower() | |||||
| for char in word: | |||||
| chars.append(char) | |||||
| return chars | |||||
| if char_level_op: | |||||
| for dataset in datasets.values(): | |||||
| dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||||
| datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
| src_vocab.from_dataset(datasets['train'], field_name='words') | |||||
| src_vocab.index_dataset(*datasets.values(), field_name='words') | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
| tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||||
| tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
| info.vocabs = { | |||||
| "words": src_vocab, | |||||
| "target": tgt_vocab | |||||
| } | |||||
| info.datasets = datasets | |||||
| if src_embed_opt is not None: | |||||
| embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) | |||||
| info.embeddings['words'] = embed | |||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input("words") | |||||
| dataset.set_target("target") | |||||
| return info | |||||
| if __name__=="__main__": | |||||
| datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv", | |||||
| "test": "/remote-home/ygwang/IMDB_data/test.csv"} | |||||
| datainfo=IMDBLoader().process(datapath,char_level_op=True) | |||||
| #print(datainfo.datasets["train"]) | |||||
| len_count = 0 | |||||
| for instance in datainfo.datasets["train"]: | |||||
| len_count += len(instance["chars"]) | |||||
| ave_len = len_count / len(datainfo.datasets["train"]) | |||||
| print(ave_len) | |||||
| @@ -32,7 +32,7 @@ class MTL16Loader(DataSetLoader): | |||||
| continue | continue | ||||
| parts = line.split('\t') | parts = line.split('\t') | ||||
| target = parts[0] | target = parts[0] | ||||
| words = parts[1].split() | |||||
| words = parts[1].lower().split() | |||||
| dataset.append(Instance(words=words, target=target)) | dataset.append(Instance(words=words, target=target)) | ||||
| if len(dataset)==0: | if len(dataset)==0: | ||||
| raise RuntimeError(f"{path} has no valid data.") | raise RuntimeError(f"{path} has no valid data.") | ||||
| @@ -72,4 +72,8 @@ class MTL16Loader(DataSetLoader): | |||||
| embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) | embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) | ||||
| info.embeddings['words'] = embed | info.embeddings['words'] = embed | ||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input("words") | |||||
| dataset.set_target("target") | |||||
| return info | return info | ||||
| @@ -0,0 +1,187 @@ | |||||
| from typing import Iterable | |||||
| from nltk import Tree | |||||
| from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||||
| from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | |||||
| from fastNLP import DataSet | |||||
| from fastNLP import Instance | |||||
| from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||||
| import csv | |||||
| from typing import Union, Dict | |||||
| class SSTLoader(DataSetLoader): | |||||
| URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
| DATA_DIR = 'sst/' | |||||
| """ | |||||
| 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||||
| 读取SST数据集, DataSet包含fields:: | |||||
| words: list(str) 需要分类的文本 | |||||
| target: str 文本的标签 | |||||
| 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
| :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||||
| :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
| """ | |||||
| def __init__(self, subtree=False, fine_grained=False): | |||||
| self.subtree = subtree | |||||
| tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||||
| '3': 'positive', '4': 'very positive'} | |||||
| if not fine_grained: | |||||
| tag_v['0'] = tag_v['1'] | |||||
| tag_v['4'] = tag_v['3'] | |||||
| self.tag_v = tag_v | |||||
| def _load(self, path): | |||||
| """ | |||||
| :param str path: 存储数据的路径 | |||||
| :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
| """ | |||||
| datalist = [] | |||||
| with open(path, 'r', encoding='utf-8') as f: | |||||
| datas = [] | |||||
| for l in f: | |||||
| datas.extend([(s, self.tag_v[t]) | |||||
| for s, t in self._get_one(l, self.subtree)]) | |||||
| ds = DataSet() | |||||
| for words, tag in datas: | |||||
| ds.append(Instance(words=words, target=tag)) | |||||
| return ds | |||||
| @staticmethod | |||||
| def _get_one(data, subtree): | |||||
| tree = Tree.fromstring(data) | |||||
| if subtree: | |||||
| return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||||
| return [(tree.leaves(), tree.label())] | |||||
| def process(self, | |||||
| paths, | |||||
| train_ds: Iterable[str] = None, | |||||
| src_vocab_op: VocabularyOption = None, | |||||
| tgt_vocab_op: VocabularyOption = None, | |||||
| src_embed_op: EmbeddingOption = None): | |||||
| input_name, target_name = 'words', 'target' | |||||
| src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
| info = DataInfo(datasets=self.load(paths)) | |||||
| _train_ds = [info.datasets[name] | |||||
| for name in train_ds] if train_ds else info.datasets.values() | |||||
| src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
| tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
| src_vocab.index_dataset( | |||||
| *info.datasets.values(), | |||||
| field_name=input_name, new_field_name=input_name) | |||||
| tgt_vocab.index_dataset( | |||||
| *info.datasets.values(), | |||||
| field_name=target_name, new_field_name=target_name) | |||||
| info.vocabs = { | |||||
| input_name: src_vocab, | |||||
| target_name: tgt_vocab | |||||
| } | |||||
| if src_embed_op is not None: | |||||
| src_embed_op.vocab = src_vocab | |||||
| init_emb = EmbedLoader.load_with_vocab(**src_embed_op) | |||||
| info.embeddings[input_name] = init_emb | |||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input(input_name) | |||||
| dataset.set_target(target_name) | |||||
| return info | |||||
| class sst2Loader(DataSetLoader): | |||||
| ''' | |||||
| 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||||
| ''' | |||||
| def __init__(self): | |||||
| super(sst2Loader, self).__init__() | |||||
| def _load(self, path: str) -> DataSet: | |||||
| ds = DataSet() | |||||
| all_count=0 | |||||
| csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') | |||||
| skip_row = 0 | |||||
| for idx,row in enumerate(csv_reader): | |||||
| if idx<=skip_row: | |||||
| continue | |||||
| target = row[1] | |||||
| words = row[0].split() | |||||
| ds.append(Instance(words=words,target=target)) | |||||
| all_count+=1 | |||||
| print("all count:", all_count) | |||||
| return ds | |||||
| def process(self, | |||||
| paths: Union[str, Dict[str, str]], | |||||
| src_vocab_opt: VocabularyOption = None, | |||||
| tgt_vocab_opt: VocabularyOption = None, | |||||
| src_embed_opt: EmbeddingOption = None, | |||||
| char_level_op=False): | |||||
| paths = check_dataloader_paths(paths) | |||||
| datasets = {} | |||||
| info = DataInfo() | |||||
| for name, path in paths.items(): | |||||
| dataset = self.load(path) | |||||
| datasets[name] = dataset | |||||
| def wordtochar(words): | |||||
| chars=[] | |||||
| for word in words: | |||||
| word=word.lower() | |||||
| for char in word: | |||||
| chars.append(char) | |||||
| return chars | |||||
| input_name, target_name = 'words', 'target' | |||||
| info.vocabs={} | |||||
| # 就分隔为char形式 | |||||
| if char_level_op: | |||||
| for dataset in datasets.values(): | |||||
| dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
| src_vocab.from_dataset(datasets['train'], field_name='words') | |||||
| src_vocab.index_dataset(*datasets.values(), field_name='words') | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
| tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||||
| tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
| info.vocabs = { | |||||
| "words": src_vocab, | |||||
| "target": tgt_vocab | |||||
| } | |||||
| info.datasets = datasets | |||||
| if src_embed_opt is not None: | |||||
| embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) | |||||
| info.embeddings['words'] = embed | |||||
| return info | |||||
| if __name__=="__main__": | |||||
| datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", | |||||
| "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} | |||||
| datainfo=sst2Loader().process(datapath,char_level_op=True) | |||||
| #print(datainfo.datasets["train"]) | |||||
| len_count = 0 | |||||
| for instance in datainfo.datasets["train"]: | |||||
| len_count += len(instance["chars"]) | |||||
| ave_len = len_count / len(datainfo.datasets["train"]) | |||||
| print(ave_len) | |||||
| @@ -0,0 +1,187 @@ | |||||
| from typing import Iterable | |||||
| from nltk import Tree | |||||
| from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||||
| from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | |||||
| from fastNLP import DataSet | |||||
| from fastNLP import Instance | |||||
| from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||||
| import csv | |||||
| from typing import Union, Dict | |||||
| class SSTLoader(DataSetLoader): | |||||
| URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
| DATA_DIR = 'sst/' | |||||
| """ | |||||
| 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||||
| 读取SST数据集, DataSet包含fields:: | |||||
| words: list(str) 需要分类的文本 | |||||
| target: str 文本的标签 | |||||
| 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
| :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||||
| :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
| """ | |||||
| def __init__(self, subtree=False, fine_grained=False): | |||||
| self.subtree = subtree | |||||
| tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||||
| '3': 'positive', '4': 'very positive'} | |||||
| if not fine_grained: | |||||
| tag_v['0'] = tag_v['1'] | |||||
| tag_v['4'] = tag_v['3'] | |||||
| self.tag_v = tag_v | |||||
| def _load(self, path): | |||||
| """ | |||||
| :param str path: 存储数据的路径 | |||||
| :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
| """ | |||||
| datalist = [] | |||||
| with open(path, 'r', encoding='utf-8') as f: | |||||
| datas = [] | |||||
| for l in f: | |||||
| datas.extend([(s, self.tag_v[t]) | |||||
| for s, t in self._get_one(l, self.subtree)]) | |||||
| ds = DataSet() | |||||
| for words, tag in datas: | |||||
| ds.append(Instance(words=words, target=tag)) | |||||
| return ds | |||||
| @staticmethod | |||||
| def _get_one(data, subtree): | |||||
| tree = Tree.fromstring(data) | |||||
| if subtree: | |||||
| return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||||
| return [(tree.leaves(), tree.label())] | |||||
| def process(self, | |||||
| paths, | |||||
| train_ds: Iterable[str] = None, | |||||
| src_vocab_op: VocabularyOption = None, | |||||
| tgt_vocab_op: VocabularyOption = None, | |||||
| src_embed_op: EmbeddingOption = None): | |||||
| input_name, target_name = 'words', 'target' | |||||
| src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
| info = DataInfo(datasets=self.load(paths)) | |||||
| _train_ds = [info.datasets[name] | |||||
| for name in train_ds] if train_ds else info.datasets.values() | |||||
| src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
| tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
| src_vocab.index_dataset( | |||||
| *info.datasets.values(), | |||||
| field_name=input_name, new_field_name=input_name) | |||||
| tgt_vocab.index_dataset( | |||||
| *info.datasets.values(), | |||||
| field_name=target_name, new_field_name=target_name) | |||||
| info.vocabs = { | |||||
| input_name: src_vocab, | |||||
| target_name: tgt_vocab | |||||
| } | |||||
| if src_embed_op is not None: | |||||
| src_embed_op.vocab = src_vocab | |||||
| init_emb = EmbedLoader.load_with_vocab(**src_embed_op) | |||||
| info.embeddings[input_name] = init_emb | |||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input(input_name) | |||||
| dataset.set_target(target_name) | |||||
| return info | |||||
| class sst2Loader(DataSetLoader): | |||||
| ''' | |||||
| 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||||
| ''' | |||||
| def __init__(self): | |||||
| super(sst2Loader, self).__init__() | |||||
| def _load(self, path: str) -> DataSet: | |||||
| ds = DataSet() | |||||
| all_count=0 | |||||
| csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') | |||||
| skip_row = 0 | |||||
| for idx,row in enumerate(csv_reader): | |||||
| if idx<=skip_row: | |||||
| continue | |||||
| target = row[1] | |||||
| words = row[0].split() | |||||
| ds.append(Instance(words=words,target=target)) | |||||
| all_count+=1 | |||||
| print("all count:", all_count) | |||||
| return ds | |||||
| def process(self, | |||||
| paths: Union[str, Dict[str, str]], | |||||
| src_vocab_opt: VocabularyOption = None, | |||||
| tgt_vocab_opt: VocabularyOption = None, | |||||
| src_embed_opt: EmbeddingOption = None, | |||||
| char_level_op=False): | |||||
| paths = check_dataloader_paths(paths) | |||||
| datasets = {} | |||||
| info = DataInfo() | |||||
| for name, path in paths.items(): | |||||
| dataset = self.load(path) | |||||
| datasets[name] = dataset | |||||
| def wordtochar(words): | |||||
| chars=[] | |||||
| for word in words: | |||||
| word=word.lower() | |||||
| for char in word: | |||||
| chars.append(char) | |||||
| return chars | |||||
| input_name, target_name = 'words', 'target' | |||||
| info.vocabs={} | |||||
| # 就分隔为char形式 | |||||
| if char_level_op: | |||||
| for dataset in datasets.values(): | |||||
| dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||||
| src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
| src_vocab.from_dataset(datasets['train'], field_name='words') | |||||
| src_vocab.index_dataset(*datasets.values(), field_name='words') | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
| tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||||
| tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
| info.vocabs = { | |||||
| "words": src_vocab, | |||||
| "target": tgt_vocab | |||||
| } | |||||
| info.datasets = datasets | |||||
| if src_embed_opt is not None: | |||||
| embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) | |||||
| info.embeddings['words'] = embed | |||||
| return info | |||||
| if __name__=="__main__": | |||||
| datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", | |||||
| "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} | |||||
| datainfo=sst2Loader().process(datapath,char_level_op=True) | |||||
| #print(datainfo.datasets["train"]) | |||||
| len_count = 0 | |||||
| for instance in datainfo.datasets["train"]: | |||||
| len_count += len(instance["chars"]) | |||||
| ave_len = len_count / len(datainfo.datasets["train"]) | |||||
| print(ave_len) | |||||
| @@ -1,18 +1,64 @@ | |||||
| import ast | import ast | ||||
| import csv | |||||
| from typing import Iterable | |||||
| from fastNLP import DataSet, Instance, Vocabulary | from fastNLP import DataSet, Instance, Vocabulary | ||||
| from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
| from fastNLP.io import JsonLoader | from fastNLP.io import JsonLoader | ||||
| from fastNLP.io.base_loader import DataInfo | |||||
| from fastNLP.io.base_loader import DataInfo,DataSetLoader | |||||
| from fastNLP.io.embed_loader import EmbeddingOption | from fastNLP.io.embed_loader import EmbeddingOption | ||||
| from fastNLP.io.file_reader import _read_json | from fastNLP.io.file_reader import _read_json | ||||
| from typing import Union, Dict | from typing import Union, Dict | ||||
| from reproduction.Star_transformer.datasets import EmbedLoader | |||||
| from reproduction.utils import check_dataloader_paths | |||||
| from reproduction.utils import check_dataloader_paths, get_tokenizer | |||||
| def clean_str(sentence, tokenizer, char_lower=False): | |||||
| """ | |||||
| heavily borrowed from github | |||||
| https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||||
| :param sentence: is a str | |||||
| :return: | |||||
| """ | |||||
| if char_lower: | |||||
| sentence = sentence.lower() | |||||
| import re | |||||
| nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||||
| words = tokenizer(sentence) | |||||
| words_collection = [] | |||||
| for word in words: | |||||
| if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||||
| continue | |||||
| tt = nonalpnum.split(word) | |||||
| t = ''.join(tt) | |||||
| if t != '': | |||||
| words_collection.append(t) | |||||
| return words_collection | |||||
| class yelpLoader(JsonLoader): | |||||
| class yelpLoader(DataSetLoader): | |||||
| """ | """ | ||||
| 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: | |||||
| words: list(str), 需要分类的文本 | |||||
| target: str, 文本的标签 | |||||
| chars:list(str),未index的字符列表 | |||||
| 数据集:yelp_full/yelp_polarity | |||||
| :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
| """ | |||||
| def __init__(self, fine_grained=False,lower=False): | |||||
| super(yelpLoader, self).__init__() | |||||
| tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||||
| '4.0': 'positive', '5.0': 'very positive'} | |||||
| if not fine_grained: | |||||
| tag_v['1.0'] = tag_v['2.0'] | |||||
| tag_v['5.0'] = tag_v['4.0'] | |||||
| self.fine_grained = fine_grained | |||||
| self.tag_v = tag_v | |||||
| self.lower = lower | |||||
| self.tokenizer = get_tokenizer() | |||||
| ''' | |||||
| 读取Yelp数据集, DataSet包含fields: | 读取Yelp数据集, DataSet包含fields: | ||||
| review_id: str, 22 character unique review id | review_id: str, 22 character unique review id | ||||
| @@ -27,20 +73,8 @@ class yelpLoader(JsonLoader): | |||||
| 数据来源: https://www.yelp.com/dataset/download | 数据来源: https://www.yelp.com/dataset/download | ||||
| :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
| """ | |||||
| def __init__(self, fine_grained=False): | |||||
| super(yelpLoader, self).__init__() | |||||
| tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||||
| '4.0': 'positive', '5.0': 'very positive'} | |||||
| if not fine_grained: | |||||
| tag_v['1.0'] = tag_v['2.0'] | |||||
| tag_v['5.0'] = tag_v['4.0'] | |||||
| self.fine_grained = fine_grained | |||||
| self.tag_v = tag_v | |||||
| def _load(self, path): | |||||
| def _load_json(self, path): | |||||
| ds = DataSet() | ds = DataSet() | ||||
| for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | ||||
| d = ast.literal_eval(d) | d = ast.literal_eval(d) | ||||
| @@ -48,21 +82,116 @@ class yelpLoader(JsonLoader): | |||||
| d["target"] = self.tag_v[str(d.pop("stars"))] | d["target"] = self.tag_v[str(d.pop("stars"))] | ||||
| ds.append(Instance(**d)) | ds.append(Instance(**d)) | ||||
| return ds | return ds | ||||
| def _load_yelp2015_broken(self,path): | |||||
| ds = DataSet() | |||||
| with open (path,encoding='ISO 8859-1') as f: | |||||
| row=f.readline() | |||||
| all_count=0 | |||||
| exp_count=0 | |||||
| while row: | |||||
| row=row.split("\t\t") | |||||
| all_count+=1 | |||||
| if len(row)>=3: | |||||
| words=row[-1].split() | |||||
| try: | |||||
| target=self.tag_v[str(row[-2])+".0"] | |||||
| ds.append(Instance(words=words, target=target)) | |||||
| except KeyError: | |||||
| exp_count+=1 | |||||
| else: | |||||
| exp_count+=1 | |||||
| row = f.readline() | |||||
| print("error sample count:",exp_count) | |||||
| print("all count:",all_count) | |||||
| return ds | |||||
| ''' | |||||
| def _load(self, path): | |||||
| ds = DataSet() | |||||
| csv_reader=csv.reader(open(path,encoding='utf-8')) | |||||
| all_count=0 | |||||
| real_count=0 | |||||
| for row in csv_reader: | |||||
| all_count+=1 | |||||
| if len(row)==2: | |||||
| target=self.tag_v[row[0]+".0"] | |||||
| words = clean_str(row[1], self.tokenizer, self.lower) | |||||
| if len(words)!=0: | |||||
| ds.append(Instance(words=words,target=target)) | |||||
| real_count += 1 | |||||
| print("all count:", all_count) | |||||
| print("real count:", real_count) | |||||
| return ds | |||||
| def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, | |||||
| embed_opt: EmbeddingOption = None): | |||||
| def process(self, paths: Union[str, Dict[str, str]], | |||||
| train_ds: Iterable[str] = None, | |||||
| src_vocab_op: VocabularyOption = None, | |||||
| tgt_vocab_op: VocabularyOption = None, | |||||
| embed_opt: EmbeddingOption = None, | |||||
| char_level_op=False): | |||||
| paths = check_dataloader_paths(paths) | paths = check_dataloader_paths(paths) | ||||
| datasets = {} | datasets = {} | ||||
| info = DataInfo() | |||||
| vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) | |||||
| for name, path in paths.items(): | |||||
| dataset = self.load(path) | |||||
| datasets[name] = dataset | |||||
| vocab.from_dataset(dataset, field_name="words") | |||||
| info.vocabs = vocab | |||||
| info.datasets = datasets | |||||
| if embed_opt is not None: | |||||
| embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) | |||||
| info.embeddings['words'] = embed | |||||
| info = DataInfo(datasets=self.load(paths)) | |||||
| src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
| tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
| if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
| _train_ds = [info.datasets[name] | |||||
| for name in train_ds] if train_ds else info.datasets.values() | |||||
| def wordtochar(words): | |||||
| chars=[] | |||||
| for word in words: | |||||
| word=word.lower() | |||||
| for char in word: | |||||
| chars.append(char) | |||||
| return chars | |||||
| input_name, target_name = 'words', 'target' | |||||
| info.vocabs={} | |||||
| #就分隔为char形式 | |||||
| if char_level_op: | |||||
| for dataset in info.datasets.values(): | |||||
| dataset.apply_field(wordtochar, field_name="words",new_field_name='chars') | |||||
| # if embed_opt is not None: | |||||
| # embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) | |||||
| # info.embeddings['words'] = embed | |||||
| else: | |||||
| src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
| src_vocab.index_dataset(*info.datasets.values(),field_name=input_name, new_field_name=input_name) | |||||
| info.vocabs[input_name]=src_vocab | |||||
| tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
| tgt_vocab.index_dataset( | |||||
| *info.datasets.values(), | |||||
| field_name=target_name, new_field_name=target_name) | |||||
| info.vocabs[target_name]=tgt_vocab | |||||
| info.datasets['train'],info.datasets['dev']=info.datasets['train'].split(0.1, shuffle=False) | |||||
| for name, dataset in info.datasets.items(): | |||||
| dataset.set_input("words") | |||||
| dataset.set_target("target") | |||||
| return info | return info | ||||
| if __name__=="__main__": | |||||
| testloader=yelpLoader() | |||||
| # datapath = {"train": "/remote-home/ygwang/yelp_full/train.csv", | |||||
| # "test": "/remote-home/ygwang/yelp_full/test.csv"} | |||||
| #datapath={"train": "/remote-home/ygwang/yelp_full/test.csv"} | |||||
| datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", | |||||
| "test": "/remote-home/ygwang/yelp_polarity/test.csv"} | |||||
| datainfo=testloader.process(datapath,char_level_op=True) | |||||
| len_count=0 | |||||
| for instance in datainfo.datasets["train"]: | |||||
| len_count+=len(instance["chars"]) | |||||
| ave_len=len_count/len(datainfo.datasets["train"]) | |||||
| print(ave_len) | |||||
| @@ -0,0 +1,109 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch.autograd import Variable | |||||
| from fastNLP.modules.utils import get_embeddings | |||||
| from fastNLP.core import Const as C | |||||
| def pack_sequence(tensor_seq, padding_value=0.0): | |||||
| if len(tensor_seq) <= 0: | |||||
| return | |||||
| length = [v.size(0) for v in tensor_seq] | |||||
| max_len = max(length) | |||||
| size = [len(tensor_seq), max_len] | |||||
| size.extend(list(tensor_seq[0].size()[1:])) | |||||
| ans = torch.Tensor(*size).fill_(padding_value) | |||||
| if tensor_seq[0].data.is_cuda: | |||||
| ans = ans.cuda() | |||||
| ans = Variable(ans) | |||||
| for i, v in enumerate(tensor_seq): | |||||
| ans[i, :length[i], :] = v | |||||
| return ans | |||||
| class HANCLS(nn.Module): | |||||
| def __init__(self, init_embed, num_cls): | |||||
| super(HANCLS, self).__init__() | |||||
| self.embed = get_embeddings(init_embed) | |||||
| self.han = HAN(input_size=300, | |||||
| output_size=num_cls, | |||||
| word_hidden_size=50, word_num_layers=1, word_context_size=100, | |||||
| sent_hidden_size=50, sent_num_layers=1, sent_context_size=100 | |||||
| ) | |||||
| def forward(self, input_sents): | |||||
| # input_sents [B, num_sents, seq-len] dtype long | |||||
| # target | |||||
| B, num_sents, seq_len = input_sents.size() | |||||
| input_sents = input_sents.view(-1, seq_len) # flat | |||||
| words_embed = self.embed(input_sents) # should be [B*num-sent, seqlen , word-dim] | |||||
| words_embed = words_embed.view(B, num_sents, seq_len, -1) # recover # [B, num-sent, seqlen , word-dim] | |||||
| out = self.han(words_embed) | |||||
| return {C.OUTPUT: out} | |||||
| def predict(self, input_sents): | |||||
| x = self.forward(input_sents)[C.OUTPUT] | |||||
| return {C.OUTPUT: torch.argmax(x, 1)} | |||||
| class HAN(nn.Module): | |||||
| def __init__(self, input_size, output_size, | |||||
| word_hidden_size, word_num_layers, word_context_size, | |||||
| sent_hidden_size, sent_num_layers, sent_context_size): | |||||
| super(HAN, self).__init__() | |||||
| self.word_layer = AttentionNet(input_size, | |||||
| word_hidden_size, | |||||
| word_num_layers, | |||||
| word_context_size) | |||||
| self.sent_layer = AttentionNet(2 * word_hidden_size, | |||||
| sent_hidden_size, | |||||
| sent_num_layers, | |||||
| sent_context_size) | |||||
| self.output_layer = nn.Linear(2 * sent_hidden_size, output_size) | |||||
| self.softmax = nn.LogSoftmax(dim=1) | |||||
| def forward(self, batch_doc): | |||||
| # input is a sequence of matrix | |||||
| doc_vec_list = [] | |||||
| for doc in batch_doc: | |||||
| sent_mat = self.word_layer(doc) # doc's dim (num_sent, seq_len, word_dim) | |||||
| doc_vec_list.append(sent_mat) # sent_mat's dim (num_sent, vec_dim) | |||||
| doc_vec = self.sent_layer(pack_sequence(doc_vec_list)) | |||||
| output = self.softmax(self.output_layer(doc_vec)) | |||||
| return output | |||||
| class AttentionNet(nn.Module): | |||||
| def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size): | |||||
| super(AttentionNet, self).__init__() | |||||
| self.input_size = input_size | |||||
| self.gru_hidden_size = gru_hidden_size | |||||
| self.gru_num_layers = gru_num_layers | |||||
| self.context_vec_size = context_vec_size | |||||
| # Encoder | |||||
| self.gru = nn.GRU(input_size=input_size, | |||||
| hidden_size=gru_hidden_size, | |||||
| num_layers=gru_num_layers, | |||||
| batch_first=True, | |||||
| bidirectional=True) | |||||
| # Attention | |||||
| self.fc = nn.Linear(2 * gru_hidden_size, context_vec_size) | |||||
| self.tanh = nn.Tanh() | |||||
| self.softmax = nn.Softmax(dim=1) | |||||
| # context vector | |||||
| self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1)) | |||||
| self.context_vec.data.uniform_(-0.1, 0.1) | |||||
| def forward(self, inputs): | |||||
| # GRU part | |||||
| h_t, hidden = self.gru(inputs) # inputs's dim (batch_size, seq_len, word_dim) | |||||
| u = self.tanh(self.fc(h_t)) | |||||
| # Attention part | |||||
| alpha = self.softmax(torch.matmul(u, self.context_vec)) # u's dim (batch_size, seq_len, context_vec_size) | |||||
| output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) # alpha's dim (batch_size, seq_len, 1) | |||||
| return torch.squeeze(output, dim=2) # output's dim (batch_size, 2*hidden_size, 1) | |||||
| @@ -0,0 +1,31 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from fastNLP.core.const import Const as C | |||||
| from .awdlstm_module import LSTM | |||||
| from fastNLP.modules import encoder | |||||
| from fastNLP.modules.decoder.mlp import MLP | |||||
| class AWDLSTMSentiment(nn.Module): | |||||
| def __init__(self, init_embed, | |||||
| num_classes, | |||||
| hidden_dim=256, | |||||
| num_layers=1, | |||||
| nfc=128, | |||||
| wdrop=0.5): | |||||
| super(AWDLSTMSentiment,self).__init__() | |||||
| self.embed = encoder.Embedding(init_embed) | |||||
| self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop) | |||||
| self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) | |||||
| def forward(self, words): | |||||
| x_emb = self.embed(words) | |||||
| output, _ = self.lstm(x_emb) | |||||
| output = self.mlp(output[:,-1,:]) | |||||
| return {C.OUTPUT: output} | |||||
| def predict(self, words): | |||||
| output = self(words) | |||||
| _, predict = output[C.OUTPUT].max(dim=1) | |||||
| return {C.OUTPUT: predict} | |||||
| @@ -0,0 +1,86 @@ | |||||
| """ | |||||
| 轻量封装的 Pytorch LSTM 模块. | |||||
| 可在 forward 时传入序列的长度, 自动对padding做合适的处理. | |||||
| """ | |||||
| __all__ = [ | |||||
| "LSTM" | |||||
| ] | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.utils.rnn as rnn | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| from torch import autograd | |||||
| from .weight_drop import WeightDrop | |||||
| class LSTM(nn.Module): | |||||
| """ | |||||
| 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` | |||||
| LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | |||||
| 为1; 且可以应对DataParallel中LSTM的使用问题。 | |||||
| :param input_size: 输入 `x` 的特征维度 | |||||
| :param hidden_size: 隐状态 `h` 的特征维度. | |||||
| :param num_layers: rnn的层数. Default: 1 | |||||
| :param dropout: 层间dropout概率. Default: 0 | |||||
| :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
| :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
| :(batch, seq, feature). Default: ``False`` | |||||
| :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
| """ | |||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||||
| bidirectional=False, bias=True, wdrop=0.5): | |||||
| super(LSTM, self).__init__() | |||||
| self.batch_first = batch_first | |||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||||
| dropout=dropout, bidirectional=bidirectional) | |||||
| self.lstm = WeightDrop(self.lstm, ['weight_hh_l0'], dropout=wdrop) | |||||
| self.init_param() | |||||
| def init_param(self): | |||||
| for name, param in self.named_parameters(): | |||||
| if 'bias' in name: | |||||
| # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 | |||||
| param.data.fill_(0) | |||||
| n = param.size(0) | |||||
| start, end = n // 4, n // 2 | |||||
| param.data[start:end].fill_(1) | |||||
| else: | |||||
| nn.init.xavier_uniform_(param) | |||||
| def forward(self, x, seq_len=None, h0=None, c0=None): | |||||
| """ | |||||
| :param x: [batch, seq_len, input_size] 输入序列 | |||||
| :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | |||||
| :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||||
| :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||||
| :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 | |||||
| 和 [batch, hidden_size*num_direction] 最后时刻隐状态. | |||||
| """ | |||||
| batch_size, max_len, _ = x.size() | |||||
| if h0 is not None and c0 is not None: | |||||
| hx = (h0, c0) | |||||
| else: | |||||
| hx = None | |||||
| if seq_len is not None and not isinstance(x, rnn.PackedSequence): | |||||
| sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||||
| if self.batch_first: | |||||
| x = x[sort_idx] | |||||
| else: | |||||
| x = x[:, sort_idx] | |||||
| x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) | |||||
| output, hx = self.lstm(x, hx) # -> [N,L,C] | |||||
| output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) | |||||
| _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
| if self.batch_first: | |||||
| output = output[unsort_idx] | |||||
| else: | |||||
| output = output[:, unsort_idx] | |||||
| else: | |||||
| output, hx = self.lstm(x, hx) | |||||
| return output, hx | |||||
| @@ -1 +1,90 @@ | |||||
| # TODO | |||||
| ''' | |||||
| @author: https://github.com/ahmedbesbes/character-based-cnn | |||||
| 这里借鉴了上述链接中char-cnn model的代码,改动主要为将其改动为符合fastnlp的pipline | |||||
| ''' | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from fastNLP.core.const import Const as C | |||||
| class CharacterLevelCNN(nn.Module): | |||||
| def __init__(self, args,embedding): | |||||
| super(CharacterLevelCNN, self).__init__() | |||||
| self.config=args.char_cnn_config | |||||
| self.embedding=embedding | |||||
| conv_layers = [] | |||||
| for i, conv_layer_parameter in enumerate(self.config['model_parameters'][args.model_size]['conv']): | |||||
| if i == 0: | |||||
| #in_channels = args.number_of_characters + len(args.extra_characters) | |||||
| in_channels = args.embedding_dim | |||||
| out_channels = conv_layer_parameter[0] | |||||
| else: | |||||
| in_channels, out_channels = conv_layer_parameter[0], conv_layer_parameter[0] | |||||
| if conv_layer_parameter[2] != -1: | |||||
| conv_layer = nn.Sequential(nn.Conv1d(in_channels, | |||||
| out_channels, | |||||
| kernel_size=conv_layer_parameter[1], padding=0), | |||||
| nn.ReLU(), | |||||
| nn.MaxPool1d(conv_layer_parameter[2])) | |||||
| else: | |||||
| conv_layer = nn.Sequential(nn.Conv1d(in_channels, | |||||
| out_channels, | |||||
| kernel_size=conv_layer_parameter[1], padding=0), | |||||
| nn.ReLU()) | |||||
| conv_layers.append(conv_layer) | |||||
| self.conv_layers = nn.ModuleList(conv_layers) | |||||
| input_shape = (args.batch_size, args.max_length, | |||||
| args.number_of_characters + len(args.extra_characters)) | |||||
| dimension = self._get_conv_output(input_shape) | |||||
| print('dimension :', dimension) | |||||
| fc_layer_parameter = self.config['model_parameters'][args.model_size]['fc'][0] | |||||
| fc_layers = nn.ModuleList([ | |||||
| nn.Sequential( | |||||
| nn.Linear(dimension, fc_layer_parameter), nn.Dropout(0.5)), | |||||
| nn.Sequential(nn.Linear(fc_layer_parameter, | |||||
| fc_layer_parameter), nn.Dropout(0.5)), | |||||
| nn.Linear(fc_layer_parameter, args.num_classes), | |||||
| ]) | |||||
| self.fc_layers = fc_layers | |||||
| if args.model_size == 'small': | |||||
| self._create_weights(mean=0.0, std=0.05) | |||||
| elif args.model_size == 'large': | |||||
| self._create_weights(mean=0.0, std=0.02) | |||||
| def _create_weights(self, mean=0.0, std=0.05): | |||||
| for module in self.modules(): | |||||
| if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear): | |||||
| module.weight.data.normal_(mean, std) | |||||
| def _get_conv_output(self, shape): | |||||
| input = torch.rand(shape) | |||||
| output = input.transpose(1, 2) | |||||
| # forward pass through conv layers | |||||
| for i in range(len(self.conv_layers)): | |||||
| output = self.conv_layers[i](output) | |||||
| output = output.view(output.size(0), -1) | |||||
| n_size = output.size(1) | |||||
| return n_size | |||||
| def forward(self, chars): | |||||
| input=self.embedding(chars) | |||||
| output = input.transpose(1, 2) | |||||
| # forward pass through conv layers | |||||
| for i in range(len(self.conv_layers)): | |||||
| output = self.conv_layers[i](output) | |||||
| output = output.view(output.size(0), -1) | |||||
| # forward pass through fc layers | |||||
| for i in range(len(self.fc_layers)): | |||||
| output = self.fc_layers[i](output) | |||||
| return {C.OUTPUT: output} | |||||
| @@ -1 +1,97 @@ | |||||
| # TODO | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from fastNLP.modules.utils import get_embeddings | |||||
| from fastNLP.core import Const as C | |||||
| class DPCNN(nn.Module): | |||||
| def __init__(self, init_embed, num_cls, n_filters=256, | |||||
| kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1): | |||||
| super().__init__() | |||||
| self.region_embed = RegionEmbedding( | |||||
| init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5]) | |||||
| embed_dim = self.region_embed.embedding_dim | |||||
| self.conv_list = nn.ModuleList() | |||||
| for i in range(n_layers): | |||||
| self.conv_list.append(nn.Sequential( | |||||
| nn.ReLU(), | |||||
| nn.Conv1d(n_filters, n_filters, kernel_size, | |||||
| padding=kernel_size//2), | |||||
| nn.Conv1d(n_filters, n_filters, kernel_size, | |||||
| padding=kernel_size//2), | |||||
| )) | |||||
| self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) | |||||
| self.embed_drop = nn.Dropout(embed_dropout) | |||||
| self.classfier = nn.Sequential( | |||||
| nn.Dropout(cls_dropout), | |||||
| nn.Linear(n_filters, num_cls), | |||||
| ) | |||||
| self.reset_parameters() | |||||
| def reset_parameters(self): | |||||
| for m in self.modules(): | |||||
| if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |||||
| nn.init.normal_(m.weight, mean=0, std=0.01) | |||||
| if m.bias is not None: | |||||
| nn.init.normal_(m.bias, mean=0, std=0.01) | |||||
| def forward(self, words, seq_len=None): | |||||
| words = words.long() | |||||
| # get region embeddings | |||||
| x = self.region_embed(words) | |||||
| x = self.embed_drop(x) | |||||
| # not pooling on first conv | |||||
| x = self.conv_list[0](x) + x | |||||
| for conv in self.conv_list[1:]: | |||||
| x = self.pool(x) | |||||
| x = conv(x) + x | |||||
| # B, C, L => B, C | |||||
| x, _ = torch.max(x, dim=2) | |||||
| x = self.classfier(x) | |||||
| return {C.OUTPUT: x} | |||||
| def predict(self, words, seq_len=None): | |||||
| x = self.forward(words, seq_len)[C.OUTPUT] | |||||
| return {C.OUTPUT: torch.argmax(x, 1)} | |||||
| class RegionEmbedding(nn.Module): | |||||
| def __init__(self, init_embed, out_dim=300, kernel_sizes=None): | |||||
| super().__init__() | |||||
| if kernel_sizes is None: | |||||
| kernel_sizes = [5, 9] | |||||
| assert isinstance( | |||||
| kernel_sizes, list), 'kernel_sizes should be List(int)' | |||||
| self.embed = get_embeddings(init_embed) | |||||
| try: | |||||
| embed_dim = self.embed.embedding_dim | |||||
| except Exception: | |||||
| embed_dim = self.embed.embed_size | |||||
| self.region_embeds = nn.ModuleList() | |||||
| for ksz in kernel_sizes: | |||||
| self.region_embeds.append(nn.Sequential( | |||||
| nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2), | |||||
| )) | |||||
| self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1) | |||||
| for _ in range(len(kernel_sizes))]) | |||||
| self.embedding_dim = embed_dim | |||||
| def forward(self, x): | |||||
| x = self.embed(x) | |||||
| x = x.transpose(1, 2) | |||||
| # B, C, L | |||||
| out = 0 | |||||
| for conv, fc in zip(self.region_embeds, self.linears[1:]): | |||||
| conv_i = conv(x) | |||||
| out = out + fc(conv_i) | |||||
| # B, C, L | |||||
| return out | |||||
| if __name__ == '__main__': | |||||
| x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long) | |||||
| model = DPCNN((10000, 300), 20) | |||||
| y = model(x) | |||||
| print(y.size(), y.mean(1), y.std(1)) | |||||
| @@ -0,0 +1,30 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from fastNLP.core.const import Const as C | |||||
| from fastNLP.modules.encoder.lstm import LSTM | |||||
| from fastNLP.modules import encoder | |||||
| from fastNLP.modules.decoder.mlp import MLP | |||||
| class BiLSTMSentiment(nn.Module): | |||||
| def __init__(self, init_embed, | |||||
| num_classes, | |||||
| hidden_dim=256, | |||||
| num_layers=1, | |||||
| nfc=128): | |||||
| super(BiLSTMSentiment,self).__init__() | |||||
| self.embed = encoder.Embedding(init_embed) | |||||
| self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) | |||||
| self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) | |||||
| def forward(self, words): | |||||
| x_emb = self.embed(words) | |||||
| output, _ = self.lstm(x_emb) | |||||
| output = self.mlp(output[:,-1,:]) | |||||
| return {C.OUTPUT: output} | |||||
| def predict(self, words): | |||||
| output = self(words) | |||||
| _, predict = output[C.OUTPUT].max(dim=1) | |||||
| return {C.OUTPUT: predict} | |||||
| @@ -0,0 +1,35 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from fastNLP.core.const import Const as C | |||||
| from fastNLP.modules.encoder.lstm import LSTM | |||||
| from fastNLP.modules import encoder | |||||
| from fastNLP.modules.aggregator.attention import SelfAttention | |||||
| from fastNLP.modules.decoder.mlp import MLP | |||||
| class BiLSTM_SELF_ATTENTION(nn.Module): | |||||
| def __init__(self, init_embed, | |||||
| num_classes, | |||||
| hidden_dim=256, | |||||
| num_layers=1, | |||||
| attention_unit=256, | |||||
| attention_hops=1, | |||||
| nfc=128): | |||||
| super(BiLSTM_SELF_ATTENTION,self).__init__() | |||||
| self.embed = encoder.Embedding(init_embed) | |||||
| self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) | |||||
| self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops) | |||||
| self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes]) | |||||
| def forward(self, words): | |||||
| x_emb = self.embed(words) | |||||
| output, _ = self.lstm(x_emb) | |||||
| after_attention, penalty = self.attention(output,words) | |||||
| after_attention =after_attention.view(after_attention.size(0),-1) | |||||
| output = self.mlp(after_attention) | |||||
| return {C.OUTPUT: output} | |||||
| def predict(self, words): | |||||
| output = self(words) | |||||
| _, predict = output[C.OUTPUT].max(dim=1) | |||||
| return {C.OUTPUT: predict} | |||||
| @@ -0,0 +1,99 @@ | |||||
| import torch | |||||
| from torch.nn import Parameter | |||||
| from functools import wraps | |||||
| class WeightDrop(torch.nn.Module): | |||||
| def __init__(self, module, weights, dropout=0, variational=False): | |||||
| super(WeightDrop, self).__init__() | |||||
| self.module = module | |||||
| self.weights = weights | |||||
| self.dropout = dropout | |||||
| self.variational = variational | |||||
| self._setup() | |||||
| def widget_demagnetizer_y2k_edition(*args, **kwargs): | |||||
| # We need to replace flatten_parameters with a nothing function | |||||
| # It must be a function rather than a lambda as otherwise pickling explodes | |||||
| # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! | |||||
| # (╯°□°)╯︵ ┻━┻ | |||||
| return | |||||
| def _setup(self): | |||||
| # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN | |||||
| if issubclass(type(self.module), torch.nn.RNNBase): | |||||
| self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition | |||||
| for name_w in self.weights: | |||||
| print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) | |||||
| w = getattr(self.module, name_w) | |||||
| del self.module._parameters[name_w] | |||||
| self.module.register_parameter(name_w + '_raw', Parameter(w.data)) | |||||
| def _setweights(self): | |||||
| for name_w in self.weights: | |||||
| raw_w = getattr(self.module, name_w + '_raw') | |||||
| w = None | |||||
| if self.variational: | |||||
| mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) | |||||
| if raw_w.is_cuda: mask = mask.cuda() | |||||
| mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) | |||||
| w = mask.expand_as(raw_w) * raw_w | |||||
| else: | |||||
| w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) | |||||
| setattr(self.module, name_w, w) | |||||
| def forward(self, *args): | |||||
| self._setweights() | |||||
| return self.module.forward(*args) | |||||
| if __name__ == '__main__': | |||||
| import torch | |||||
| from weight_drop import WeightDrop | |||||
| # Input is (seq, batch, input) | |||||
| x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() | |||||
| h0 = None | |||||
| ### | |||||
| print('Testing WeightDrop') | |||||
| print('=-=-=-=-=-=-=-=-=-=') | |||||
| ### | |||||
| print('Testing WeightDrop with Linear') | |||||
| lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) | |||||
| lin.cuda() | |||||
| run1 = [x.sum() for x in lin(x).data] | |||||
| run2 = [x.sum() for x in lin(x).data] | |||||
| print('All items should be different') | |||||
| print('Run 1:', run1) | |||||
| print('Run 2:', run2) | |||||
| assert run1[0] != run2[0] | |||||
| assert run1[1] != run2[1] | |||||
| print('---') | |||||
| ### | |||||
| print('Testing WeightDrop with LSTM') | |||||
| wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) | |||||
| wdrnn.cuda() | |||||
| run1 = [x.sum() for x in wdrnn(x, h0)[0].data] | |||||
| run2 = [x.sum() for x in wdrnn(x, h0)[0].data] | |||||
| print('First timesteps should be equal, all others should differ') | |||||
| print('Run 1:', run1) | |||||
| print('Run 2:', run2) | |||||
| # First time step, not influenced by hidden to hidden weights, should be equal | |||||
| assert run1[0] == run2[0] | |||||
| # Second step should not | |||||
| assert run1[1] != run2[1] | |||||
| print('---') | |||||
| @@ -0,0 +1,109 @@ | |||||
| # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 | |||||
| import os | |||||
| import sys | |||||
| sys.path.append('../../') | |||||
| os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
| os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||||
| from fastNLP.core.const import Const as C | |||||
| from fastNLP.core import LRScheduler | |||||
| import torch.nn as nn | |||||
| from fastNLP.io.dataset_loader import SSTLoader | |||||
| from reproduction.text_classification.data.yelpLoader import yelpLoader | |||||
| from reproduction.text_classification.model.HAN import HANCLS | |||||
| from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding | |||||
| from fastNLP import CrossEntropyLoss, AccuracyMetric | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from torch.optim import SGD | |||||
| import torch.cuda | |||||
| from torch.optim.lr_scheduler import CosineAnnealingLR | |||||
| ##hyper | |||||
| class Config(): | |||||
| model_dir_or_name = "en-base-uncased" | |||||
| embedding_grad = False, | |||||
| train_epoch = 30 | |||||
| batch_size = 100 | |||||
| num_classes = 5 | |||||
| task = "yelp" | |||||
| #datadir = '/remote-home/lyli/fastNLP/yelp_polarity/' | |||||
| datadir = '/remote-home/ygwang/yelp_polarity/' | |||||
| datafile = {"train": "train.csv", "test": "test.csv"} | |||||
| lr = 1e-3 | |||||
| def __init__(self): | |||||
| self.datapath = {k: os.path.join(self.datadir, v) | |||||
| for k, v in self.datafile.items()} | |||||
| ops = Config() | |||||
| ##1.task相关信息:利用dataloader载入dataInfo | |||||
| datainfo = yelpLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) | |||||
| print(len(datainfo.datasets['train'])) | |||||
| print(len(datainfo.datasets['test'])) | |||||
| # post process | |||||
| def make_sents(words): | |||||
| sents = [words] | |||||
| return sents | |||||
| for dataset in datainfo.datasets.values(): | |||||
| dataset.apply_field(make_sents, field_name='words', new_field_name='input_sents') | |||||
| datainfo = datainfo | |||||
| datainfo.datasets['train'].set_input('input_sents') | |||||
| datainfo.datasets['test'].set_input('input_sents') | |||||
| datainfo.datasets['train'].set_target('target') | |||||
| datainfo.datasets['test'].set_target('target') | |||||
| ## 2.或直接复用fastNLP的模型 | |||||
| vocab = datainfo.vocabs['words'] | |||||
| # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | |||||
| embedding = StaticEmbedding(vocab) | |||||
| print(len(vocab)) | |||||
| print(len(datainfo.vocabs['target'])) | |||||
| # model = DPCNN(init_embed=embedding, num_cls=ops.num_classes) | |||||
| model = HANCLS(init_embed=embedding, num_cls=ops.num_classes) | |||||
| ## 3. 声明loss,metric,optimizer | |||||
| loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) | |||||
| metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) | |||||
| optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], | |||||
| lr=ops.lr, momentum=0.9, weight_decay=0) | |||||
| callbacks = [] | |||||
| callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) | |||||
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||||
| print(device) | |||||
| for ds in datainfo.datasets.values(): | |||||
| ds.apply_field(len, C.INPUT, C.INPUT_LEN) | |||||
| ds.set_input(C.INPUT, C.INPUT_LEN) | |||||
| ds.set_target(C.TARGET) | |||||
| ## 4.定义train方法 | |||||
| def train(model, datainfo, loss, metrics, optimizer, num_epochs=ops.train_epoch): | |||||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
| metrics=[metrics], dev_data=datainfo.datasets['test'], device=device, | |||||
| check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||||
| n_epochs=num_epochs) | |||||
| print(trainer.train()) | |||||
| if __name__ == "__main__": | |||||
| train(model, datainfo, loss, metric, optimizer) | |||||
| @@ -0,0 +1,69 @@ | |||||
| # 这个模型需要在pytorch=0.4下运行,weight_drop不支持1.0 | |||||
| # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 | |||||
| import os | |||||
| os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
| os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
| import torch.nn as nn | |||||
| from data.IMDBLoader import IMDBLoader | |||||
| from fastNLP.modules.encoder.embedding import StaticEmbedding | |||||
| from model.awd_lstm import AWDLSTMSentiment | |||||
| from fastNLP.core.const import Const as C | |||||
| from fastNLP import CrossEntropyLoss, AccuracyMetric | |||||
| from fastNLP import Trainer, Tester | |||||
| from torch.optim import Adam | |||||
| from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
| import argparse | |||||
| class Config(): | |||||
| train_epoch= 10 | |||||
| lr=0.001 | |||||
| num_classes=2 | |||||
| hidden_dim=256 | |||||
| num_layers=1 | |||||
| nfc=128 | |||||
| wdrop=0.5 | |||||
| task_name = "IMDB" | |||||
| datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} | |||||
| save_model_path="./result_IMDB_test/" | |||||
| opt=Config() | |||||
| # load data | |||||
| dataloader=IMDBLoader() | |||||
| datainfo=dataloader.process(opt.datapath) | |||||
| # print(datainfo.datasets["train"]) | |||||
| # print(datainfo) | |||||
| # define model | |||||
| vocab=datainfo.vocabs['words'] | |||||
| embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) | |||||
| model=AWDLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc, wdrop=opt.wdrop) | |||||
| # define loss_function and metrics | |||||
| loss=CrossEntropyLoss() | |||||
| metrics=AccuracyMetric() | |||||
| optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) | |||||
| def train(datainfo, model, optimizer, loss, metrics, opt): | |||||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
| metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, | |||||
| n_epochs=opt.train_epoch, save_path=opt.save_model_path) | |||||
| trainer.train() | |||||
| if __name__ == "__main__": | |||||
| train(datainfo, model, optimizer, loss, metrics, opt) | |||||
| @@ -0,0 +1,205 @@ | |||||
| # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 | |||||
| import os | |||||
| os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
| os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
| import sys | |||||
| sys.path.append('../..') | |||||
| from fastNLP.core.const import Const as C | |||||
| import torch.nn as nn | |||||
| from data.yelpLoader import yelpLoader | |||||
| from data.sstLoader import sst2Loader | |||||
| from data.IMDBLoader import IMDBLoader | |||||
| from model.char_cnn import CharacterLevelCNN | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.models.cnn_text_classification import CNNText | |||||
| from fastNLP.modules.encoder.embedding import CNNCharEmbedding,StaticEmbedding,StackEmbedding,LSTMCharEmbedding | |||||
| from fastNLP import CrossEntropyLoss, AccuracyMetric | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from torch.optim import SGD | |||||
| from torch.autograd import Variable | |||||
| import torch | |||||
| from fastNLP import BucketSampler | |||||
| ##hyper | |||||
| #todo 这里加入fastnlp的记录 | |||||
| class Config(): | |||||
| model_dir_or_name="en-base-uncased" | |||||
| embedding_grad= False, | |||||
| bert_embedding_larers= '4,-2,-1' | |||||
| train_epoch= 50 | |||||
| num_classes=2 | |||||
| task= "IMDB" | |||||
| #yelp_p | |||||
| datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv", | |||||
| "test": "/remote-home/ygwang/yelp_polarity/test.csv"} | |||||
| #IMDB | |||||
| #datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv", | |||||
| # "test": "/remote-home/ygwang/IMDB_data/test.csv"} | |||||
| # sst | |||||
| # datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", | |||||
| # "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} | |||||
| lr=0.01 | |||||
| batch_size=128 | |||||
| model_size="large" | |||||
| number_of_characters=69 | |||||
| extra_characters='' | |||||
| max_length=1014 | |||||
| char_cnn_config={ | |||||
| "alphabet": { | |||||
| "en": { | |||||
| "lower": { | |||||
| "alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}", | |||||
| "number_of_characters": 69 | |||||
| }, | |||||
| "both": { | |||||
| "alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}", | |||||
| "number_of_characters": 95 | |||||
| } | |||||
| } | |||||
| }, | |||||
| "model_parameters": { | |||||
| "small": { | |||||
| "conv": [ | |||||
| #依次是channel,kennnel_size,maxpooling_size | |||||
| [256,7,3], | |||||
| [256,7,3], | |||||
| [256,3,-1], | |||||
| [256,3,-1], | |||||
| [256,3,-1], | |||||
| [256,3,3] | |||||
| ], | |||||
| "fc": [1024,1024] | |||||
| }, | |||||
| "large":{ | |||||
| "conv":[ | |||||
| [1024, 7, 3], | |||||
| [1024, 7, 3], | |||||
| [1024, 3, -1], | |||||
| [1024, 3, -1], | |||||
| [1024, 3, -1], | |||||
| [1024, 3, 3] | |||||
| ], | |||||
| "fc": [2048,2048] | |||||
| } | |||||
| }, | |||||
| "data": { | |||||
| "text_column": "SentimentText", | |||||
| "label_column": "Sentiment", | |||||
| "max_length": 1014, | |||||
| "num_of_classes": 2, | |||||
| "encoding": None, | |||||
| "chunksize": 50000, | |||||
| "max_rows": 100000, | |||||
| "preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"] | |||||
| }, | |||||
| "training": { | |||||
| "batch_size": 128, | |||||
| "learning_rate": 0.01, | |||||
| "epochs": 10, | |||||
| "optimizer": "sgd" | |||||
| } | |||||
| } | |||||
| ops=Config | |||||
| ##1.task相关信息:利用dataloader载入dataInfo | |||||
| #dataloader=sst2Loader() | |||||
| #dataloader=IMDBLoader() | |||||
| dataloader=yelpLoader(fine_grained=True) | |||||
| datainfo=dataloader.process(ops.datapath,char_level_op=True) | |||||
| char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] | |||||
| ops.number_of_characters=len(char_vocab) | |||||
| ops.embedding_dim=ops.number_of_characters | |||||
| #chartoindex | |||||
| def chartoindex(chars): | |||||
| max_seq_len=ops.max_length | |||||
| zero_index=len(char_vocab) | |||||
| char_index_list=[] | |||||
| for char in chars: | |||||
| if char in char_vocab: | |||||
| char_index_list.append(char_vocab.index(char)) | |||||
| else: | |||||
| #<unk>和<pad>均使用最后一个作为embbeding | |||||
| char_index_list.append(zero_index) | |||||
| if len(char_index_list) > max_seq_len: | |||||
| char_index_list = char_index_list[:max_seq_len] | |||||
| elif 0 < len(char_index_list) < max_seq_len: | |||||
| char_index_list = char_index_list+[zero_index]*(max_seq_len-len(char_index_list)) | |||||
| elif len(char_index_list) == 0: | |||||
| char_index_list=[zero_index]*max_seq_len | |||||
| return char_index_list | |||||
| for dataset in datainfo.datasets.values(): | |||||
| dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars') | |||||
| datainfo.datasets['train'].set_input('chars') | |||||
| datainfo.datasets['test'].set_input('chars') | |||||
| datainfo.datasets['train'].set_target('target') | |||||
| datainfo.datasets['test'].set_target('target') | |||||
| ##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model | |||||
| class ModelFactory(nn.Module): | |||||
| """ | |||||
| 用于拼装embedding,encoder,decoder 以及设计forward过程 | |||||
| :param embedding: embbeding model | |||||
| :param encoder: encoder model | |||||
| :param decoder: decoder model | |||||
| """ | |||||
| def __int__(self,embedding,encoder,decoder,**kwargs): | |||||
| super(ModelFactory,self).__init__() | |||||
| self.embedding=embedding | |||||
| self.encoder=encoder | |||||
| self.decoder=decoder | |||||
| def forward(self,x): | |||||
| return {C.OUTPUT:None} | |||||
| ## 2.或直接复用fastNLP的模型 | |||||
| #vocab=datainfo.vocabs['words'] | |||||
| vocab_label=datainfo.vocabs['target'] | |||||
| ''' | |||||
| # emded_char=CNNCharEmbedding(vocab) | |||||
| # embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||||
| # embedding=StackEmbedding([emded_char, embed_word]) | |||||
| # cnn_char_embed = CNNCharEmbedding(vocab) | |||||
| # lstm_char_embed = LSTMCharEmbedding(vocab) | |||||
| # embedding = StackEmbedding([cnn_char_embed, lstm_char_embed]) | |||||
| ''' | |||||
| #one-hot embedding | |||||
| embedding_weight= Variable(torch.zeros(len(char_vocab)+1, len(char_vocab))) | |||||
| for i in range(len(char_vocab)): | |||||
| embedding_weight[i][i]=1 | |||||
| embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_vocab),padding_idx=len(char_vocab),_weight=embedding_weight) | |||||
| for para in embedding.parameters(): | |||||
| para.requires_grad=False | |||||
| #CNNText太过于简单 | |||||
| #model=CNNText(init_embed=embedding, num_classes=ops.num_classes) | |||||
| model=CharacterLevelCNN(ops,embedding) | |||||
| ## 3. 声明loss,metric,optimizer | |||||
| loss=CrossEntropyLoss | |||||
| metric=AccuracyMetric | |||||
| optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr) | |||||
| ## 4.定义train方法 | |||||
| def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): | |||||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'), | |||||
| metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, | |||||
| n_epochs=num_epochs) | |||||
| print(trainer.train()) | |||||
| if __name__=="__main__": | |||||
| #print(vocab_label) | |||||
| #print(datainfo.datasets["train"]) | |||||
| train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch) | |||||
| @@ -0,0 +1,120 @@ | |||||
| # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 | |||||
| import torch.cuda | |||||
| from fastNLP.core.utils import cache_results | |||||
| from torch.optim import SGD | |||||
| from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP import CrossEntropyLoss, AccuracyMetric | |||||
| from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding | |||||
| from reproduction.text_classification.model.dpcnn import DPCNN | |||||
| from data.yelpLoader import yelpLoader | |||||
| from fastNLP.core.sampler import BucketSampler | |||||
| import torch.nn as nn | |||||
| from fastNLP.core import LRScheduler | |||||
| from fastNLP.core.const import Const as C | |||||
| from fastNLP.core.vocabulary import VocabularyOption | |||||
| from utils.util_init import set_rng_seeds | |||||
| import os | |||||
| os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
| os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||||
| # hyper | |||||
| class Config(): | |||||
| seed = 12345 | |||||
| model_dir_or_name = "dpcnn-yelp-p" | |||||
| embedding_grad = True | |||||
| train_epoch = 30 | |||||
| batch_size = 100 | |||||
| task = "yelp_p" | |||||
| #datadir = 'workdir/datasets/SST' | |||||
| datadir = 'workdir/datasets/yelp_polarity' | |||||
| # datadir = 'workdir/datasets/yelp_full' | |||||
| #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} | |||||
| datafile = {"train": "train.csv", "test": "test.csv"} | |||||
| lr = 1e-3 | |||||
| src_vocab_op = VocabularyOption(max_size=100000) | |||||
| embed_dropout = 0.3 | |||||
| cls_dropout = 0.1 | |||||
| weight_decay = 1e-5 | |||||
| def __init__(self): | |||||
| self.datadir = os.path.join(os.environ['HOME'], self.datadir) | |||||
| self.datapath = {k: os.path.join(self.datadir, v) | |||||
| for k, v in self.datafile.items()} | |||||
| ops = Config() | |||||
| set_rng_seeds(ops.seed) | |||||
| print('RNG SEED: {}'.format(ops.seed)) | |||||
| # 1.task相关信息:利用dataloader载入dataInfo | |||||
| #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) | |||||
| @cache_results(ops.model_dir_or_name+'-data-cache') | |||||
| def load_data(): | |||||
| datainfo = yelpLoader(fine_grained=True, lower=True).process( | |||||
| paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op) | |||||
| for ds in datainfo.datasets.values(): | |||||
| ds.apply_field(len, C.INPUT, C.INPUT_LEN) | |||||
| ds.set_input(C.INPUT, C.INPUT_LEN) | |||||
| ds.set_target(C.TARGET) | |||||
| embedding = StaticEmbedding( | |||||
| datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad, | |||||
| normalize=False | |||||
| ) | |||||
| return datainfo, embedding | |||||
| datainfo, embedding = load_data() | |||||
| # 2.或直接复用fastNLP的模型 | |||||
| # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | |||||
| print(datainfo) | |||||
| print(datainfo.datasets['train'][0]) | |||||
| model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), | |||||
| embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) | |||||
| print(model) | |||||
| # 3. 声明loss,metric,optimizer | |||||
| loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) | |||||
| metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET) | |||||
| optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], | |||||
| lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) | |||||
| callbacks = [] | |||||
| # callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) | |||||
| callbacks.append( | |||||
| LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < | |||||
| ops.train_epoch * 0.8 else ops.lr * 0.1)) | |||||
| ) | |||||
| # callbacks.append( | |||||
| # FitlogCallback(data=datainfo.datasets, verbose=1) | |||||
| # ) | |||||
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||||
| print(device) | |||||
| # 4.定义train方法 | |||||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
| sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||||
| metrics=[metric], | |||||
| dev_data=datainfo.datasets['test'], device=device, | |||||
| check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||||
| n_epochs=ops.train_epoch, num_workers=4) | |||||
| if __name__ == "__main__": | |||||
| print(trainer.train()) | |||||
| @@ -0,0 +1,66 @@ | |||||
| # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 | |||||
| import os | |||||
| os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
| os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
| import torch.nn as nn | |||||
| from data.IMDBLoader import IMDBLoader | |||||
| from fastNLP.modules.encoder.embedding import StaticEmbedding | |||||
| from model.lstm import BiLSTMSentiment | |||||
| from fastNLP.core.const import Const as C | |||||
| from fastNLP import CrossEntropyLoss, AccuracyMetric | |||||
| from fastNLP import Trainer, Tester | |||||
| from torch.optim import Adam | |||||
| from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
| import argparse | |||||
| class Config(): | |||||
| train_epoch= 10 | |||||
| lr=0.001 | |||||
| num_classes=2 | |||||
| hidden_dim=256 | |||||
| num_layers=1 | |||||
| nfc=128 | |||||
| task_name = "IMDB" | |||||
| datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} | |||||
| save_model_path="./result_IMDB_test/" | |||||
| opt=Config() | |||||
| # load data | |||||
| dataloader=IMDBLoader() | |||||
| datainfo=dataloader.process(opt.datapath) | |||||
| # print(datainfo.datasets["train"]) | |||||
| # print(datainfo) | |||||
| # define model | |||||
| vocab=datainfo.vocabs['words'] | |||||
| embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) | |||||
| model=BiLSTMSentiment(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, nfc=opt.nfc) | |||||
| # define loss_function and metrics | |||||
| loss=CrossEntropyLoss() | |||||
| metrics=AccuracyMetric() | |||||
| optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) | |||||
| def train(datainfo, model, optimizer, loss, metrics, opt): | |||||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
| metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, | |||||
| n_epochs=opt.train_epoch, save_path=opt.save_model_path) | |||||
| trainer.train() | |||||
| if __name__ == "__main__": | |||||
| train(datainfo, model, optimizer, loss, metrics, opt) | |||||
| @@ -0,0 +1,68 @@ | |||||
| # 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径 | |||||
| import os | |||||
| os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
| os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
| import torch.nn as nn | |||||
| from data.IMDBLoader import IMDBLoader | |||||
| from fastNLP.modules.encoder.embedding import StaticEmbedding | |||||
| from model.lstm_self_attention import BiLSTM_SELF_ATTENTION | |||||
| from fastNLP.core.const import Const as C | |||||
| from fastNLP import CrossEntropyLoss, AccuracyMetric | |||||
| from fastNLP import Trainer, Tester | |||||
| from torch.optim import Adam | |||||
| from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
| import argparse | |||||
| class Config(): | |||||
| train_epoch= 10 | |||||
| lr=0.001 | |||||
| num_classes=2 | |||||
| hidden_dim=256 | |||||
| num_layers=1 | |||||
| attention_unit=256 | |||||
| attention_hops=1 | |||||
| nfc=128 | |||||
| task_name = "IMDB" | |||||
| datapath={"train":"IMDB_data/train.csv", "test":"IMDB_data/test.csv"} | |||||
| save_model_path="./result_IMDB_test/" | |||||
| opt=Config() | |||||
| # load data | |||||
| dataloader=IMDBLoader() | |||||
| datainfo=dataloader.process(opt.datapath) | |||||
| # print(datainfo.datasets["train"]) | |||||
| # print(datainfo) | |||||
| # define model | |||||
| vocab=datainfo.vocabs['words'] | |||||
| embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-840b-300', requires_grad=True) | |||||
| model=BiLSTM_SELF_ATTENTION(init_embed=embed, num_classes=opt.num_classes, hidden_dim=opt.hidden_dim, num_layers=opt.num_layers, attention_unit=opt.attention_unit, attention_hops=opt.attention_hops, nfc=opt.nfc) | |||||
| # define loss_function and metrics | |||||
| loss=CrossEntropyLoss() | |||||
| metrics=AccuracyMetric() | |||||
| optimizer= Adam([param for param in model.parameters() if param.requires_grad==True], lr=opt.lr) | |||||
| def train(datainfo, model, optimizer, loss, metrics, opt): | |||||
| trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
| metrics=metrics, dev_data=datainfo.datasets['test'], device=0, check_code_level=-1, | |||||
| n_epochs=opt.train_epoch, save_path=opt.save_model_path) | |||||
| trainer.train() | |||||
| if __name__ == "__main__": | |||||
| train(datainfo, model, optimizer, loss, metrics, opt) | |||||
| @@ -0,0 +1,11 @@ | |||||
| import numpy | |||||
| import torch | |||||
| import random | |||||
| def set_rng_seeds(seed): | |||||
| random.seed(seed) | |||||
| numpy.random.seed(seed) | |||||
| torch.random.manual_seed(seed) | |||||
| torch.cuda.manual_seed_all(seed) | |||||
| # print('RNG_SEED {}'.format(seed)) | |||||
| @@ -59,4 +59,13 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
| else: | else: | ||||
| raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | ||||
| def get_tokenizer(): | |||||
| try: | |||||
| import spacy | |||||
| spacy.prefer_gpu() | |||||
| en = spacy.load('en') | |||||
| print('use spacy tokenizer') | |||||
| return lambda x: [w.text for w in en.tokenizer(x)] | |||||
| except Exception as e: | |||||
| print('use raw tokenizer') | |||||
| return lambda x: x.split() | |||||