diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index 021a79b7..8d0d005f 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -5,10 +5,8 @@ from ..base_loader import DataInfo, DataSetLoader from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet from ...core.instance import Instance -from ..embed_loader import EmbeddingOption, EmbedLoader +from ..utils import check_dataloader_paths, get_tokenizer -spacy.prefer_gpu() -sptk = spacy.load('en') class SSTLoader(DataSetLoader): URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' @@ -37,6 +35,7 @@ class SSTLoader(DataSetLoader): tag_v['0'] = tag_v['1'] tag_v['4'] = tag_v['3'] self.tag_v = tag_v + self.tokenizer = get_tokenizer() def _load(self, path): """ @@ -55,29 +54,37 @@ class SSTLoader(DataSetLoader): ds.append(Instance(words=words, target=tag)) return ds - @staticmethod - def _get_one(data, subtree): + def _get_one(self, data, subtree): tree = Tree.fromstring(data) if subtree: - return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] - return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())] + return [([x.text for x in self.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] + return [([x.text for x in self.tokenizer(' '.join(tree.leaves()))], tree.label())] def process(self, - paths, - train_ds: Iterable[str] = None, + paths, train_subtree=True, src_vocab_op: VocabularyOption = None, - tgt_vocab_op: VocabularyOption = None, - src_embed_op: EmbeddingOption = None): + tgt_vocab_op: VocabularyOption = None,): + paths = check_dataloader_paths(paths) input_name, target_name = 'words', 'target' src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) - info = DataInfo(datasets=self.load(paths)) - _train_ds = [info.datasets[name] - for name in train_ds] if train_ds else info.datasets.values() - src_vocab.from_dataset(*_train_ds, field_name=input_name) - tgt_vocab.from_dataset(*_train_ds, field_name=target_name) + info = DataInfo() + origin_subtree = self.subtree + self.subtree = train_subtree + info.datasets['train'] = self._load(paths['train']) + self.subtree = origin_subtree + for n, p in paths.items(): + if n != 'train': + info.datasets[n] = self._load(p) + + src_vocab.from_dataset( + info.datasets['train'], + field_name=input_name, + no_create_entry_dataset=[ds for n, ds in info.datasets.items() if n != 'train']) + tgt_vocab.from_dataset(info.datasets['train'], field_name=target_name) + src_vocab.index_dataset( *info.datasets.values(), field_name=input_name, new_field_name=input_name) @@ -89,10 +96,5 @@ class SSTLoader(DataSetLoader): target_name: tgt_vocab } - if src_embed_op is not None: - src_embed_op.vocab = src_vocab - init_emb = EmbedLoader.load_with_vocab(**src_embed_op) - info.embeddings[input_name] = init_emb - return info diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py new file mode 100644 index 00000000..a7d2de85 --- /dev/null +++ b/fastNLP/io/utils.py @@ -0,0 +1,69 @@ +import os + +from typing import Union, Dict + + +def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: + """ + 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 + { + 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 + 'test': 'xxx' # 可能有,也可能没有 + ... + } + 如果paths为不合法的,将直接进行raise相应的错误 + + :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 + 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 + :return: + """ + if isinstance(paths, str): + if os.path.isfile(paths): + return {'train': paths} + elif os.path.isdir(paths): + filenames = os.listdir(paths) + files = {} + for filename in filenames: + path_pair = None + if 'train' in filename: + path_pair = ('train', filename) + if 'dev' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) + path_pair = ('dev', filename) + if 'test' in filename: + if path_pair: + raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) + path_pair = ('test', filename) + if path_pair: + files[path_pair[0]] = os.path.join(paths, path_pair[1]) + return files + else: + raise FileNotFoundError(f"{paths} is not a valid file path.") + + elif isinstance(paths, dict): + if paths: + if 'train' not in paths: + raise KeyError("You have to include `train` in your dict.") + for key, value in paths.items(): + if isinstance(key, str) and isinstance(value, str): + if not os.path.isfile(value): + raise TypeError(f"{value} is not a valid file.") + else: + raise TypeError("All keys and values in paths should be str.") + return paths + else: + raise ValueError("Empty paths is not allowed.") + else: + raise TypeError(f"paths only supports str and dict. not {type(paths)}.") + +def get_tokenizer(): + try: + import spacy + spacy.prefer_gpu() + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split() diff --git a/reproduction/seqence_labelling/ner/model/dilated_cnn.py b/reproduction/seqence_labelling/ner/model/dilated_cnn.py index cd2fa64b..a4e02159 100644 --- a/reproduction/seqence_labelling/ner/model/dilated_cnn.py +++ b/reproduction/seqence_labelling/ner/model/dilated_cnn.py @@ -8,16 +8,23 @@ from fastNLP.core.const import Const as C class IDCNN(nn.Module): - def __init__(self, init_embed, char_embed, + def __init__(self, + init_embed, + char_embed, num_cls, repeats, num_layers, num_filters, kernel_size, use_crf=False, use_projection=False, block_loss=False, input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): super(IDCNN, self).__init__() self.word_embeddings = Embedding(init_embed) - self.char_embeddings = Embedding(char_embed) - embedding_size = self.word_embeddings.embedding_dim + \ - self.char_embeddings.embedding_dim + + if char_embed is None: + self.char_embeddings = None + embedding_size = self.word_embeddings.embedding_dim + else: + self.char_embeddings = Embedding(char_embed) + embedding_size = self.word_embeddings.embedding_dim + \ + self.char_embeddings.embedding_dim self.conv0 = nn.Sequential( nn.Conv1d(in_channels=embedding_size, @@ -31,7 +38,7 @@ class IDCNN(nn.Module): block = [] for layer_i in range(num_layers): - dilated = 2 ** layer_i + dilated = 2 ** layer_i if layer_i+1 < num_layers else 1 block.append(nn.Conv1d( in_channels=num_filters, out_channels=num_filters, @@ -67,11 +74,24 @@ class IDCNN(nn.Module): self.crf = ConditionalRandomField( num_tags=num_cls) if use_crf else None self.block_loss = block_loss + self.reset_parameters() - def forward(self, words, chars, seq_len, target=None): - e1 = self.word_embeddings(words) - e2 = self.char_embeddings(chars) - x = torch.cat((e1, e2), dim=-1) # b,l,h + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.xavier_normal_(m.weight, gain=1) + if m.bias is not None: + nn.init.normal_(m.bias, mean=0, std=0.01) + + def forward(self, words, seq_len, target=None, chars=None): + if self.char_embeddings is None: + x = self.word_embeddings(words) + else: + if chars is None: + raise ValueError('must provide chars for model with char embedding') + e1 = self.word_embeddings(words) + e2 = self.char_embeddings(chars) + x = torch.cat((e1, e2), dim=-1) # b,l,h mask = seq_len_to_mask(seq_len) x = x.transpose(1, 2) # b,h,l @@ -84,21 +104,24 @@ class IDCNN(nn.Module): def compute_loss(y, t, mask): if self.crf is not None and target is not None: - loss = self.crf(y, t, mask) + loss = self.crf(y.transpose(1, 2), t, mask) else: t.masked_fill_(mask == 0, -100) loss = F.cross_entropy(y, t, ignore_index=-100) return loss - if self.block_loss: - losses = [compute_loss(o, target, mask) for o in output] - loss = sum(losses) + if target is not None: + if self.block_loss: + losses = [compute_loss(o, target, mask) for o in output] + loss = sum(losses) + else: + loss = compute_loss(output[-1], target, mask) else: - loss = compute_loss(output[-1], target, mask) + loss = None scores = output[-1] if self.crf is not None: - pred = self.crf.viterbi_decode(scores, target, mask) + pred, _ = self.crf.viterbi_decode(scores.transpose(1, 2), mask) else: pred = scores.max(1)[1] * mask.long() @@ -107,5 +130,13 @@ class IDCNN(nn.Module): C.OUTPUT: pred, } - def predict(self, words, chars, seq_len): - return self.forward(words, chars, seq_len)[C.OUTPUT] + def predict(self, words, seq_len, chars=None): + res = self.forward( + words=words, + seq_len=seq_len, + chars=chars, + target=None + )[C.OUTPUT] + return { + C.OUTPUT: res + } diff --git a/reproduction/seqence_labelling/ner/train_idcnn.py b/reproduction/seqence_labelling/ner/train_idcnn.py new file mode 100644 index 00000000..1781c763 --- /dev/null +++ b/reproduction/seqence_labelling/ner/train_idcnn.py @@ -0,0 +1,99 @@ +from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader +from fastNLP.core.callback import FitlogCallback, LRScheduler +from fastNLP import GradientClipCallback +from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR +from torch.optim import SGD, Adam +from fastNLP import Const +from fastNLP import RandomSampler, BucketSampler +from fastNLP import SpanFPreRecMetric +from fastNLP import Trainer +from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN +from fastNLP.core.utils import Option +from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding +from fastNLP.core.utils import cache_results +import sys +import torch.cuda +import os +os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' +os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +encoding_type = 'bioes' + + +def get_path(path): + return os.path.join(os.environ['HOME'], path) + +data_path = get_path('workdir/datasets/ontonotes-v4') + +ops = Option( + batch_size=128, + num_epochs=100, + lr=3e-4, + repeats=3, + num_layers=3, + num_filters=400, + use_crf=True, + gradient_clip=5, +) + +@cache_results('ontonotes-cache') +def load_data(): + + data = OntoNoteNERDataLoader(encoding_type=encoding_type).process(data_path, + lower=True) + + # char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], + # kernel_sizes=[3]) + + word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], + model_dir_or_name='en-glove-840b-300', + requires_grad=True) + return data, [word_embed] + +data, embeds = load_data() +print(data.datasets['train'][0]) +print(list(data.vocabs.keys())) + +for ds in data.datasets.values(): + ds.rename_field('cap_words', 'chars') + ds.set_input('chars') + +word_embed = embeds[0] +char_embed = CNNCharEmbedding(data.vocabs['cap_words']) +# for ds in data.datasets: +# ds.rename_field('') + +print(data.vocabs[Const.TARGET].word2idx) + +model = IDCNN(init_embed=word_embed, + char_embed=char_embed, + num_cls=len(data.vocabs[Const.TARGET]), + repeats=ops.repeats, + num_layers=ops.num_layers, + num_filters=ops.num_filters, + kernel_size=3, + use_crf=ops.use_crf, use_projection=True, + block_loss=True, + input_dropout=0.33, hidden_dropout=0.2, inner_dropout=0.2) + +print(model) + +callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='norm'),] + +optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=0) +# scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) +# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 15))) +# optimizer = SWATS(model.parameters(), verbose=True) +# optimizer = Adam(model.parameters(), lr=0.005) + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, + sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), + device=device, dev_data=data.datasets['dev'], batch_size=ops.batch_size, + metrics=SpanFPreRecMetric( + tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), + check_code_level=-1, + callbacks=callbacks, num_workers=2, n_epochs=ops.num_epochs) +trainer.train() diff --git a/reproduction/text_classification/data/yelpLoader.py b/reproduction/text_classification/data/yelpLoader.py index d97f9399..704c29e5 100644 --- a/reproduction/text_classification/data/yelpLoader.py +++ b/reproduction/text_classification/data/yelpLoader.py @@ -8,18 +8,7 @@ from fastNLP.io.base_loader import DataInfo from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.file_reader import _read_json from typing import Union, Dict -from reproduction.utils import check_dataloader_paths - - -def get_tokenizer(): - try: - import spacy - en = spacy.load('en') - print('use spacy tokenizer') - return lambda x: [w.text for w in en.tokenizer(x)] - except Exception as e: - print('use raw tokenizer') - return lambda x: x.split() +from reproduction.utils import check_dataloader_paths, get_tokenizer def clean_str(sentence, tokenizer, char_lower=False): """ diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index 294a0742..9664bf75 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -9,6 +9,7 @@ from fastNLP import CrossEntropyLoss, AccuracyMetric from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding from reproduction.text_classification.model.dpcnn import DPCNN from data.yelpLoader import yelpLoader +from fastNLP.core.sampler import BucketSampler import torch.nn as nn from fastNLP.core import LRScheduler from fastNLP.core.const import Const as C @@ -28,19 +29,20 @@ class Config(): embedding_grad = True train_epoch = 30 batch_size = 100 - num_classes = 2 task = "yelp_p" - #datadir = '/remote-home/yfshao/workdir/datasets/SST' - datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity' + #datadir = 'workdir/datasets/SST' + datadir = 'workdir/datasets/yelp_polarity' + # datadir = 'workdir/datasets/yelp_full' #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} datafile = {"train": "train.csv", "test": "test.csv"} lr = 1e-3 - src_vocab_op = VocabularyOption() + src_vocab_op = VocabularyOption(max_size=100000) embed_dropout = 0.3 cls_dropout = 0.1 - weight_decay = 1e-4 + weight_decay = 1e-5 def __init__(self): + self.datadir = os.path.join(os.environ['HOME'], self.datadir) self.datapath = {k: os.path.join(self.datadir, v) for k, v in self.datafile.items()} @@ -53,6 +55,8 @@ print('RNG SEED: {}'.format(ops.seed)) # 1.task相关信息:利用dataloader载入dataInfo #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) + + @cache_results(ops.model_dir_or_name+'-data-cache') def load_data(): datainfo = yelpLoader(fine_grained=True, lower=True).process( @@ -61,28 +65,23 @@ def load_data(): ds.apply_field(len, C.INPUT, C.INPUT_LEN) ds.set_input(C.INPUT, C.INPUT_LEN) ds.set_target(C.TARGET) - return datainfo + embedding = StaticEmbedding( + datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad, + normalize=False + ) + return datainfo, embedding -datainfo = load_data() + +datainfo, embedding = load_data() # 2.或直接复用fastNLP的模型 -vocab = datainfo.vocabs['words'] # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) -#embedding = StaticEmbedding(vocab) -embedding = StaticEmbedding( - vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad, - normalize=False -) -print(len(datainfo.datasets['train'])) -print(len(datainfo.datasets['test'])) +print(datainfo) print(datainfo.datasets['train'][0]) -print(len(vocab)) -print(len(datainfo.vocabs['target'])) - -model = DPCNN(init_embed=embedding, num_cls=ops.num_classes, +model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) print(model) @@ -93,11 +92,11 @@ optimizer = SGD([param for param in model.parameters() if param.requires_grad == lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) callbacks = [] -callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) -# callbacks.append -# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < -# ops.train_epoch * 0.8 else ops.lr * 0.1)) -# ) +# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) +callbacks.append( + LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < + ops.train_epoch * 0.8 else ops.lr * 0.1)) +) # callbacks.append( # FitlogCallback(data=datainfo.datasets, verbose=1) @@ -109,6 +108,7 @@ print(device) # 4.定义train方法 trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), metrics=[metric], dev_data=datainfo.datasets['test'], device=device, check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, diff --git a/reproduction/utils.py b/reproduction/utils.py index 4f0d021e..a7d2de85 100644 --- a/reproduction/utils.py +++ b/reproduction/utils.py @@ -57,4 +57,13 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: else: raise TypeError(f"paths only supports str and dict. not {type(paths)}.") - +def get_tokenizer(): + try: + import spacy + spacy.prefer_gpu() + en = spacy.load('en') + print('use spacy tokenizer') + return lambda x: [w.text for w in en.tokenizer(x)] + except Exception as e: + print('use raw tokenizer') + return lambda x: x.split()