From d71f0eef139dc6d78a8ca220e375383d5698f6b9 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 4 Jun 2019 23:40:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BA=8F=E5=88=97=E6=A0=87=E6=B3=A8=E7=9A=84Se?= =?UTF-8?q?miCRFRelay=E4=B8=AD=E6=96=87=E5=88=86=E8=AF=8D.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 15 +- fastNLP/core/utils.py | 2 +- fastNLP/core/vocabulary.py | 4 +- fastNLP/io/embed_loader.py | 4 +- .../data/CWSDataLoader.py | 159 ++++++++------ .../seqence_labelling/cws/model/metric.py | 44 ++++ .../seqence_labelling/cws/model/model.py | 74 +++++++ .../seqence_labelling/cws/model/module.py | 198 ++++++++++++++++++ .../seqence_labelling/cws/test/__init__.py | 0 .../cws/test/test_CWSDataLoader.py | 17 ++ .../cws/train_shift_relay.py | 68 ++++++ reproduction/utils.py | 51 +++++ 12 files changed, 566 insertions(+), 70 deletions(-) rename reproduction/seqence_labelling/{Chinese_Word_Segmentation => cws}/data/CWSDataLoader.py (62%) create mode 100644 reproduction/seqence_labelling/cws/model/metric.py create mode 100644 reproduction/seqence_labelling/cws/model/model.py create mode 100644 reproduction/seqence_labelling/cws/model/module.py create mode 100644 reproduction/seqence_labelling/cws/test/__init__.py create mode 100644 reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py create mode 100644 reproduction/seqence_labelling/cws/train_shift_relay.py create mode 100644 reproduction/utils.py diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 40e5a5c1..57a31a69 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -494,14 +494,15 @@ class Trainer(object): self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) - def train(self, load_best_model=True, on_exception='ignore'): + def train(self, load_best_model=True, on_exception='auto'): """ 使用该函数使Trainer开始训练。 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 最好的模型参数。 :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 - 支持'ignore'与'raise': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出。 + 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; + 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. :return dict: 返回一个字典类型的数据, 内含以下内容:: @@ -530,12 +531,16 @@ class Trainer(object): self.callback_manager.on_train_begin() self._train() self.callback_manager.on_train_end() - except (CallbackException, KeyboardInterrupt, Exception) as e: + + except Exception as e: self.callback_manager.on_exception(e) - if on_exception=='raise': + if on_exception == 'auto': + if not isinstance(e, (CallbackException, KeyboardInterrupt)): + raise e + elif on_exception == 'raise': raise e - if self.dev_data is not None and hasattr(self, 'best_dev_perf'): + if self.dev_data is not None and self.best_dev_perf is not None: print( "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + self.tester._format_eval_results(self.best_dev_perf), ) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 58436a35..9dab47b5 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -4,7 +4,7 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 __all__ = [ "cache_results", "seq_len_to_mask", - "Example", + "Option", ] import _pickle diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 0cf45049..bca28e10 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -6,10 +6,10 @@ __all__ = [ from functools import wraps from collections import Counter from .dataset import DataSet -from .utils import Example +from .utils import Option -class VocabularyOption(Example): +class VocabularyOption(Option): def __init__(self, max_size=None, min_freq=None, diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index 93861258..bc37777e 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -10,10 +10,10 @@ import numpy as np from ..core.vocabulary import Vocabulary from .base_loader import BaseLoader -from ..core.utils import Example +from ..core.utils import Option -class EmbeddingOption(Example): +class EmbeddingOption(Option): def __init__(self, embed_filepath=None, dtype=np.float32, diff --git a/reproduction/seqence_labelling/Chinese_Word_Segmentation/data/CWSDataLoader.py b/reproduction/seqence_labelling/cws/data/CWSDataLoader.py similarity index 62% rename from reproduction/seqence_labelling/Chinese_Word_Segmentation/data/CWSDataLoader.py rename to reproduction/seqence_labelling/cws/data/CWSDataLoader.py index 1000c204..e8440289 100644 --- a/reproduction/seqence_labelling/Chinese_Word_Segmentation/data/CWSDataLoader.py +++ b/reproduction/seqence_labelling/cws/data/CWSDataLoader.py @@ -6,6 +6,9 @@ from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance from fastNLP import Vocabulary +from fastNLP import Const +from reproduction.utils import check_dataloader_paths +from functools import partial class SigHanLoader(DataSetLoader): """ @@ -20,27 +23,43 @@ class SigHanLoader(DataSetLoader): chars: list(str), 每个元素是一个index(汉字对应的index) target: list(int), 根据不同的encoding_type会有不同的变化 - :param target_type: target的类型,当前支持以下的两种: "bmes", "pointer" + :param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay" """ def __init__(self, target_type:str): super().__init__() - if target_type.lower() not in ('bmes', 'pointer'): - raise ValueError("target_type only supports 'bmes', 'pointer'.") + if target_type.lower() not in ('bmes', 'shift_relay'): + raise ValueError("target_type only supports 'bmes', 'shift_relay'.") self.target_type = target_type if target_type=='bmes': self._word_len_to_target = self._word_len_to_bems + elif target_type=='shift_relay': + self._word_len_to_target = self._word_lens_to_relay - + @staticmethod + def _word_lens_to_relay(word_lens: Iterator[int]): + """ + [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); + :param word_lens: + :return: {'target': , 'end_seg_mask':, 'start_seg_mask':} + """ + tags = [] + end_seg_mask = [] + start_seg_mask = [] + for word_len in word_lens: + tags.extend([idx for idx in range(word_len - 1, -1, -1)]) + end_seg_mask.extend([0] * (word_len - 1) + [1]) + start_seg_mask.extend([1] + [0] * (word_len - 1)) + return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask} @staticmethod - def _word_len_to_bems(word_lens:Iterator[int])->List[str]: + def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]: """ :param word_lens: 每个word的长度 - :return: 返回对应的BMES的str + :return: """ tags = [] for word_len in word_lens: @@ -51,7 +70,7 @@ class SigHanLoader(DataSetLoader): for _ in range(word_len-2): tags.append('M') tags.append('E') - return tags + return {'target':tags} @staticmethod def _gen_bigram(chars:List[str])->List[str]: @@ -71,11 +90,15 @@ class SigHanLoader(DataSetLoader): dataset = DataSet() with open(path, 'r', encoding='utf-8') as f: for line in f: + line = line.strip() + if not line: # 去掉空行 + continue parts = line.split() word_lens = map(len, parts) - chars = list(line) + chars = list(''.join(parts)) tags = self._word_len_to_target(word_lens) - dataset.append(Instance(raw_chars=chars, target=tags)) + assert len(chars)==len(tags['target']) + dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars))) if len(dataset)==0: raise RuntimeError(f"{path} has no valid data.") if bigram: @@ -84,7 +107,7 @@ class SigHanLoader(DataSetLoader): def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, - bigram_embed_opt:EmbeddingOption=None): + bigram_embed_opt:EmbeddingOption=None, L:int=4): """ 支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 @@ -113,7 +136,7 @@ class SigHanLoader(DataSetLoader): data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt # 包含以下的内容data.vocabs['chars']: Vocabulary对象 # data.vocabs['target']:Vocabulary对象 - # data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项 + # data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象; # data.datasets['train']: DataSet对象 # 包含的field有: # raw_chars: list[str], 每个元素是一个汉字 @@ -132,79 +155,95 @@ class SigHanLoader(DataSetLoader): :param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e :param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 + :param L: 当target_type为shift_relay时传入的segment长度 :return: """ # 推荐大家使用这个check_data_loader_paths进行paths的验证 paths = check_dataloader_paths(paths) datasets = {} + data = DataInfo() bigram = bigram_vocab_opt is not None for name, path in paths.items(): dataset = self.load(path, bigram=bigram) datasets[name] = dataset + input_fields = [] + target_fields = [] # 创建vocab char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) char_vocab.from_dataset(datasets['train'], field_name='raw_chars') char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') + data.vocabs[Const.CHAR_INPUT] = char_vocab + input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET]) + target_fields.append(Const.TARGET) # 创建target if self.target_type == 'bmes': target_vocab = Vocabulary(unknown=None, padding=None) target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) target_vocab.index_dataset(*datasets.values(), field_name='target') + data.vocabs[Const.TARGET] = target_vocab + if char_embed_opt is not None: + char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab) + data.embeddings['chars'] = char_embed if bigram: bigram_vocab = Vocabulary(**bigram_vocab_opt) bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') + data.vocabs['bigrams'] = bigram_vocab if bigram_embed_opt is not None: - pass - - - + bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab) + data.embeddings['bigrams'] = bigram_embed + input_fields.append('bigrams') + if self.target_type == 'shift_relay': + func = partial(self._clip_target, L=L) + for name, dataset in datasets.items(): + res = dataset.apply_field(func, field_name='target') + relay_target = [res_i[0] for res_i in res] + relay_mask = [res_i[1] for res_i in res] + dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) + dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) + if self.target_type == 'shift_relay': + input_fields.extend(['end_seg_mask']) + target_fields.append('start_seg_mask') + # 将dataset加入DataInfo + for name, dataset in datasets.items(): + dataset.set_input(*input_fields) + dataset.set_target(*target_fields) + data.datasets[name] = dataset + + return data -import os + @staticmethod + def _clip_target(target:List[int], L:int): + """ -def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: - """ - 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 - { - 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 - 'test': 'xxx' # 可能有,也可能没有 - ... - } - 如果paths为不合法的,将直接进行raise相应的错误 - - :param paths: 路径 - :return: - """ - if isinstance(paths, str): - if os.path.isfile(paths): - return {'train': paths} - elif os.path.isdir(paths): - train_fp = os.path.join(paths, 'train.txt') - if not os.path.isfile(train_fp): - raise FileNotFoundError(f"train.txt is not found in folder {paths}.") - files = {'train': train_fp} - for filename in ['test.txt', 'dev.txt']: - fp = os.path.join(paths, filename) - if os.path.isfile(fp): - files[filename.split('.')[0]] = fp - return files - else: - raise FileNotFoundError(f"{paths} is not a valid file path.") - - elif isinstance(paths, dict): - if paths: - if 'train' not in paths: - raise KeyError("You have to include `train` in your dict.") - for key, value in paths.items(): - if isinstance(key, str) and isinstance(value, str): - if not os.path.isfile(value): - raise TypeError(f"{value} is not a valid file.") - else: - raise TypeError("All keys and values in paths should be str.") - return paths + 只有在target_type为shift_relay的使用 + :param target: List[int] + :param L: + :return: + """ + relay_target_i = [] + tmp = [] + for j in range(len(target) - 1): + tmp.append(target[j]) + if target[j] > target[j + 1]: + pass + else: + relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) + tmp = [] + # 处理未结束的部分 + if len(tmp) == 0: + relay_target_i.append(0) else: - raise ValueError("Empty paths is not allowed.") - else: - raise TypeError(f"paths only supports str and dict. not {type(paths)}.") - + tmp.append(target[-1]) + relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) + relay_mask_i = [] + j = 0 + while j < len(target): + seg_len = target[j] + 1 + if target[j] < L: + relay_mask_i.extend([0] * (seg_len)) + else: + relay_mask_i.extend([1] * (seg_len - L) + [0] * L) + j = seg_len + j + return relay_target_i, relay_mask_i diff --git a/reproduction/seqence_labelling/cws/model/metric.py b/reproduction/seqence_labelling/cws/model/metric.py new file mode 100644 index 00000000..d68e3473 --- /dev/null +++ b/reproduction/seqence_labelling/cws/model/metric.py @@ -0,0 +1,44 @@ + +from fastNLP.core.metrics import MetricBase + + +class RelayMetric(MetricBase): + def __init__(self, pred=None, pred_mask=None, target=None, start_seg_mask=None): + super().__init__() + self._init_param_map(pred=pred, pred_mask=pred_mask, target=target, start_seg_mask=start_seg_mask) + self.tp = 0 + self.rec = 0 + self.pre = 0 + + def evaluate(self, pred, pred_mask, target, start_seg_mask): + """ + 给定每个batch,累计一下结果。 + + :param pred: 预测的结果,为当前位置的开始的segment的(长度-1) + :param pred_mask: 当前位置预测有segment开始 + :param target: 当前位置开始的segment的(长度-1) + :param start_seg_mask: 当前有segment结束 + :return: + """ + self.tp += ((pred.long().eq(target.long())).__and__(pred_mask.byte().__and__(start_seg_mask.byte()))).sum().item() + self.rec += start_seg_mask.sum().item() + self.pre += pred_mask.sum().item() + + def get_metric(self, reset=True): + """ + 在所有数据都计算结束之后,得到performance + + :param reset: + :return: + """ + pre = self.tp/(self.pre + 1e-12) + rec = self.tp/(self.rec + 1e-12) + f = 2*pre*rec/(1e-12 + pre + rec) + + if reset: + self.tp = 0 + self.rec = 0 + self.pre = 0 + self.bigger_than_L = 0 + + return {'f': round(f, 6), 'pre': round(pre, 6), 'rec': round(rec, 6)} diff --git a/reproduction/seqence_labelling/cws/model/model.py b/reproduction/seqence_labelling/cws/model/model.py new file mode 100644 index 00000000..bdd9002d --- /dev/null +++ b/reproduction/seqence_labelling/cws/model/model.py @@ -0,0 +1,74 @@ +from torch import nn +import torch +from fastNLP.modules import Embedding +import numpy as np +from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay +from fastNLP.modules import LSTM + +class ShiftRelayCWSModel(nn.Module): + """ + 该模型可以用于进行分词操作 + 包含两个方法, + forward(chars, bigrams, seq_len) -> {'loss': batch_size,} + predict(chars, bigrams) -> {'pred': batch_size x max_len, 'pred_mask': batch_size x max_len} + pred是对当前segment的长度预测,pred_mask是仅在有预测的位置为1 + + :param char_embed: 预训练的Embedding或者embedding的shape + :param bigram_embed: 预训练的Embedding或者embedding的shape + :param hidden_size: LSTM的隐藏层大小 + :param num_layers: LSTM的层数 + :param L: SemiCRFShiftRelay的segment大小 + :param num_bigram_per_char: 每个character对应的bigram的数量 + :param drop_p: Dropout的大小 + """ + def __init__(self, char_embed:Embedding, bigram_embed:Embedding, hidden_size:int=400, num_layers:int=1, + L:int=6, num_bigram_per_char:int=1, drop_p:float=0.2): + super().__init__() + self.char_embedding = Embedding(char_embed, dropout=drop_p) + self._pretrained_embed = False + if isinstance(char_embed, np.ndarray): + self._pretrained_embed = True + self.bigram_embedding = Embedding(bigram_embed, dropout=drop_p) + self.lstm = LSTM(100 * (num_bigram_per_char + 1), hidden_size // 2, num_layers=num_layers, bidirectional=True, + batch_first=True) + self.feature_fn = FeatureFunMax(hidden_size, L) + self.semi_crf_relay = SemiCRFShiftRelay(L) + self.feat_drop = nn.Dropout(drop_p) + self.reset_param() + # self.feature_fn.reset_parameters() + + def reset_param(self): + for name, param in self.named_parameters(): + if 'embedding' in name and self._pretrained_embed: + continue + if 'bias_hh' in name: + nn.init.constant_(param, 0) + elif 'bias_ih' in name: + nn.init.constant_(param, 1) + elif len(param.size()) < 2: + nn.init.uniform_(param, -0.1, 0.1) + else: + nn.init.xavier_uniform_(param) + + def get_feats(self, chars, bigrams, seq_len): + batch_size, max_len = chars.size() + chars = self.char_embedding(chars) + bigrams = self.bigram_embedding(bigrams) + bigrams = bigrams.view(bigrams.size(0), max_len, -1) + chars = torch.cat([chars, bigrams], dim=-1) + feats, _ = self.lstm(chars, seq_len) + feats = self.feat_drop(feats) + logits, relay_logits = self.feature_fn(feats) + + return logits, relay_logits + + def forward(self, chars, bigrams, relay_target, relay_mask, end_seg_mask, seq_len): + logits, relay_logits = self.get_feats(chars, bigrams, seq_len) + loss = self.semi_crf_relay(logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len) + return {'loss':loss} + + def predict(self, chars, bigrams, seq_len): + logits, relay_logits = self.get_feats(chars, bigrams, seq_len) + pred, pred_mask = self.semi_crf_relay.predict(logits, relay_logits, seq_len) + return {'pred': pred, 'pred_mask': pred_mask} + diff --git a/reproduction/seqence_labelling/cws/model/module.py b/reproduction/seqence_labelling/cws/model/module.py new file mode 100644 index 00000000..6cd8b5e3 --- /dev/null +++ b/reproduction/seqence_labelling/cws/model/module.py @@ -0,0 +1,198 @@ +from torch import nn +import torch +from fastNLP.modules import Embedding +import numpy as np + +class SemiCRFShiftRelay(nn.Module): + """ + 该模块是一个decoder,但 + + """ + def __init__(self, L): + """ + + :param L: 不包含relay的长度 + """ + if L<2: + raise RuntimeError() + super().__init__() + self.L = L + + def forward(self, logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len): + """ + relay node是接下来L个字都不是它的结束。relay的状态是往后滑动1个位置 + + :param logits: batch_size x max_len x L, 当前位置往左边L个segment的分数,最后一维的0是长度为1的segment(即本身) + :param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数 + :param relay_target: batch_size x max_len 每个位置他的segment在哪里开始的。如果超过L,则一直保持为L-1。比如长度为 + 5的词,L=3, [0, 1, 2, 2, 2] + :param relay_mask: batch_size x max_len, 在需要relay的地方为1, 长度为5的词, L=3时,为[1, 1, 1, 0, 0] + :param end_seg_mask: batch_size x max_len, segment结束的地方为1。 + :param seq_len: batch_size, 句子的长度 + :return: loss: batch_size, + """ + batch_size, max_len, L = logits.size() + + # 当前时刻为relay node的分数是多少 + relay_scores = logits.new_zeros(batch_size, max_len) + # 当前时刻结束的分数是多少 + scores = logits.new_zeros(batch_size, max_len+1) + # golden的分数 + gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(0), 0) + \ + logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(0), 0) + # 初始化 + scores[:, 1] = logits[:, 0, 0] + batch_i = torch.arange(batch_size).to(logits.device).long() + relay_scores[:, 0] = relay_logits[:, 0] + last_relay_index = max_len - self.L + for t in range(1, max_len): + real_L = min(t+1, L) + flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment + # 计算relay_scores的更新 + if tself.L-1: + # (2)从relay跳转过来的 + tmp2 = relay_scores[:, t-self.L] # batch_size + tmp2 = tmp2 + flip_logits_t[:, 0] # batch_size + tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1) + scores[:, t+1] = torch.logsumexp(tmp1, dim=-1) # 更新当前时刻的分数 + + # 计算golden + seg_i = relay_target[:, t] # batch_size + gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(0), 0) # batch_size, 后向从0到L长度的segment的分数 + relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(0), 0) + gold_scores = gold_scores + relay_score + gold_segment_scores + all_scores = scores.gather(dim=1, index=seq_len.unsqueeze(1)).squeeze(1) # batch_size + return all_scores - gold_scores + + def predict(self, logits, relay_logits, seq_len): + """ + relay node是接下来L个字都不是它的结束。relay的状态是往后滑动L-1个位置 + + :param logits: batch_size x max_len x L, 当前位置左边L个segment的分数,最后一维的0是长度为1的segment(即本身) + :param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数 + :param seq_len: batch_size, 句子的长度 + :return: pred: batch_size x max_len以该点开始的segment的(长度-1); pred_mask为1的地方预测有segment开始 + """ + batch_size, max_len, L = logits.size() + # 当前时刻为relay node的分数是多少 + max_relay_scores = logits.new_zeros(batch_size, max_len) + relay_bt = seq_len.new_zeros(batch_size, max_len) # 当前结果是否来自于relay的结果 + # 当前时刻结束的分数是多少 + max_scores = logits.new_zeros(batch_size, max_len+1) + bt = seq_len.new_zeros(batch_size, max_len) + # 初始化 + max_scores[:, 1] = logits[:, 0, 0] + max_relay_scores[:, 0] = relay_logits[:, 0] + last_relay_index = max_len - self.L + for t in range(1, max_len): + real_L = min(t+1, L) + flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment + # 计算relay_scores的更新 + if t-1: + if bt_i[j]==self.L: + seg_start_pos = j + j = j-self.L + while relay_bt_i[j]!=0 and j>-1: + j = j - 1 + pred[b, j] = seg_start_pos - j + pred_mask[b, j] = 1 + else: + length = bt_i[j] + j = j - bt_i[j] + pred_mask[b, j] = 1 + pred[b, j] = length + j = j - 1 + + return torch.LongTensor(pred).to(logits.device), torch.LongTensor(pred_mask).to(logits.device) + + + +class FeatureFunMax(nn.Module): + def __init__(self, hidden_size:int, L:int): + """ + 用于计算semi-CRF特征的函数。给定batch_size x max_len x hidden_size形状的输入,输出为batch_size x max_len x L的 + 分数,以及batch_size x max_len的relay的分数。两者的区别参考论文 TODO 补充 + + :param hidden_size: 输入特征的维度大小 + :param L: 不包含relay node的segment的长度大小。 + """ + super().__init__() + + self.end_fc = nn.Linear(hidden_size, 1, bias=False) + self.whole_w = nn.Parameter(torch.randn(L, hidden_size)) + self.relay_fc = nn.Linear(hidden_size, 1) + self.length_bias = nn.Parameter(torch.randn(L)) + self.L = L + def forward(self, logits): + """ + + :param logits: batch_size x max_len x hidden_size + :return: batch_size x max_len x L # 最后一维为左边segment的分数,0处为长度为1的segment + batch_size x max_len, # 当前位置是接下来L-1个位置都不是终点的分数 + + """ + batch_size, max_len, hidden_size = logits.size() + # start_scores = self.start_fc(logits) # batch_size x max_len x 1 # 每个位置作为start的分数 + tmp = logits.new_zeros(batch_size, max_len+self.L-1, hidden_size) + tmp[:, -max_len:] = logits + # batch_size x max_len x hidden_size x (self.L) -> batch_size x max_len x (self.L) x hidden_size + start_logits = tmp.unfold(dimension=1, size=self.L, step=1).transpose(2, 3).flip(dims=[2]) + end_scores = self.end_fc(logits) # batch_size x max_len x 1 + # 计算relay的特征 + relay_tmp = logits.new_zeros(batch_size, max_len, hidden_size) + relay_tmp[:, :-self.L] = logits[:, self.L:] + # batch_size x max_len x hidden_size + relay_logits_max = torch.max(relay_tmp, logits) # end - start + logits_max = torch.max(logits.unsqueeze(2), start_logits) # batch_size x max_len x L x hidden_size + whole_scores = (logits_max*self.whole_w).sum(dim=-1) # batch_size x max_len x self.L + # whole_scores = self.whole_fc().squeeze(-1) # bz x max_len x self.L + # batch_size x max_len + relay_scores = self.relay_fc(relay_logits_max).squeeze(-1) + return whole_scores+end_scores+self.length_bias.view(1, 1, -1), relay_scores diff --git a/reproduction/seqence_labelling/cws/test/__init__.py b/reproduction/seqence_labelling/cws/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py b/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py new file mode 100644 index 00000000..0b9cb633 --- /dev/null +++ b/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py @@ -0,0 +1,17 @@ + + +import unittest +from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader +from fastNLP.core.vocabulary import VocabularyOption + + +class TestCWSDataLoader(unittest.TestCase): + def test_case1(self): + cws_loader = SigHanLoader(target_type='bmes') + data = cws_loader.process('pku_demo.txt') + print(data.datasets) + + def test_calse2(self): + cws_loader = SigHanLoader(target_type='bmes') + data = cws_loader.process('pku_demo.txt', bigram_vocab_opt=VocabularyOption()) + print(data.datasets) \ No newline at end of file diff --git a/reproduction/seqence_labelling/cws/train_shift_relay.py b/reproduction/seqence_labelling/cws/train_shift_relay.py new file mode 100644 index 00000000..ed512252 --- /dev/null +++ b/reproduction/seqence_labelling/cws/train_shift_relay.py @@ -0,0 +1,68 @@ + +import os + +from fastNLP import cache_results +from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader +from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel +from fastNLP.io.embed_loader import EmbeddingOption +from fastNLP.core.vocabulary import VocabularyOption +from fastNLP import Trainer +from torch.optim import Adam +from fastNLP import BucketSampler +from fastNLP import GradientClipCallback +from reproduction.seqence_labelling.cws.model.metric import RelayMetric + + +# 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果 +@cache_results(None) +def prepare_data(): + data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt, + bigram_vocab_opt=bigram_vocab_opt, + bigram_embed_opt=bigram_embed_opt, + L=L) + return data + +#########hyper +L = 4 +hidden_size = 200 +num_layers = 1 +drop_p = 0.2 +lr = 0.02 + +#########hyper +device = 0 + +# !!!!这里前往不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 +# 你们的reproduction路径下,然后设置.gitignore +file_dir = '/path/to/pku' +char_embed_path = '/path/to/1grams_t3_m50_corpus.txt' +bigram_embed_path = 'path/to/2grams_t3_m50_corpus.txt' +bigram_vocab_opt = VocabularyOption(min_freq=3) +char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path) +bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path) + +data_name = os.path.basename(file_dir) +cache_fp = 'caches/{}.pkl'.format(data_name) + +data = prepare_data(_cache_fp=cache_fp, _refresh=False) + +model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'], + hidden_size=hidden_size, num_layers=num_layers, + L=L, num_bigram_per_char=1, drop_p=drop_p) + +sampler = BucketSampler(batch_size=32) +optimizer = Adam(model.parameters(), lr=lr) +clipper = GradientClipCallback(clip_value=5, clip_type='value') +callbacks = [clipper] +# if pretrain: +# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) +# callbacks.append(fixer) +trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, + batch_size=32, sampler=sampler, update_every=5, + n_epochs=3, print_every=5, + dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f', + validate_every=-1, save_path=None, + prefetch=True, use_tqdm=True, device=device, + callbacks=callbacks, + check_code_level=0) +trainer.train() \ No newline at end of file diff --git a/reproduction/utils.py b/reproduction/utils.py new file mode 100644 index 00000000..58883b43 --- /dev/null +++ b/reproduction/utils.py @@ -0,0 +1,51 @@ +import os + +from typing import Union, Dict + + +def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: + """ + 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 + { + 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 + 'test': 'xxx' # 可能有,也可能没有 + ... + } + 如果paths为不合法的,将直接进行raise相应的错误 + + :param paths: 路径 + :return: + """ + if isinstance(paths, str): + if os.path.isfile(paths): + return {'train': paths} + elif os.path.isdir(paths): + train_fp = os.path.join(paths, 'train.txt') + if not os.path.isfile(train_fp): + raise FileNotFoundError(f"train.txt is not found in folder {paths}.") + files = {'train': train_fp} + for filename in ['test.txt', 'dev.txt']: + fp = os.path.join(paths, filename) + if os.path.isfile(fp): + files[filename.split('.')[0]] = fp + return files + else: + raise FileNotFoundError(f"{paths} is not a valid file path.") + + elif isinstance(paths, dict): + if paths: + if 'train' not in paths: + raise KeyError("You have to include `train` in your dict.") + for key, value in paths.items(): + if isinstance(key, str) and isinstance(value, str): + if not os.path.isfile(value): + raise TypeError(f"{value} is not a valid file.") + else: + raise TypeError("All keys and values in paths should be str.") + return paths + else: + raise ValueError("Empty paths is not allowed.") + else: + raise TypeError(f"paths only supports str and dict. not {type(paths)}.") + +