diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index c5784f59..d8ba1ad1 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -62,16 +62,27 @@ class BucketSampler(Sampler): 带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 :param int num_buckets: bucket的数量 - :param int batch_size: batch的大小 + :param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 + 要显示传递该值 :param str seq_len_field_name: 对应序列长度的 `field` 的名字 """ - def __init__(self, num_buckets=10, batch_size=32, seq_len_field_name='seq_len'): + def __init__(self, num_buckets=10, batch_size=None, seq_len_field_name='seq_len'): self.num_buckets = num_buckets self.batch_size = batch_size self.seq_len_field_name = seq_len_field_name - + + def set_batch_size(self, batch_size): + """ + + :param int batch_size: 每个batch的大小 + :return: + """ + self.batch_size = batch_size + def __call__(self, data_set): + if self.batch_size is None: + raise RuntimeError("batch_size is None.") seq_lens = data_set.get_all_fields()[self.seq_len_field_name].content total_sample_num = len(seq_lens) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index eabda99c..8fa44438 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -443,6 +443,8 @@ class Trainer(object): if sampler is None: sampler = RandomSampler() + elif hasattr(sampler, 'set_batch_size'): + sampler.set_batch_size(batch_size) if isinstance(train_data, DataSet): self.data_iterator = DataSetIter( diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index 62793836..5d61c16a 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -111,7 +111,7 @@ def _uncompress(src, dst): class DataBundle: """ - 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 + 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。 :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict diff --git a/fastNLP/models/sequence_labeling.py b/fastNLP/models/sequence_labeling.py index 8e6a5db1..506ebdc6 100644 --- a/fastNLP/models/sequence_labeling.py +++ b/fastNLP/models/sequence_labeling.py @@ -3,17 +3,76 @@ """ __all__ = [ "SeqLabeling", - "AdvSeqLabel" + "AdvSeqLabel", + "BiLSTMCRF" ] import torch import torch.nn as nn +import torch.nn.functional as F from .base_model import BaseModel from ..modules import decoder, encoder from ..modules.decoder.crf import allowed_transitions from ..core.utils import seq_len_to_mask from ..core.const import Const as C +from ..modules import LSTM +from ..modules import get_embeddings +from ..modules import ConditionalRandomField + + +class BiLSTMCRF(BaseModel): + """ + 结构为BiLSTM + FC + Dropout + CRF. + TODO 补充文档 + :param embed: tuple: + :param num_classes: + :param num_layers: + :param hidden_size: + :param dropout: + :param target_vocab: + :param encoding_type: + """ + def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, + target_vocab=None, encoding_type=None): + super().__init__() + self.embed = get_embeddings(embed) + + if num_layers>1: + self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, + batch_first=True, dropout=dropout) + else: + self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, + batch_first=True) + + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(hidden_size, num_classes) + + trans = None + if target_vocab is not None and encoding_type is not None: + trans = allowed_transitions(target_vocab.idx2word, encoding_type=encoding_type, include_start_end=True) + + self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans) + + def _forward(self, words, seq_len=None, target=None): + words = self.embed(words) + feats = self.lstm(words, seq_len=seq_len) + feats = self.fc(feats) + feats = self.dropout(feats) + logits = F.log_softmax(feats, dim=-1) + mask = seq_len_to_mask(seq_len) + if target is None: + pred, _ = self.crf.viterbi_decode(logits, mask) + return {C.OUTPUT:pred} + else: + loss = self.crf(logits, target, mask).mean() + return {C.LOSS:loss} + + def forward(self, words, seq_len, target): + return self._forward(words, seq_len, target) + + def predict(self, words, seq_len): + return self._forward(words, seq_len) class SeqLabeling(BaseModel): diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 050a423a..0639959b 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -689,7 +689,7 @@ class BertEmbedding(ContextualEmbedding): outputs = self.model(words) outputs = torch.cat([*outputs], dim=-1) - return self.dropout(words) + return self.dropout(outputs) @property def requires_grad(self): diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 5e599a65..695f56b8 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -10,10 +10,6 @@ import torch import torch.nn as nn import torch.nn.utils.rnn as rnn -from ..utils import initial_parameter -from torch import autograd - - class LSTM(nn.Module): """ 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` diff --git a/reproduction/Biaffine_parser/cfg.cfg b/reproduction/legacy/Biaffine_parser/cfg.cfg similarity index 100% rename from reproduction/Biaffine_parser/cfg.cfg rename to reproduction/legacy/Biaffine_parser/cfg.cfg diff --git a/reproduction/Biaffine_parser/infer.py b/reproduction/legacy/Biaffine_parser/infer.py similarity index 100% rename from reproduction/Biaffine_parser/infer.py rename to reproduction/legacy/Biaffine_parser/infer.py diff --git a/reproduction/Biaffine_parser/main.py b/reproduction/legacy/Biaffine_parser/main.py similarity index 100% rename from reproduction/Biaffine_parser/main.py rename to reproduction/legacy/Biaffine_parser/main.py diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/legacy/Biaffine_parser/run.py similarity index 100% rename from reproduction/Biaffine_parser/run.py rename to reproduction/legacy/Biaffine_parser/run.py diff --git a/reproduction/Biaffine_parser/util.py b/reproduction/legacy/Biaffine_parser/util.py similarity index 100% rename from reproduction/Biaffine_parser/util.py rename to reproduction/legacy/Biaffine_parser/util.py diff --git a/reproduction/Chinese_word_segmentation/__init__.py b/reproduction/legacy/Chinese_word_segmentation/__init__.py similarity index 100% rename from reproduction/Chinese_word_segmentation/__init__.py rename to reproduction/legacy/Chinese_word_segmentation/__init__.py diff --git a/reproduction/Chinese_word_segmentation/cws.cfg b/reproduction/legacy/Chinese_word_segmentation/cws.cfg similarity index 100% rename from reproduction/Chinese_word_segmentation/cws.cfg rename to reproduction/legacy/Chinese_word_segmentation/cws.cfg diff --git a/reproduction/Chinese_word_segmentation/cws_io/__init__.py b/reproduction/legacy/Chinese_word_segmentation/cws_io/__init__.py similarity index 100% rename from reproduction/Chinese_word_segmentation/cws_io/__init__.py rename to reproduction/legacy/Chinese_word_segmentation/cws_io/__init__.py diff --git a/reproduction/Chinese_word_segmentation/cws_io/cws_reader.py b/reproduction/legacy/Chinese_word_segmentation/cws_io/cws_reader.py similarity index 100% rename from reproduction/Chinese_word_segmentation/cws_io/cws_reader.py rename to reproduction/legacy/Chinese_word_segmentation/cws_io/cws_reader.py diff --git a/reproduction/Chinese_word_segmentation/models/__init__.py b/reproduction/legacy/Chinese_word_segmentation/models/__init__.py similarity index 100% rename from reproduction/Chinese_word_segmentation/models/__init__.py rename to reproduction/legacy/Chinese_word_segmentation/models/__init__.py diff --git a/reproduction/Chinese_word_segmentation/models/cws_model.py b/reproduction/legacy/Chinese_word_segmentation/models/cws_model.py similarity index 98% rename from reproduction/Chinese_word_segmentation/models/cws_model.py rename to reproduction/legacy/Chinese_word_segmentation/models/cws_model.py index b41ad87d..0d10d2e5 100644 --- a/reproduction/Chinese_word_segmentation/models/cws_model.py +++ b/reproduction/legacy/Chinese_word_segmentation/models/cws_model.py @@ -4,7 +4,7 @@ from torch import nn from fastNLP.models.base_model import BaseModel from fastNLP.modules.decoder.mlp import MLP -from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask +from reproduction.legacy.Chinese_word_segmentation.utils import seq_lens_to_mask class CWSBiLSTMEncoder(BaseModel): diff --git a/reproduction/Chinese_word_segmentation/models/cws_transformer.py b/reproduction/legacy/Chinese_word_segmentation/models/cws_transformer.py similarity index 97% rename from reproduction/Chinese_word_segmentation/models/cws_transformer.py rename to reproduction/legacy/Chinese_word_segmentation/models/cws_transformer.py index e8ae5ecc..ae8a5a7f 100644 --- a/reproduction/Chinese_word_segmentation/models/cws_transformer.py +++ b/reproduction/legacy/Chinese_word_segmentation/models/cws_transformer.py @@ -9,7 +9,7 @@ from torch import nn import torch # from fastNLP.modules.encoder.transformer import TransformerEncoder -from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder +from reproduction.legacy.Chinese_word_segmentation.models import TransformerEncoder from fastNLP.modules.decoder.crf import ConditionalRandomField,seq_len_to_byte_mask from fastNLP.modules.decoder.crf import allowed_transitions @@ -79,7 +79,7 @@ class TransformerCWS(nn.Module): return {'pred': probs, 'seq_lens':seq_lens} -from reproduction.Chinese_word_segmentation.models.dilated_transformer import TransformerDilateEncoder +from reproduction.legacy.Chinese_word_segmentation.models import TransformerDilateEncoder class TransformerDilatedCWS(nn.Module): def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, diff --git a/reproduction/Chinese_word_segmentation/process/__init__.py b/reproduction/legacy/Chinese_word_segmentation/process/__init__.py similarity index 100% rename from reproduction/Chinese_word_segmentation/process/__init__.py rename to reproduction/legacy/Chinese_word_segmentation/process/__init__.py diff --git a/reproduction/Chinese_word_segmentation/process/cws_processor.py b/reproduction/legacy/Chinese_word_segmentation/process/cws_processor.py similarity index 99% rename from reproduction/Chinese_word_segmentation/process/cws_processor.py rename to reproduction/legacy/Chinese_word_segmentation/process/cws_processor.py index 614d9ef5..1f64bed2 100644 --- a/reproduction/Chinese_word_segmentation/process/cws_processor.py +++ b/reproduction/legacy/Chinese_word_segmentation/process/cws_processor.py @@ -4,7 +4,7 @@ import re from fastNLP.api.processor import Processor from fastNLP.core.dataset import DataSet from fastNLP.core.vocabulary import Vocabulary -from reproduction.Chinese_word_segmentation.process.span_converter import SpanConverter +from reproduction.legacy.Chinese_word_segmentation.process.span_converter import SpanConverter _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' diff --git a/reproduction/Chinese_word_segmentation/process/span_converter.py b/reproduction/legacy/Chinese_word_segmentation/process/span_converter.py similarity index 100% rename from reproduction/Chinese_word_segmentation/process/span_converter.py rename to reproduction/legacy/Chinese_word_segmentation/process/span_converter.py diff --git a/reproduction/Chinese_word_segmentation/utils.py b/reproduction/legacy/Chinese_word_segmentation/utils.py similarity index 100% rename from reproduction/Chinese_word_segmentation/utils.py rename to reproduction/legacy/Chinese_word_segmentation/utils.py diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/README.md b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/README.md similarity index 100% rename from reproduction/LSTM+self_attention_sentiment_analysis/README.md rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/README.md diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/Word2Idx.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/Word2Idx.py similarity index 100% rename from reproduction/LSTM+self_attention_sentiment_analysis/Word2Idx.py rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/Word2Idx.py diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/config.cfg b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/config.cfg similarity index 100% rename from reproduction/LSTM+self_attention_sentiment_analysis/config.cfg rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/config.cfg diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/dataloader.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/dataloader.py similarity index 100% rename from reproduction/LSTM+self_attention_sentiment_analysis/dataloader.py rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/dataloader.py diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/example.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/example.py similarity index 100% rename from reproduction/LSTM+self_attention_sentiment_analysis/example.py rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/example.py diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/main.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/main.py similarity index 100% rename from reproduction/LSTM+self_attention_sentiment_analysis/main.py rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/main.py diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/predict.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/predict.py similarity index 100% rename from reproduction/LSTM+self_attention_sentiment_analysis/predict.py rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/predict.py diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/prepare.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/prepare.py similarity index 100% rename from reproduction/LSTM+self_attention_sentiment_analysis/prepare.py rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/prepare.py diff --git a/reproduction/POS_tagging/pos_processor.py b/reproduction/legacy/POS_tagging/pos_processor.py similarity index 100% rename from reproduction/POS_tagging/pos_processor.py rename to reproduction/legacy/POS_tagging/pos_processor.py diff --git a/reproduction/POS_tagging/pos_reader.py b/reproduction/legacy/POS_tagging/pos_reader.py similarity index 100% rename from reproduction/POS_tagging/pos_reader.py rename to reproduction/legacy/POS_tagging/pos_reader.py diff --git a/reproduction/POS_tagging/pos_tag.cfg b/reproduction/legacy/POS_tagging/pos_tag.cfg similarity index 100% rename from reproduction/POS_tagging/pos_tag.cfg rename to reproduction/legacy/POS_tagging/pos_tag.cfg diff --git a/reproduction/POS_tagging/train_pos_tag.py b/reproduction/legacy/POS_tagging/train_pos_tag.py similarity index 100% rename from reproduction/POS_tagging/train_pos_tag.py rename to reproduction/legacy/POS_tagging/train_pos_tag.py diff --git a/reproduction/POS_tagging/utils.py b/reproduction/legacy/POS_tagging/utils.py similarity index 100% rename from reproduction/POS_tagging/utils.py rename to reproduction/legacy/POS_tagging/utils.py diff --git a/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py b/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py new file mode 100644 index 00000000..cec5ab76 --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py @@ -0,0 +1,115 @@ + + +from fastNLP.io.base_loader import DataSetLoader, DataBundle +from fastNLP.io import ConllLoader +from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 +from fastNLP import Const +from reproduction.utils import check_dataloader_paths +from fastNLP import Vocabulary + +class ChineseNERLoader(DataSetLoader): + """ + 读取中文命名实体数据集,包括PeopleDaily, MSRA-NER, Weibo。数据在这里可以找到https://github.com/OYE93/Chinese-NLP-Corpus/tree/master/NER + 请确保输入数据的格式如下, 共两列,第一列为字,第二列为标签,不同句子以空行隔开 + 我 O + 们 O + 变 O + 而 O + 以 O + 书 O + 会 O + ... + + """ + def __init__(self, encoding_type:str='bioes'): + """ + + :param str encoding_type: 支持bio和bioes格式 + """ + super().__init__() + self._loader = ConllLoader(headers=['raw_chars', 'target'], indexes=[0, 1]) + + assert encoding_type in ('bio', 'bioes') + + self._tag_converters = [iob2] + if encoding_type == 'bioes': + self._tag_converters.append(iob2bioes) + + def load(self, path:str): + dataset = self._loader.load(path) + def convert_tag_schema(tags): + for converter in self._tag_converters: + tags = converter(tags) + return tags + if self._tag_converters: + dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) + return dataset + + def process(self, paths, bigrams=False, trigrams=False): + """ + + :param paths: + :param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d] + :param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd, d] + :return: DataBundle + 包含以下的fields + raw_chars: List[str] + chars: List[int] + seq_len: int, 字的长度 + bigrams: List[int], optional + trigrams: List[int], optional + target: List[int] + """ + paths = check_dataloader_paths(paths) + data = DataBundle() + input_fields = [Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET] + target_fields = [Const.TARGET, Const.INPUT_LEN] + + for name, path in paths.items(): + dataset = self.load(path) + if bigrams: + dataset.apply_field(lambda raw_chars: [c1+c2 for c1, c2 in zip(raw_chars, raw_chars[1:]+[''])], + field_name='raw_chars', new_field_name='bigrams') + + if trigrams: + dataset.apply_field(lambda raw_chars: [c1+c2+c3 for c1, c2, c3 in zip(raw_chars, + raw_chars[1:]+[''], + raw_chars[2:]+['']*2)], + field_name='raw_chars', new_field_name='trigrams') + data.datasets[name] = dataset + + char_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='raw_chars', + no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) + char_vocab.index_dataset(*data.datasets.values(), field_name='raw_chars', new_field_name=Const.CHAR_INPUT) + data.vocabs[Const.CHAR_INPUT] = char_vocab + + target_vocab = Vocabulary(unknown=None, padding=None).from_dataset(data.datasets['train'], field_name=Const.TARGET) + target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) + data.vocabs[Const.TARGET] = target_vocab + + if bigrams: + bigram_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='bigrams', + no_create_entry_dataset=[dataset for name, dataset in + data.datasets.items() if name != 'train']) + bigram_vocab.index_dataset(*data.datasets.values(), field_name='bigrams', new_field_name='bigrams') + data.vocabs['bigrams'] = bigram_vocab + input_fields.append('bigrams') + + if trigrams: + trigram_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='trigrams', + no_create_entry_dataset=[dataset for name, dataset in + data.datasets.items() if name != 'train']) + trigram_vocab.index_dataset(*data.datasets.values(), field_name='trigrams', new_field_name='trigrams') + data.vocabs['trigrams'] = trigram_vocab + input_fields.append('trigrams') + + for name, dataset in data.datasets.items(): + dataset.add_seq_len(Const.CHAR_INPUT) + dataset.set_input(*input_fields) + dataset.set_target(*target_fields) + + return data + + + + diff --git a/reproduction/seqence_labelling/chinese_ner/data/__init__.py b/reproduction/seqence_labelling/chinese_ner/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py b/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py new file mode 100644 index 00000000..1993898b --- /dev/null +++ b/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py @@ -0,0 +1,94 @@ + + + +from reproduction.seqence_labelling.chinese_ner.data.ChineseNER import ChineseNERLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding + +from torch import nn +import torch +from fastNLP.modules import get_embeddings +from fastNLP.modules import LSTM +from fastNLP.modules import ConditionalRandomField +from fastNLP.modules import allowed_transitions +import torch.nn.functional as F +from fastNLP import seq_len_to_mask +from fastNLP.core.const import Const as C +from fastNLP import SpanFPreRecMetric, Trainer +from fastNLP import cache_results + +class CNBiLSTMCRFNER(nn.Module): + def __init__(self, char_embed, num_classes, bigram_embed=None, trigram_embed=None, num_layers=1, hidden_size=100, + dropout=0.5, target_vocab=None, encoding_type=None): + super().__init__() + + self.char_embed = get_embeddings(char_embed) + embed_size = self.char_embed.embedding_dim + if bigram_embed: + self.bigram_embed = get_embeddings(bigram_embed) + embed_size += self.bigram_embed.embedding_dim + if trigram_embed: + self.trigram_ebmbed = get_embeddings(trigram_embed) + embed_size += self.bigram_embed.embedding_dim + + if num_layers>1: + self.lstm = LSTM(embed_size, num_layers=num_layers, hidden_size=hidden_size//2, bidirectional=True, + batch_first=True, dropout=dropout) + else: + self.lstm = LSTM(embed_size, num_layers=num_layers, hidden_size=hidden_size//2, bidirectional=True, + batch_first=True) + + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(hidden_size, num_classes) + + trans = None + if target_vocab is not None and encoding_type is not None: + trans = allowed_transitions(target_vocab.idx2word, encoding_type=encoding_type, include_start_end=True) + + self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans) + + def _forward(self, chars, bigrams=None, trigrams=None, seq_len=None, target=None): + chars = self.char_embed(chars) + if hasattr(self, 'bigram_embed'): + bigrams = self.bigram_embed(bigrams) + chars = torch.cat((chars, bigrams), dim=-1) + if hasattr(self, 'trigram_embed'): + trigrams = self.trigram_embed(trigrams) + chars = torch.cat((chars, trigrams), dim=-1) + feats, _ = self.lstm(chars, seq_len=seq_len) + feats = self.fc(feats) + feats = self.dropout(feats) + logits = F.log_softmax(feats, dim=-1) + mask = seq_len_to_mask(seq_len) + if target is None: + pred, _ = self.crf.viterbi_decode(logits, mask) + return {C.OUTPUT: pred} + else: + loss = self.crf(logits, target, mask).mean() + return {C.LOSS:loss} + + def forward(self, chars, target, bigrams=None, trigrams=None, seq_len=None): + return self._forward(chars, bigrams, trigrams, seq_len, target) + + def predict(self, chars, seq_len=None, bigrams=None, trigrams=None): + return self._forward(chars, bigrams, trigrams, seq_len) + +# data_bundle = pickle.load(open('caches/msra.pkl', 'rb')) +@cache_results('caches/msra.pkl', _refresh=False) +def get_data(): + data_bundle = ChineseNERLoader().process('/remote-home/hyan01/exps/fastNLP/others/data/MSRA-NER', bigrams=True) + char_embed = StaticEmbedding(data_bundle.vocabs['chars'], + model_dir_or_name='/remote-home/hyan01/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt') + bigram_embed = StaticEmbedding(data_bundle.vocabs['bigrams'], + model_dir_or_name='/remote-home/hyan01/exps/CWS/pretrain/vectors/2gram_t3_m50_merge.txt') + return data_bundle, char_embed, bigram_embed +data_bundle, char_embed, bigram_embed = get_data() +print(data_bundle) +# exit(0) +data_bundle.datasets['train'].set_input('target') +data_bundle.datasets['dev'].set_input('target') +model = CNBiLSTMCRFNER(char_embed, num_classes=len(data_bundle.vocabs['target']), bigram_embed=bigram_embed) + +Trainer(data_bundle.datasets['train'], model, batch_size=640, + metrics=SpanFPreRecMetric(data_bundle.vocabs['target'], encoding_type='bioes'), + num_workers=2, dev_data=data_bundle. datasets['dev'], device=3).train() + diff --git a/reproduction/seqence_labelling/ner/train_ontonote.py b/reproduction/seqence_labelling/ner/train_ontonote.py index 6548cb9f..33f015d8 100644 --- a/reproduction/seqence_labelling/ner/train_ontonote.py +++ b/reproduction/seqence_labelling/ner/train_ontonote.py @@ -12,54 +12,72 @@ from fastNLP import Const from torch.optim import SGD, Adam from torch.optim.lr_scheduler import LambdaLR from fastNLP import GradientClipCallback +from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.callback import FitlogCallback, LRScheduler -from reproduction.seqence_labelling.ner.model.swats import SWATS +from functools import partial +from torch import nn +from fastNLP import cache_results import fitlog fitlog.debug() +fitlog.set_log_dir('logs/') + +fitlog.add_hyper_in_file(__file__) +#######hyper +normalize = False +divide_std = True +lower = False +lr = 0.015 +dropout = 0.5 +batch_size = 20 +init_method = 'default' +job_embed = False +data_name = 'ontonote' +#######hyper + + +init_method = {'default': None, + 'xavier': partial(nn.init.xavier_normal_, gain=0.02), + 'normal': partial(nn.init.normal_, std=0.02) + }[init_method] + from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader encoding_type = 'bioes' -data = OntoNoteNERDataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/v4/english', - lower=True) - -import joblib -raw_data = joblib.load('/hdd/fudanNLP/fastNLP/others/NER-with-LS/data/ontonotes_with_data.joblib') -def convert_to_ids(raw_words): - ids = [] - for word in raw_words: - id = raw_data['word_to_id'][word] - id = raw_data['id_to_emb_map'][id] - ids.append(id) - return ids -word_embed = raw_data['emb_matrix'] -for name, dataset in data.datasets.items(): - dataset.apply_field(convert_to_ids, field_name='raw_words', new_field_name=Const.INPUT) +@cache_results('caches/ontonotes.pkl') +def cache(): + data = OntoNoteNERDataLoader(encoding_type=encoding_type).process('../../../../others/data/v4/english', + lower=lower, + word_vocab_opt=VocabularyOption(min_freq=1)) + char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], + kernel_sizes=[3]) + word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], + model_dir_or_name='/remote-home/hyan01/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt', + requires_grad=True, + normalize=normalize, + init_method=init_method) + return data, char_embed, word_embed +data, char_embed, word_embed = cache() print(data) -char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], - kernel_sizes=[3]) -# word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], -# model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt', -# requires_grad=True) model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], - encoding_type=encoding_type) + encoding_type=encoding_type, dropout=dropout) -callbacks = [GradientClipCallback(clip_value=5, clip_type='value'), - FitlogCallback(data.datasets['test'], verbose=1)] +callbacks = [ + GradientClipCallback(clip_value=5, clip_type='value'), + FitlogCallback(data.datasets['test'], verbose=1) + ] -optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9) +optimizer = SGD(model.parameters(), lr=lr, momentum=0.9) scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) callbacks.append(scheduler) -# optimizer = SWATS(model.parameters(), verbose=True) -# optimizer = Adam(model.parameters(), lr=0.005) -trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(num_buckets=100), - device=0, dev_data=data.datasets['dev'], batch_size=10, +trainer = Trainer(train_data=data.datasets['dev'][:100], model=model, optimizer=optimizer, sampler=None, + device=0, dev_data=data.datasets['dev'][:100], batch_size=batch_size, metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), callbacks=callbacks, num_workers=1, n_epochs=100) trainer.train() \ No newline at end of file diff --git a/reproduction/text_classification/data/IMDBLoader.py b/reproduction/text_classification/data/IMDBLoader.py index d57ee41b..94244431 100644 --- a/reproduction/text_classification/data/IMDBLoader.py +++ b/reproduction/text_classification/data/IMDBLoader.py @@ -10,7 +10,6 @@ from fastNLP import Const from functools import partial from reproduction.utils import check_dataloader_paths, get_tokenizer - class IMDBLoader(DataSetLoader): """ 读取IMDB数据集,DataSet包含以下fields: @@ -51,6 +50,7 @@ class IMDBLoader(DataSetLoader): datasets = {} info = DataBundle() + paths = check_dataloader_paths(paths) for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset diff --git a/reproduction/text_classification/model/lstm.py b/reproduction/text_classification/model/lstm.py index 388f3f1c..93b4d8a9 100644 --- a/reproduction/text_classification/model/lstm.py +++ b/reproduction/text_classification/model/lstm.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from fastNLP.core.const import Const as C from fastNLP.modules.encoder.lstm import LSTM -from fastNLP.modules import encoder +from fastNLP.modules import get_embeddings from fastNLP.modules.decoder.mlp import MLP @@ -13,14 +13,14 @@ class BiLSTMSentiment(nn.Module): num_layers=1, nfc=128): super(BiLSTMSentiment,self).__init__() - self.embed = encoder.Embedding(init_embed) + self.embed = get_embeddings(init_embed) self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) - self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) + self.mlp = MLP(size_layer=[hidden_dim*2, nfc, num_classes]) def forward(self, words): x_emb = self.embed(words) output, _ = self.lstm(x_emb) - output = self.mlp(output[:,-1,:]) + output = self.mlp(torch.max(output, dim=1)[0]) return {C.OUTPUT: output} def predict(self, words): diff --git a/reproduction/text_classification/train_bert.py b/reproduction/text_classification/train_bert.py index e69de29b..4db54958 100644 --- a/reproduction/text_classification/train_bert.py +++ b/reproduction/text_classification/train_bert.py @@ -0,0 +1,33 @@ +import sys +sys.path.append('../../') + +from reproduction.text_classification.data.IMDBLoader import IMDBLoader +from fastNLP.modules.encoder.embedding import BertEmbedding +from reproduction.text_classification.model.lstm import BiLSTMSentiment +from fastNLP import Trainer +from fastNLP import CrossEntropyLoss, AccuracyMetric +from fastNLP import cache_results +from fastNLP import Tester + +# 对返回结果进行缓存,下一次运行就会自动跳过预处理 +@cache_results('imdb.pkl') +def get_data(): + data_bundle = IMDBLoader().process('imdb/') + return data_bundle +data_bundle = get_data() + +print(data_bundle) + +# 删除超过512, 但由于英语中会把word进行word piece处理,所以截取的时候做一点的裕量 +data_bundle.datasets['train'].drop(lambda x:len(x['words'])>400) +data_bundle.datasets['dev'].drop(lambda x:len(x['words'])>400) +data_bundle.datasets['test'].drop(lambda x:len(x['words'])>400) +bert_embed = BertEmbedding(data_bundle.vocabs['words'], requires_grad=False, + model_dir_or_name="en-base") +model = BiLSTMSentiment(bert_embed, len(data_bundle.vocabs['target'])) + +Trainer(data_bundle.datasets['train'], model, optimizer=None, loss=CrossEntropyLoss(), device=0, + batch_size=10, dev_data=data_bundle.datasets['dev'], metrics=AccuracyMetric()).train() + +# 在测试集上测试一下效果 +Tester(data_bundle.datasets['test'], model, batch_size=32, metrics=AccuracyMetric()).test() \ No newline at end of file