@@ -494,14 +494,15 @@ class Trainer(object): | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True, on_exception='ignore'): | |||||
def train(self, load_best_model=True, on_exception='auto'): | |||||
""" | """ | ||||
使用该函数使Trainer开始训练。 | 使用该函数使Trainer开始训练。 | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | ||||
最好的模型参数。 | 最好的模型参数。 | ||||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | ||||
支持'ignore'与'raise': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出。 | |||||
支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; | |||||
'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. | |||||
:return dict: 返回一个字典类型的数据, | :return dict: 返回一个字典类型的数据, | ||||
内含以下内容:: | 内含以下内容:: | ||||
@@ -530,12 +531,16 @@ class Trainer(object): | |||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end() | self.callback_manager.on_train_end() | ||||
except (CallbackException, KeyboardInterrupt, Exception) as e: | |||||
except Exception as e: | |||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if on_exception=='raise': | |||||
if on_exception == 'auto': | |||||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||||
raise e | |||||
elif on_exception == 'raise': | |||||
raise e | raise e | ||||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | |||||
if self.dev_data is not None and self.best_dev_perf is not None: | |||||
print( | print( | ||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
self.tester._format_eval_results(self.best_dev_perf), ) | self.tester._format_eval_results(self.best_dev_perf), ) | ||||
@@ -4,7 +4,7 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 | |||||
__all__ = [ | __all__ = [ | ||||
"cache_results", | "cache_results", | ||||
"seq_len_to_mask", | "seq_len_to_mask", | ||||
"Example", | |||||
"Option", | |||||
] | ] | ||||
import _pickle | import _pickle | ||||
@@ -6,10 +6,10 @@ __all__ = [ | |||||
from functools import wraps | from functools import wraps | ||||
from collections import Counter | from collections import Counter | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .utils import Example | |||||
from .utils import Option | |||||
class VocabularyOption(Example): | |||||
class VocabularyOption(Option): | |||||
def __init__(self, | def __init__(self, | ||||
max_size=None, | max_size=None, | ||||
min_freq=None, | min_freq=None, | ||||
@@ -10,10 +10,10 @@ import numpy as np | |||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from .base_loader import BaseLoader | from .base_loader import BaseLoader | ||||
from ..core.utils import Example | |||||
from ..core.utils import Option | |||||
class EmbeddingOption(Example): | |||||
class EmbeddingOption(Option): | |||||
def __init__(self, | def __init__(self, | ||||
embed_filepath=None, | embed_filepath=None, | ||||
dtype=np.float32, | dtype=np.float32, | ||||
@@ -6,6 +6,9 @@ from typing import Union, Dict, List, Iterator | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import Vocabulary | from fastNLP import Vocabulary | ||||
from fastNLP import Const | |||||
from reproduction.utils import check_dataloader_paths | |||||
from functools import partial | |||||
class SigHanLoader(DataSetLoader): | class SigHanLoader(DataSetLoader): | ||||
""" | """ | ||||
@@ -20,27 +23,43 @@ class SigHanLoader(DataSetLoader): | |||||
chars: list(str), 每个元素是一个index(汉字对应的index) | chars: list(str), 每个元素是一个index(汉字对应的index) | ||||
target: list(int), 根据不同的encoding_type会有不同的变化 | target: list(int), 根据不同的encoding_type会有不同的变化 | ||||
:param target_type: target的类型,当前支持以下的两种: "bmes", "pointer" | |||||
:param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay" | |||||
""" | """ | ||||
def __init__(self, target_type:str): | def __init__(self, target_type:str): | ||||
super().__init__() | super().__init__() | ||||
if target_type.lower() not in ('bmes', 'pointer'): | |||||
raise ValueError("target_type only supports 'bmes', 'pointer'.") | |||||
if target_type.lower() not in ('bmes', 'shift_relay'): | |||||
raise ValueError("target_type only supports 'bmes', 'shift_relay'.") | |||||
self.target_type = target_type | self.target_type = target_type | ||||
if target_type=='bmes': | if target_type=='bmes': | ||||
self._word_len_to_target = self._word_len_to_bems | self._word_len_to_target = self._word_len_to_bems | ||||
elif target_type=='shift_relay': | |||||
self._word_len_to_target = self._word_lens_to_relay | |||||
@staticmethod | |||||
def _word_lens_to_relay(word_lens: Iterator[int]): | |||||
""" | |||||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||||
:param word_lens: | |||||
:return: {'target': , 'end_seg_mask':, 'start_seg_mask':} | |||||
""" | |||||
tags = [] | |||||
end_seg_mask = [] | |||||
start_seg_mask = [] | |||||
for word_len in word_lens: | |||||
tags.extend([idx for idx in range(word_len - 1, -1, -1)]) | |||||
end_seg_mask.extend([0] * (word_len - 1) + [1]) | |||||
start_seg_mask.extend([1] + [0] * (word_len - 1)) | |||||
return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask} | |||||
@staticmethod | @staticmethod | ||||
def _word_len_to_bems(word_lens:Iterator[int])->List[str]: | |||||
def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]: | |||||
""" | """ | ||||
:param word_lens: 每个word的长度 | :param word_lens: 每个word的长度 | ||||
:return: 返回对应的BMES的str | |||||
:return: | |||||
""" | """ | ||||
tags = [] | tags = [] | ||||
for word_len in word_lens: | for word_len in word_lens: | ||||
@@ -51,7 +70,7 @@ class SigHanLoader(DataSetLoader): | |||||
for _ in range(word_len-2): | for _ in range(word_len-2): | ||||
tags.append('M') | tags.append('M') | ||||
tags.append('E') | tags.append('E') | ||||
return tags | |||||
return {'target':tags} | |||||
@staticmethod | @staticmethod | ||||
def _gen_bigram(chars:List[str])->List[str]: | def _gen_bigram(chars:List[str])->List[str]: | ||||
@@ -71,11 +90,15 @@ class SigHanLoader(DataSetLoader): | |||||
dataset = DataSet() | dataset = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
for line in f: | for line in f: | ||||
line = line.strip() | |||||
if not line: # 去掉空行 | |||||
continue | |||||
parts = line.split() | parts = line.split() | ||||
word_lens = map(len, parts) | word_lens = map(len, parts) | ||||
chars = list(line) | |||||
chars = list(''.join(parts)) | |||||
tags = self._word_len_to_target(word_lens) | tags = self._word_len_to_target(word_lens) | ||||
dataset.append(Instance(raw_chars=chars, target=tags)) | |||||
assert len(chars)==len(tags['target']) | |||||
dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars))) | |||||
if len(dataset)==0: | if len(dataset)==0: | ||||
raise RuntimeError(f"{path} has no valid data.") | raise RuntimeError(f"{path} has no valid data.") | ||||
if bigram: | if bigram: | ||||
@@ -84,7 +107,7 @@ class SigHanLoader(DataSetLoader): | |||||
def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, | def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, | ||||
char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, | char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, | ||||
bigram_embed_opt:EmbeddingOption=None): | |||||
bigram_embed_opt:EmbeddingOption=None, L:int=4): | |||||
""" | """ | ||||
支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 | 支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 | ||||
@@ -113,7 +136,7 @@ class SigHanLoader(DataSetLoader): | |||||
data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt | data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt | ||||
# 包含以下的内容data.vocabs['chars']: Vocabulary对象 | # 包含以下的内容data.vocabs['chars']: Vocabulary对象 | ||||
# data.vocabs['target']:Vocabulary对象 | # data.vocabs['target']:Vocabulary对象 | ||||
# data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项 | |||||
# data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象; | |||||
# data.datasets['train']: DataSet对象 | # data.datasets['train']: DataSet对象 | ||||
# 包含的field有: | # 包含的field有: | ||||
# raw_chars: list[str], 每个元素是一个汉字 | # raw_chars: list[str], 每个元素是一个汉字 | ||||
@@ -132,79 +155,95 @@ class SigHanLoader(DataSetLoader): | |||||
:param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 | :param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 | ||||
为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e<eos> | 为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e<eos> | ||||
:param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 | :param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 | ||||
:param L: 当target_type为shift_relay时传入的segment长度 | |||||
:return: | :return: | ||||
""" | """ | ||||
# 推荐大家使用这个check_data_loader_paths进行paths的验证 | # 推荐大家使用这个check_data_loader_paths进行paths的验证 | ||||
paths = check_dataloader_paths(paths) | paths = check_dataloader_paths(paths) | ||||
datasets = {} | datasets = {} | ||||
data = DataInfo() | |||||
bigram = bigram_vocab_opt is not None | bigram = bigram_vocab_opt is not None | ||||
for name, path in paths.items(): | for name, path in paths.items(): | ||||
dataset = self.load(path, bigram=bigram) | dataset = self.load(path, bigram=bigram) | ||||
datasets[name] = dataset | datasets[name] = dataset | ||||
input_fields = [] | |||||
target_fields = [] | |||||
# 创建vocab | # 创建vocab | ||||
char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) | char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) | ||||
char_vocab.from_dataset(datasets['train'], field_name='raw_chars') | char_vocab.from_dataset(datasets['train'], field_name='raw_chars') | ||||
char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') | char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') | ||||
data.vocabs[Const.CHAR_INPUT] = char_vocab | |||||
input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET]) | |||||
target_fields.append(Const.TARGET) | |||||
# 创建target | # 创建target | ||||
if self.target_type == 'bmes': | if self.target_type == 'bmes': | ||||
target_vocab = Vocabulary(unknown=None, padding=None) | target_vocab = Vocabulary(unknown=None, padding=None) | ||||
target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) | target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) | ||||
target_vocab.index_dataset(*datasets.values(), field_name='target') | target_vocab.index_dataset(*datasets.values(), field_name='target') | ||||
data.vocabs[Const.TARGET] = target_vocab | |||||
if char_embed_opt is not None: | |||||
char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab) | |||||
data.embeddings['chars'] = char_embed | |||||
if bigram: | if bigram: | ||||
bigram_vocab = Vocabulary(**bigram_vocab_opt) | bigram_vocab = Vocabulary(**bigram_vocab_opt) | ||||
bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') | bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') | ||||
bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') | bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') | ||||
data.vocabs['bigrams'] = bigram_vocab | |||||
if bigram_embed_opt is not None: | if bigram_embed_opt is not None: | ||||
pass | |||||
bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab) | |||||
data.embeddings['bigrams'] = bigram_embed | |||||
input_fields.append('bigrams') | |||||
if self.target_type == 'shift_relay': | |||||
func = partial(self._clip_target, L=L) | |||||
for name, dataset in datasets.items(): | |||||
res = dataset.apply_field(func, field_name='target') | |||||
relay_target = [res_i[0] for res_i in res] | |||||
relay_mask = [res_i[1] for res_i in res] | |||||
dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) | |||||
dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) | |||||
if self.target_type == 'shift_relay': | |||||
input_fields.extend(['end_seg_mask']) | |||||
target_fields.append('start_seg_mask') | |||||
# 将dataset加入DataInfo | |||||
for name, dataset in datasets.items(): | |||||
dataset.set_input(*input_fields) | |||||
dataset.set_target(*target_fields) | |||||
data.datasets[name] = dataset | |||||
return data | |||||
import os | |||||
@staticmethod | |||||
def _clip_target(target:List[int], L:int): | |||||
""" | |||||
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
""" | |||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | |||||
{ | |||||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||||
'test': 'xxx' # 可能有,也可能没有 | |||||
... | |||||
} | |||||
如果paths为不合法的,将直接进行raise相应的错误 | |||||
:param paths: 路径 | |||||
:return: | |||||
""" | |||||
if isinstance(paths, str): | |||||
if os.path.isfile(paths): | |||||
return {'train': paths} | |||||
elif os.path.isdir(paths): | |||||
train_fp = os.path.join(paths, 'train.txt') | |||||
if not os.path.isfile(train_fp): | |||||
raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | |||||
files = {'train': train_fp} | |||||
for filename in ['test.txt', 'dev.txt']: | |||||
fp = os.path.join(paths, filename) | |||||
if os.path.isfile(fp): | |||||
files[filename.split('.')[0]] = fp | |||||
return files | |||||
else: | |||||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||||
elif isinstance(paths, dict): | |||||
if paths: | |||||
if 'train' not in paths: | |||||
raise KeyError("You have to include `train` in your dict.") | |||||
for key, value in paths.items(): | |||||
if isinstance(key, str) and isinstance(value, str): | |||||
if not os.path.isfile(value): | |||||
raise TypeError(f"{value} is not a valid file.") | |||||
else: | |||||
raise TypeError("All keys and values in paths should be str.") | |||||
return paths | |||||
只有在target_type为shift_relay的使用 | |||||
:param target: List[int] | |||||
:param L: | |||||
:return: | |||||
""" | |||||
relay_target_i = [] | |||||
tmp = [] | |||||
for j in range(len(target) - 1): | |||||
tmp.append(target[j]) | |||||
if target[j] > target[j + 1]: | |||||
pass | |||||
else: | |||||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||||
tmp = [] | |||||
# 处理未结束的部分 | |||||
if len(tmp) == 0: | |||||
relay_target_i.append(0) | |||||
else: | else: | ||||
raise ValueError("Empty paths is not allowed.") | |||||
else: | |||||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | |||||
tmp.append(target[-1]) | |||||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||||
relay_mask_i = [] | |||||
j = 0 | |||||
while j < len(target): | |||||
seg_len = target[j] + 1 | |||||
if target[j] < L: | |||||
relay_mask_i.extend([0] * (seg_len)) | |||||
else: | |||||
relay_mask_i.extend([1] * (seg_len - L) + [0] * L) | |||||
j = seg_len + j | |||||
return relay_target_i, relay_mask_i | |||||
@@ -0,0 +1,44 @@ | |||||
from fastNLP.core.metrics import MetricBase | |||||
class RelayMetric(MetricBase): | |||||
def __init__(self, pred=None, pred_mask=None, target=None, start_seg_mask=None): | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, pred_mask=pred_mask, target=target, start_seg_mask=start_seg_mask) | |||||
self.tp = 0 | |||||
self.rec = 0 | |||||
self.pre = 0 | |||||
def evaluate(self, pred, pred_mask, target, start_seg_mask): | |||||
""" | |||||
给定每个batch,累计一下结果。 | |||||
:param pred: 预测的结果,为当前位置的开始的segment的(长度-1) | |||||
:param pred_mask: 当前位置预测有segment开始 | |||||
:param target: 当前位置开始的segment的(长度-1) | |||||
:param start_seg_mask: 当前有segment结束 | |||||
:return: | |||||
""" | |||||
self.tp += ((pred.long().eq(target.long())).__and__(pred_mask.byte().__and__(start_seg_mask.byte()))).sum().item() | |||||
self.rec += start_seg_mask.sum().item() | |||||
self.pre += pred_mask.sum().item() | |||||
def get_metric(self, reset=True): | |||||
""" | |||||
在所有数据都计算结束之后,得到performance | |||||
:param reset: | |||||
:return: | |||||
""" | |||||
pre = self.tp/(self.pre + 1e-12) | |||||
rec = self.tp/(self.rec + 1e-12) | |||||
f = 2*pre*rec/(1e-12 + pre + rec) | |||||
if reset: | |||||
self.tp = 0 | |||||
self.rec = 0 | |||||
self.pre = 0 | |||||
self.bigger_than_L = 0 | |||||
return {'f': round(f, 6), 'pre': round(pre, 6), 'rec': round(rec, 6)} |
@@ -0,0 +1,74 @@ | |||||
from torch import nn | |||||
import torch | |||||
from fastNLP.modules import Embedding | |||||
import numpy as np | |||||
from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay | |||||
from fastNLP.modules import LSTM | |||||
class ShiftRelayCWSModel(nn.Module): | |||||
""" | |||||
该模型可以用于进行分词操作 | |||||
包含两个方法, | |||||
forward(chars, bigrams, seq_len) -> {'loss': batch_size,} | |||||
predict(chars, bigrams) -> {'pred': batch_size x max_len, 'pred_mask': batch_size x max_len} | |||||
pred是对当前segment的长度预测,pred_mask是仅在有预测的位置为1 | |||||
:param char_embed: 预训练的Embedding或者embedding的shape | |||||
:param bigram_embed: 预训练的Embedding或者embedding的shape | |||||
:param hidden_size: LSTM的隐藏层大小 | |||||
:param num_layers: LSTM的层数 | |||||
:param L: SemiCRFShiftRelay的segment大小 | |||||
:param num_bigram_per_char: 每个character对应的bigram的数量 | |||||
:param drop_p: Dropout的大小 | |||||
""" | |||||
def __init__(self, char_embed:Embedding, bigram_embed:Embedding, hidden_size:int=400, num_layers:int=1, | |||||
L:int=6, num_bigram_per_char:int=1, drop_p:float=0.2): | |||||
super().__init__() | |||||
self.char_embedding = Embedding(char_embed, dropout=drop_p) | |||||
self._pretrained_embed = False | |||||
if isinstance(char_embed, np.ndarray): | |||||
self._pretrained_embed = True | |||||
self.bigram_embedding = Embedding(bigram_embed, dropout=drop_p) | |||||
self.lstm = LSTM(100 * (num_bigram_per_char + 1), hidden_size // 2, num_layers=num_layers, bidirectional=True, | |||||
batch_first=True) | |||||
self.feature_fn = FeatureFunMax(hidden_size, L) | |||||
self.semi_crf_relay = SemiCRFShiftRelay(L) | |||||
self.feat_drop = nn.Dropout(drop_p) | |||||
self.reset_param() | |||||
# self.feature_fn.reset_parameters() | |||||
def reset_param(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'embedding' in name and self._pretrained_embed: | |||||
continue | |||||
if 'bias_hh' in name: | |||||
nn.init.constant_(param, 0) | |||||
elif 'bias_ih' in name: | |||||
nn.init.constant_(param, 1) | |||||
elif len(param.size()) < 2: | |||||
nn.init.uniform_(param, -0.1, 0.1) | |||||
else: | |||||
nn.init.xavier_uniform_(param) | |||||
def get_feats(self, chars, bigrams, seq_len): | |||||
batch_size, max_len = chars.size() | |||||
chars = self.char_embedding(chars) | |||||
bigrams = self.bigram_embedding(bigrams) | |||||
bigrams = bigrams.view(bigrams.size(0), max_len, -1) | |||||
chars = torch.cat([chars, bigrams], dim=-1) | |||||
feats, _ = self.lstm(chars, seq_len) | |||||
feats = self.feat_drop(feats) | |||||
logits, relay_logits = self.feature_fn(feats) | |||||
return logits, relay_logits | |||||
def forward(self, chars, bigrams, relay_target, relay_mask, end_seg_mask, seq_len): | |||||
logits, relay_logits = self.get_feats(chars, bigrams, seq_len) | |||||
loss = self.semi_crf_relay(logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len) | |||||
return {'loss':loss} | |||||
def predict(self, chars, bigrams, seq_len): | |||||
logits, relay_logits = self.get_feats(chars, bigrams, seq_len) | |||||
pred, pred_mask = self.semi_crf_relay.predict(logits, relay_logits, seq_len) | |||||
return {'pred': pred, 'pred_mask': pred_mask} | |||||
@@ -0,0 +1,198 @@ | |||||
from torch import nn | |||||
import torch | |||||
from fastNLP.modules import Embedding | |||||
import numpy as np | |||||
class SemiCRFShiftRelay(nn.Module): | |||||
""" | |||||
该模块是一个decoder,但 | |||||
""" | |||||
def __init__(self, L): | |||||
""" | |||||
:param L: 不包含relay的长度 | |||||
""" | |||||
if L<2: | |||||
raise RuntimeError() | |||||
super().__init__() | |||||
self.L = L | |||||
def forward(self, logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len): | |||||
""" | |||||
relay node是接下来L个字都不是它的结束。relay的状态是往后滑动1个位置 | |||||
:param logits: batch_size x max_len x L, 当前位置往左边L个segment的分数,最后一维的0是长度为1的segment(即本身) | |||||
:param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数 | |||||
:param relay_target: batch_size x max_len 每个位置他的segment在哪里开始的。如果超过L,则一直保持为L-1。比如长度为 | |||||
5的词,L=3, [0, 1, 2, 2, 2] | |||||
:param relay_mask: batch_size x max_len, 在需要relay的地方为1, 长度为5的词, L=3时,为[1, 1, 1, 0, 0] | |||||
:param end_seg_mask: batch_size x max_len, segment结束的地方为1。 | |||||
:param seq_len: batch_size, 句子的长度 | |||||
:return: loss: batch_size, | |||||
""" | |||||
batch_size, max_len, L = logits.size() | |||||
# 当前时刻为relay node的分数是多少 | |||||
relay_scores = logits.new_zeros(batch_size, max_len) | |||||
# 当前时刻结束的分数是多少 | |||||
scores = logits.new_zeros(batch_size, max_len+1) | |||||
# golden的分数 | |||||
gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(0), 0) + \ | |||||
logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(0), 0) | |||||
# 初始化 | |||||
scores[:, 1] = logits[:, 0, 0] | |||||
batch_i = torch.arange(batch_size).to(logits.device).long() | |||||
relay_scores[:, 0] = relay_logits[:, 0] | |||||
last_relay_index = max_len - self.L | |||||
for t in range(1, max_len): | |||||
real_L = min(t+1, L) | |||||
flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment | |||||
# 计算relay_scores的更新 | |||||
if t<last_relay_index: | |||||
# (1) 从正常位置跳转 | |||||
tmp1 = relay_logits[:, t] + scores[:, t] # batch_size | |||||
# (2) 从relay跳转 | |||||
tmp2 = relay_logits[:, t] + relay_scores[:, t-1] # batch_size | |||||
tmp1 = torch.stack([tmp1, tmp2], dim=0) | |||||
relay_scores[:, t] = torch.logsumexp(tmp1, dim=0) | |||||
# 计算scores的更新 | |||||
# (1)从之前的位置跳转过来的 | |||||
tmp1 = scores[:, t-real_L+1:t+1] + flip_logits_t # batch_size x L | |||||
if t>self.L-1: | |||||
# (2)从relay跳转过来的 | |||||
tmp2 = relay_scores[:, t-self.L] # batch_size | |||||
tmp2 = tmp2 + flip_logits_t[:, 0] # batch_size | |||||
tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1) | |||||
scores[:, t+1] = torch.logsumexp(tmp1, dim=-1) # 更新当前时刻的分数 | |||||
# 计算golden | |||||
seg_i = relay_target[:, t] # batch_size | |||||
gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(0), 0) # batch_size, 后向从0到L长度的segment的分数 | |||||
relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(0), 0) | |||||
gold_scores = gold_scores + relay_score + gold_segment_scores | |||||
all_scores = scores.gather(dim=1, index=seq_len.unsqueeze(1)).squeeze(1) # batch_size | |||||
return all_scores - gold_scores | |||||
def predict(self, logits, relay_logits, seq_len): | |||||
""" | |||||
relay node是接下来L个字都不是它的结束。relay的状态是往后滑动L-1个位置 | |||||
:param logits: batch_size x max_len x L, 当前位置左边L个segment的分数,最后一维的0是长度为1的segment(即本身) | |||||
:param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数 | |||||
:param seq_len: batch_size, 句子的长度 | |||||
:return: pred: batch_size x max_len以该点开始的segment的(长度-1); pred_mask为1的地方预测有segment开始 | |||||
""" | |||||
batch_size, max_len, L = logits.size() | |||||
# 当前时刻为relay node的分数是多少 | |||||
max_relay_scores = logits.new_zeros(batch_size, max_len) | |||||
relay_bt = seq_len.new_zeros(batch_size, max_len) # 当前结果是否来自于relay的结果 | |||||
# 当前时刻结束的分数是多少 | |||||
max_scores = logits.new_zeros(batch_size, max_len+1) | |||||
bt = seq_len.new_zeros(batch_size, max_len) | |||||
# 初始化 | |||||
max_scores[:, 1] = logits[:, 0, 0] | |||||
max_relay_scores[:, 0] = relay_logits[:, 0] | |||||
last_relay_index = max_len - self.L | |||||
for t in range(1, max_len): | |||||
real_L = min(t+1, L) | |||||
flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment | |||||
# 计算relay_scores的更新 | |||||
if t<last_relay_index: | |||||
# (1) 从正常位置跳转 | |||||
tmp1 = relay_logits[:, t] + max_scores[:, t] | |||||
# (2) 从relay跳转 | |||||
tmp2 = relay_logits[:, t] + max_relay_scores[:, t-1] # batch_size | |||||
# 每个sample的倒数L位不能是relay了 | |||||
tmp2 = tmp2.masked_fill(seq_len.le(t+L), float('-inf')) | |||||
mask_i = tmp1.lt(tmp2) # 为1的位置为relay跳转 | |||||
relay_bt[:, t].masked_fill_(mask_i, 1) | |||||
max_relay_scores[:, t] = torch.max(tmp1, tmp2) | |||||
# 计算scores的更新 | |||||
# (1)从之前的位置跳转过来的 | |||||
tmp1 = max_scores[:, t-real_L+1:t+1] + flip_logits_t # batch_size x L | |||||
tmp1 = tmp1.flip(dims=[1]) # 0的位置代表长度为1的segment | |||||
if self.L-1<t: | |||||
# (2)从relay跳转过来的 | |||||
tmp2 = max_relay_scores[:, t-self.L] # batch_size | |||||
tmp2 = tmp2 + flip_logits_t[:, 0] | |||||
tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1) | |||||
# 看哪个更大 | |||||
max_score, pt = torch.max(tmp1, dim=1) | |||||
max_scores[:, t+1] = max_score | |||||
# mask_i = pt.ge(self.L) | |||||
bt[:, t] = pt # 假设L=3, 那么对于0,1,2,3分别代表的是[t, t], [t-1, t], [t-2, t], [t-self.L(relay), t] | |||||
# 需要把结果decode出来 | |||||
pred = np.zeros((batch_size, max_len), dtype=int) | |||||
pred_mask = np.zeros((batch_size, max_len), dtype=int) | |||||
seq_len = seq_len.tolist() | |||||
bt = bt.tolist() | |||||
relay_bt = relay_bt.tolist() | |||||
for b in range(batch_size): | |||||
seq_len_i = seq_len[b] | |||||
bt_i = bt[b][:seq_len_i] | |||||
relay_bt_i = relay_bt[b][:seq_len_i] | |||||
j = seq_len_i - 1 | |||||
assert relay_bt_i[j]!=1 | |||||
while j>-1: | |||||
if bt_i[j]==self.L: | |||||
seg_start_pos = j | |||||
j = j-self.L | |||||
while relay_bt_i[j]!=0 and j>-1: | |||||
j = j - 1 | |||||
pred[b, j] = seg_start_pos - j | |||||
pred_mask[b, j] = 1 | |||||
else: | |||||
length = bt_i[j] | |||||
j = j - bt_i[j] | |||||
pred_mask[b, j] = 1 | |||||
pred[b, j] = length | |||||
j = j - 1 | |||||
return torch.LongTensor(pred).to(logits.device), torch.LongTensor(pred_mask).to(logits.device) | |||||
class FeatureFunMax(nn.Module): | |||||
def __init__(self, hidden_size:int, L:int): | |||||
""" | |||||
用于计算semi-CRF特征的函数。给定batch_size x max_len x hidden_size形状的输入,输出为batch_size x max_len x L的 | |||||
分数,以及batch_size x max_len的relay的分数。两者的区别参考论文 TODO 补充 | |||||
:param hidden_size: 输入特征的维度大小 | |||||
:param L: 不包含relay node的segment的长度大小。 | |||||
""" | |||||
super().__init__() | |||||
self.end_fc = nn.Linear(hidden_size, 1, bias=False) | |||||
self.whole_w = nn.Parameter(torch.randn(L, hidden_size)) | |||||
self.relay_fc = nn.Linear(hidden_size, 1) | |||||
self.length_bias = nn.Parameter(torch.randn(L)) | |||||
self.L = L | |||||
def forward(self, logits): | |||||
""" | |||||
:param logits: batch_size x max_len x hidden_size | |||||
:return: batch_size x max_len x L # 最后一维为左边segment的分数,0处为长度为1的segment | |||||
batch_size x max_len, # 当前位置是接下来L-1个位置都不是终点的分数 | |||||
""" | |||||
batch_size, max_len, hidden_size = logits.size() | |||||
# start_scores = self.start_fc(logits) # batch_size x max_len x 1 # 每个位置作为start的分数 | |||||
tmp = logits.new_zeros(batch_size, max_len+self.L-1, hidden_size) | |||||
tmp[:, -max_len:] = logits | |||||
# batch_size x max_len x hidden_size x (self.L) -> batch_size x max_len x (self.L) x hidden_size | |||||
start_logits = tmp.unfold(dimension=1, size=self.L, step=1).transpose(2, 3).flip(dims=[2]) | |||||
end_scores = self.end_fc(logits) # batch_size x max_len x 1 | |||||
# 计算relay的特征 | |||||
relay_tmp = logits.new_zeros(batch_size, max_len, hidden_size) | |||||
relay_tmp[:, :-self.L] = logits[:, self.L:] | |||||
# batch_size x max_len x hidden_size | |||||
relay_logits_max = torch.max(relay_tmp, logits) # end - start | |||||
logits_max = torch.max(logits.unsqueeze(2), start_logits) # batch_size x max_len x L x hidden_size | |||||
whole_scores = (logits_max*self.whole_w).sum(dim=-1) # batch_size x max_len x self.L | |||||
# whole_scores = self.whole_fc().squeeze(-1) # bz x max_len x self.L | |||||
# batch_size x max_len | |||||
relay_scores = self.relay_fc(relay_logits_max).squeeze(-1) | |||||
return whole_scores+end_scores+self.length_bias.view(1, 1, -1), relay_scores |
@@ -0,0 +1,17 @@ | |||||
import unittest | |||||
from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader | |||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
class TestCWSDataLoader(unittest.TestCase): | |||||
def test_case1(self): | |||||
cws_loader = SigHanLoader(target_type='bmes') | |||||
data = cws_loader.process('pku_demo.txt') | |||||
print(data.datasets) | |||||
def test_calse2(self): | |||||
cws_loader = SigHanLoader(target_type='bmes') | |||||
data = cws_loader.process('pku_demo.txt', bigram_vocab_opt=VocabularyOption()) | |||||
print(data.datasets) |
@@ -0,0 +1,68 @@ | |||||
import os | |||||
from fastNLP import cache_results | |||||
from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader | |||||
from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel | |||||
from fastNLP.io.embed_loader import EmbeddingOption | |||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
from fastNLP import Trainer | |||||
from torch.optim import Adam | |||||
from fastNLP import BucketSampler | |||||
from fastNLP import GradientClipCallback | |||||
from reproduction.seqence_labelling.cws.model.metric import RelayMetric | |||||
# 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果 | |||||
@cache_results(None) | |||||
def prepare_data(): | |||||
data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt, | |||||
bigram_vocab_opt=bigram_vocab_opt, | |||||
bigram_embed_opt=bigram_embed_opt, | |||||
L=L) | |||||
return data | |||||
#########hyper | |||||
L = 4 | |||||
hidden_size = 200 | |||||
num_layers = 1 | |||||
drop_p = 0.2 | |||||
lr = 0.02 | |||||
#########hyper | |||||
device = 0 | |||||
# !!!!这里前往不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 | |||||
# 你们的reproduction路径下,然后设置.gitignore | |||||
file_dir = '/path/to/pku' | |||||
char_embed_path = '/path/to/1grams_t3_m50_corpus.txt' | |||||
bigram_embed_path = 'path/to/2grams_t3_m50_corpus.txt' | |||||
bigram_vocab_opt = VocabularyOption(min_freq=3) | |||||
char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path) | |||||
bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path) | |||||
data_name = os.path.basename(file_dir) | |||||
cache_fp = 'caches/{}.pkl'.format(data_name) | |||||
data = prepare_data(_cache_fp=cache_fp, _refresh=False) | |||||
model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'], | |||||
hidden_size=hidden_size, num_layers=num_layers, | |||||
L=L, num_bigram_per_char=1, drop_p=drop_p) | |||||
sampler = BucketSampler(batch_size=32) | |||||
optimizer = Adam(model.parameters(), lr=lr) | |||||
clipper = GradientClipCallback(clip_value=5, clip_type='value') | |||||
callbacks = [clipper] | |||||
# if pretrain: | |||||
# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | |||||
# callbacks.append(fixer) | |||||
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, | |||||
batch_size=32, sampler=sampler, update_every=5, | |||||
n_epochs=3, print_every=5, | |||||
dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f', | |||||
validate_every=-1, save_path=None, | |||||
prefetch=True, use_tqdm=True, device=device, | |||||
callbacks=callbacks, | |||||
check_code_level=0) | |||||
trainer.train() |
@@ -0,0 +1,51 @@ | |||||
import os | |||||
from typing import Union, Dict | |||||
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
""" | |||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | |||||
{ | |||||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||||
'test': 'xxx' # 可能有,也可能没有 | |||||
... | |||||
} | |||||
如果paths为不合法的,将直接进行raise相应的错误 | |||||
:param paths: 路径 | |||||
:return: | |||||
""" | |||||
if isinstance(paths, str): | |||||
if os.path.isfile(paths): | |||||
return {'train': paths} | |||||
elif os.path.isdir(paths): | |||||
train_fp = os.path.join(paths, 'train.txt') | |||||
if not os.path.isfile(train_fp): | |||||
raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | |||||
files = {'train': train_fp} | |||||
for filename in ['test.txt', 'dev.txt']: | |||||
fp = os.path.join(paths, filename) | |||||
if os.path.isfile(fp): | |||||
files[filename.split('.')[0]] = fp | |||||
return files | |||||
else: | |||||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||||
elif isinstance(paths, dict): | |||||
if paths: | |||||
if 'train' not in paths: | |||||
raise KeyError("You have to include `train` in your dict.") | |||||
for key, value in paths.items(): | |||||
if isinstance(key, str) and isinstance(value, str): | |||||
if not os.path.isfile(value): | |||||
raise TypeError(f"{value} is not a valid file.") | |||||
else: | |||||
raise TypeError("All keys and values in paths should be str.") | |||||
return paths | |||||
else: | |||||
raise ValueError("Empty paths is not allowed.") | |||||
else: | |||||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | |||||