Browse Source

序列标注的SemiCRFRelay中文分词.

tags/v0.4.10
yh_cc 6 years ago
parent
commit
d71f0eef13
12 changed files with 566 additions and 70 deletions
  1. +10
    -5
      fastNLP/core/trainer.py
  2. +1
    -1
      fastNLP/core/utils.py
  3. +2
    -2
      fastNLP/core/vocabulary.py
  4. +2
    -2
      fastNLP/io/embed_loader.py
  5. +99
    -60
      reproduction/seqence_labelling/cws/data/CWSDataLoader.py
  6. +44
    -0
      reproduction/seqence_labelling/cws/model/metric.py
  7. +74
    -0
      reproduction/seqence_labelling/cws/model/model.py
  8. +198
    -0
      reproduction/seqence_labelling/cws/model/module.py
  9. +0
    -0
      reproduction/seqence_labelling/cws/test/__init__.py
  10. +17
    -0
      reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py
  11. +68
    -0
      reproduction/seqence_labelling/cws/train_shift_relay.py
  12. +51
    -0
      reproduction/utils.py

+ 10
- 5
fastNLP/core/trainer.py View File

@@ -494,14 +494,15 @@ class Trainer(object):
self.callback_manager = CallbackManager(env={"trainer": self}, self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks) callbacks=callbacks)
def train(self, load_best_model=True, on_exception='ignore'):
def train(self, load_best_model=True, on_exception='auto'):
""" """
使用该函数使Trainer开始训练。 使用该函数使Trainer开始训练。


:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。 最好的模型参数。
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。
支持'ignore'与'raise': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出。
支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出;
'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception.
:return dict: 返回一个字典类型的数据, :return dict: 返回一个字典类型的数据,
内含以下内容:: 内含以下内容::


@@ -530,12 +531,16 @@ class Trainer(object):
self.callback_manager.on_train_begin() self.callback_manager.on_train_begin()
self._train() self._train()
self.callback_manager.on_train_end() self.callback_manager.on_train_end()
except (CallbackException, KeyboardInterrupt, Exception) as e:

except Exception as e:
self.callback_manager.on_exception(e) self.callback_manager.on_exception(e)
if on_exception=='raise':
if on_exception == 'auto':
if not isinstance(e, (CallbackException, KeyboardInterrupt)):
raise e
elif on_exception == 'raise':
raise e raise e
if self.dev_data is not None and hasattr(self, 'best_dev_perf'):
if self.dev_data is not None and self.best_dev_perf is not None:
print( print(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf), ) self.tester._format_eval_results(self.best_dev_perf), )


+ 1
- 1
fastNLP/core/utils.py View File

@@ -4,7 +4,7 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户
__all__ = [ __all__ = [
"cache_results", "cache_results",
"seq_len_to_mask", "seq_len_to_mask",
"Example",
"Option",
] ]


import _pickle import _pickle


+ 2
- 2
fastNLP/core/vocabulary.py View File

@@ -6,10 +6,10 @@ __all__ = [
from functools import wraps from functools import wraps
from collections import Counter from collections import Counter
from .dataset import DataSet from .dataset import DataSet
from .utils import Example
from .utils import Option




class VocabularyOption(Example):
class VocabularyOption(Option):
def __init__(self, def __init__(self,
max_size=None, max_size=None,
min_freq=None, min_freq=None,


+ 2
- 2
fastNLP/io/embed_loader.py View File

@@ -10,10 +10,10 @@ import numpy as np


from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader from .base_loader import BaseLoader
from ..core.utils import Example
from ..core.utils import Option




class EmbeddingOption(Example):
class EmbeddingOption(Option):
def __init__(self, def __init__(self,
embed_filepath=None, embed_filepath=None,
dtype=np.float32, dtype=np.float32,


reproduction/seqence_labelling/Chinese_Word_Segmentation/data/CWSDataLoader.py → reproduction/seqence_labelling/cws/data/CWSDataLoader.py View File

@@ -6,6 +6,9 @@ from typing import Union, Dict, List, Iterator
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
from fastNLP import Vocabulary from fastNLP import Vocabulary
from fastNLP import Const
from reproduction.utils import check_dataloader_paths
from functools import partial


class SigHanLoader(DataSetLoader): class SigHanLoader(DataSetLoader):
""" """
@@ -20,27 +23,43 @@ class SigHanLoader(DataSetLoader):
chars: list(str), 每个元素是一个index(汉字对应的index) chars: list(str), 每个元素是一个index(汉字对应的index)
target: list(int), 根据不同的encoding_type会有不同的变化 target: list(int), 根据不同的encoding_type会有不同的变化


:param target_type: target的类型,当前支持以下的两种: "bmes", "pointer"
:param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay"
""" """


def __init__(self, target_type:str): def __init__(self, target_type:str):
super().__init__() super().__init__()


if target_type.lower() not in ('bmes', 'pointer'):
raise ValueError("target_type only supports 'bmes', 'pointer'.")
if target_type.lower() not in ('bmes', 'shift_relay'):
raise ValueError("target_type only supports 'bmes', 'shift_relay'.")


self.target_type = target_type self.target_type = target_type
if target_type=='bmes': if target_type=='bmes':
self._word_len_to_target = self._word_len_to_bems self._word_len_to_target = self._word_len_to_bems
elif target_type=='shift_relay':
self._word_len_to_target = self._word_lens_to_relay



@staticmethod
def _word_lens_to_relay(word_lens: Iterator[int]):
"""
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长);
:param word_lens:
:return: {'target': , 'end_seg_mask':, 'start_seg_mask':}
"""
tags = []
end_seg_mask = []
start_seg_mask = []
for word_len in word_lens:
tags.extend([idx for idx in range(word_len - 1, -1, -1)])
end_seg_mask.extend([0] * (word_len - 1) + [1])
start_seg_mask.extend([1] + [0] * (word_len - 1))
return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask}


@staticmethod @staticmethod
def _word_len_to_bems(word_lens:Iterator[int])->List[str]:
def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]:
""" """


:param word_lens: 每个word的长度 :param word_lens: 每个word的长度
:return: 返回对应的BMES的str
:return:
""" """
tags = [] tags = []
for word_len in word_lens: for word_len in word_lens:
@@ -51,7 +70,7 @@ class SigHanLoader(DataSetLoader):
for _ in range(word_len-2): for _ in range(word_len-2):
tags.append('M') tags.append('M')
tags.append('E') tags.append('E')
return tags
return {'target':tags}


@staticmethod @staticmethod
def _gen_bigram(chars:List[str])->List[str]: def _gen_bigram(chars:List[str])->List[str]:
@@ -71,11 +90,15 @@ class SigHanLoader(DataSetLoader):
dataset = DataSet() dataset = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
line = line.strip()
if not line: # 去掉空行
continue
parts = line.split() parts = line.split()
word_lens = map(len, parts) word_lens = map(len, parts)
chars = list(line)
chars = list(''.join(parts))
tags = self._word_len_to_target(word_lens) tags = self._word_len_to_target(word_lens)
dataset.append(Instance(raw_chars=chars, target=tags))
assert len(chars)==len(tags['target'])
dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars)))
if len(dataset)==0: if len(dataset)==0:
raise RuntimeError(f"{path} has no valid data.") raise RuntimeError(f"{path} has no valid data.")
if bigram: if bigram:
@@ -84,7 +107,7 @@ class SigHanLoader(DataSetLoader):


def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None,
char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None,
bigram_embed_opt:EmbeddingOption=None):
bigram_embed_opt:EmbeddingOption=None, L:int=4):
""" """
支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如


@@ -113,7 +136,7 @@ class SigHanLoader(DataSetLoader):
data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt
# 包含以下的内容data.vocabs['chars']: Vocabulary对象 # 包含以下的内容data.vocabs['chars']: Vocabulary对象
# data.vocabs['target']:Vocabulary对象 # data.vocabs['target']:Vocabulary对象
# data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项
# data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象;
# data.datasets['train']: DataSet对象 # data.datasets['train']: DataSet对象
# 包含的field有: # 包含的field有:
# raw_chars: list[str], 每个元素是一个汉字 # raw_chars: list[str], 每个元素是一个汉字
@@ -132,79 +155,95 @@ class SigHanLoader(DataSetLoader):
:param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 :param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。
为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e<eos> 为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e<eos>
:param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 :param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效
:param L: 当target_type为shift_relay时传入的segment长度
:return: :return:
""" """
# 推荐大家使用这个check_data_loader_paths进行paths的验证 # 推荐大家使用这个check_data_loader_paths进行paths的验证
paths = check_dataloader_paths(paths) paths = check_dataloader_paths(paths)
datasets = {} datasets = {}
data = DataInfo()
bigram = bigram_vocab_opt is not None bigram = bigram_vocab_opt is not None
for name, path in paths.items(): for name, path in paths.items():
dataset = self.load(path, bigram=bigram) dataset = self.load(path, bigram=bigram)
datasets[name] = dataset datasets[name] = dataset
input_fields = []
target_fields = []
# 创建vocab # 创建vocab
char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt)
char_vocab.from_dataset(datasets['train'], field_name='raw_chars') char_vocab.from_dataset(datasets['train'], field_name='raw_chars')
char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars')
data.vocabs[Const.CHAR_INPUT] = char_vocab
input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET])
target_fields.append(Const.TARGET)
# 创建target # 创建target
if self.target_type == 'bmes': if self.target_type == 'bmes':
target_vocab = Vocabulary(unknown=None, padding=None) target_vocab = Vocabulary(unknown=None, padding=None)
target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S'])
target_vocab.index_dataset(*datasets.values(), field_name='target') target_vocab.index_dataset(*datasets.values(), field_name='target')
data.vocabs[Const.TARGET] = target_vocab
if char_embed_opt is not None:
char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab)
data.embeddings['chars'] = char_embed
if bigram: if bigram:
bigram_vocab = Vocabulary(**bigram_vocab_opt) bigram_vocab = Vocabulary(**bigram_vocab_opt)
bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') bigram_vocab.from_dataset(datasets['train'], field_name='bigrams')
bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams')
data.vocabs['bigrams'] = bigram_vocab
if bigram_embed_opt is not None: if bigram_embed_opt is not None:
pass



bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab)
data.embeddings['bigrams'] = bigram_embed
input_fields.append('bigrams')
if self.target_type == 'shift_relay':
func = partial(self._clip_target, L=L)
for name, dataset in datasets.items():
res = dataset.apply_field(func, field_name='target')
relay_target = [res_i[0] for res_i in res]
relay_mask = [res_i[1] for res_i in res]
dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False)
dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False)
if self.target_type == 'shift_relay':
input_fields.extend(['end_seg_mask'])
target_fields.append('start_seg_mask')
# 将dataset加入DataInfo
for name, dataset in datasets.items():
dataset.set_input(*input_fields)
dataset.set_target(*target_fields)
data.datasets[name] = dataset

return data


import os
@staticmethod
def _clip_target(target:List[int], L:int):
"""


def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
"""
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
{
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
'test': 'xxx' # 可能有,也可能没有
...
}
如果paths为不合法的,将直接进行raise相应的错误

:param paths: 路径
:return:
"""
if isinstance(paths, str):
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
train_fp = os.path.join(paths, 'train.txt')
if not os.path.isfile(train_fp):
raise FileNotFoundError(f"train.txt is not found in folder {paths}.")
files = {'train': train_fp}
for filename in ['test.txt', 'dev.txt']:
fp = os.path.join(paths, filename)
if os.path.isfile(fp):
files[filename.split('.')[0]] = fp
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")

elif isinstance(paths, dict):
if paths:
if 'train' not in paths:
raise KeyError("You have to include `train` in your dict.")
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
if not os.path.isfile(value):
raise TypeError(f"{value} is not a valid file.")
else:
raise TypeError("All keys and values in paths should be str.")
return paths
只有在target_type为shift_relay的使用
:param target: List[int]
:param L:
:return:
"""
relay_target_i = []
tmp = []
for j in range(len(target) - 1):
tmp.append(target[j])
if target[j] > target[j + 1]:
pass
else:
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]])
tmp = []
# 处理未结束的部分
if len(tmp) == 0:
relay_target_i.append(0)
else: else:
raise ValueError("Empty paths is not allowed.")
else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.")

tmp.append(target[-1])
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]])
relay_mask_i = []
j = 0
while j < len(target):
seg_len = target[j] + 1
if target[j] < L:
relay_mask_i.extend([0] * (seg_len))
else:
relay_mask_i.extend([1] * (seg_len - L) + [0] * L)
j = seg_len + j
return relay_target_i, relay_mask_i



+ 44
- 0
reproduction/seqence_labelling/cws/model/metric.py View File

@@ -0,0 +1,44 @@

from fastNLP.core.metrics import MetricBase


class RelayMetric(MetricBase):
def __init__(self, pred=None, pred_mask=None, target=None, start_seg_mask=None):
super().__init__()
self._init_param_map(pred=pred, pred_mask=pred_mask, target=target, start_seg_mask=start_seg_mask)
self.tp = 0
self.rec = 0
self.pre = 0

def evaluate(self, pred, pred_mask, target, start_seg_mask):
"""
给定每个batch,累计一下结果。

:param pred: 预测的结果,为当前位置的开始的segment的(长度-1)
:param pred_mask: 当前位置预测有segment开始
:param target: 当前位置开始的segment的(长度-1)
:param start_seg_mask: 当前有segment结束
:return:
"""
self.tp += ((pred.long().eq(target.long())).__and__(pred_mask.byte().__and__(start_seg_mask.byte()))).sum().item()
self.rec += start_seg_mask.sum().item()
self.pre += pred_mask.sum().item()

def get_metric(self, reset=True):
"""
在所有数据都计算结束之后,得到performance
:param reset:
:return:
"""
pre = self.tp/(self.pre + 1e-12)
rec = self.tp/(self.rec + 1e-12)
f = 2*pre*rec/(1e-12 + pre + rec)

if reset:
self.tp = 0
self.rec = 0
self.pre = 0
self.bigger_than_L = 0

return {'f': round(f, 6), 'pre': round(pre, 6), 'rec': round(rec, 6)}

+ 74
- 0
reproduction/seqence_labelling/cws/model/model.py View File

@@ -0,0 +1,74 @@
from torch import nn
import torch
from fastNLP.modules import Embedding
import numpy as np
from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay
from fastNLP.modules import LSTM

class ShiftRelayCWSModel(nn.Module):
"""
该模型可以用于进行分词操作
包含两个方法,
forward(chars, bigrams, seq_len) -> {'loss': batch_size,}
predict(chars, bigrams) -> {'pred': batch_size x max_len, 'pred_mask': batch_size x max_len}
pred是对当前segment的长度预测,pred_mask是仅在有预测的位置为1

:param char_embed: 预训练的Embedding或者embedding的shape
:param bigram_embed: 预训练的Embedding或者embedding的shape
:param hidden_size: LSTM的隐藏层大小
:param num_layers: LSTM的层数
:param L: SemiCRFShiftRelay的segment大小
:param num_bigram_per_char: 每个character对应的bigram的数量
:param drop_p: Dropout的大小
"""
def __init__(self, char_embed:Embedding, bigram_embed:Embedding, hidden_size:int=400, num_layers:int=1,
L:int=6, num_bigram_per_char:int=1, drop_p:float=0.2):
super().__init__()
self.char_embedding = Embedding(char_embed, dropout=drop_p)
self._pretrained_embed = False
if isinstance(char_embed, np.ndarray):
self._pretrained_embed = True
self.bigram_embedding = Embedding(bigram_embed, dropout=drop_p)
self.lstm = LSTM(100 * (num_bigram_per_char + 1), hidden_size // 2, num_layers=num_layers, bidirectional=True,
batch_first=True)
self.feature_fn = FeatureFunMax(hidden_size, L)
self.semi_crf_relay = SemiCRFShiftRelay(L)
self.feat_drop = nn.Dropout(drop_p)
self.reset_param()
# self.feature_fn.reset_parameters()

def reset_param(self):
for name, param in self.named_parameters():
if 'embedding' in name and self._pretrained_embed:
continue
if 'bias_hh' in name:
nn.init.constant_(param, 0)
elif 'bias_ih' in name:
nn.init.constant_(param, 1)
elif len(param.size()) < 2:
nn.init.uniform_(param, -0.1, 0.1)
else:
nn.init.xavier_uniform_(param)

def get_feats(self, chars, bigrams, seq_len):
batch_size, max_len = chars.size()
chars = self.char_embedding(chars)
bigrams = self.bigram_embedding(bigrams)
bigrams = bigrams.view(bigrams.size(0), max_len, -1)
chars = torch.cat([chars, bigrams], dim=-1)
feats, _ = self.lstm(chars, seq_len)
feats = self.feat_drop(feats)
logits, relay_logits = self.feature_fn(feats)

return logits, relay_logits

def forward(self, chars, bigrams, relay_target, relay_mask, end_seg_mask, seq_len):
logits, relay_logits = self.get_feats(chars, bigrams, seq_len)
loss = self.semi_crf_relay(logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len)
return {'loss':loss}

def predict(self, chars, bigrams, seq_len):
logits, relay_logits = self.get_feats(chars, bigrams, seq_len)
pred, pred_mask = self.semi_crf_relay.predict(logits, relay_logits, seq_len)
return {'pred': pred, 'pred_mask': pred_mask}


+ 198
- 0
reproduction/seqence_labelling/cws/model/module.py View File

@@ -0,0 +1,198 @@
from torch import nn
import torch
from fastNLP.modules import Embedding
import numpy as np

class SemiCRFShiftRelay(nn.Module):
"""
该模块是一个decoder,但

"""
def __init__(self, L):
"""

:param L: 不包含relay的长度
"""
if L<2:
raise RuntimeError()
super().__init__()
self.L = L

def forward(self, logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len):
"""
relay node是接下来L个字都不是它的结束。relay的状态是往后滑动1个位置

:param logits: batch_size x max_len x L, 当前位置往左边L个segment的分数,最后一维的0是长度为1的segment(即本身)
:param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数
:param relay_target: batch_size x max_len 每个位置他的segment在哪里开始的。如果超过L,则一直保持为L-1。比如长度为
5的词,L=3, [0, 1, 2, 2, 2]
:param relay_mask: batch_size x max_len, 在需要relay的地方为1, 长度为5的词, L=3时,为[1, 1, 1, 0, 0]
:param end_seg_mask: batch_size x max_len, segment结束的地方为1。
:param seq_len: batch_size, 句子的长度
:return: loss: batch_size,
"""
batch_size, max_len, L = logits.size()

# 当前时刻为relay node的分数是多少
relay_scores = logits.new_zeros(batch_size, max_len)
# 当前时刻结束的分数是多少
scores = logits.new_zeros(batch_size, max_len+1)
# golden的分数
gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(0), 0) + \
logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(0), 0)
# 初始化
scores[:, 1] = logits[:, 0, 0]
batch_i = torch.arange(batch_size).to(logits.device).long()
relay_scores[:, 0] = relay_logits[:, 0]
last_relay_index = max_len - self.L
for t in range(1, max_len):
real_L = min(t+1, L)
flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment
# 计算relay_scores的更新
if t<last_relay_index:
# (1) 从正常位置跳转
tmp1 = relay_logits[:, t] + scores[:, t] # batch_size
# (2) 从relay跳转
tmp2 = relay_logits[:, t] + relay_scores[:, t-1] # batch_size
tmp1 = torch.stack([tmp1, tmp2], dim=0)
relay_scores[:, t] = torch.logsumexp(tmp1, dim=0)
# 计算scores的更新
# (1)从之前的位置跳转过来的
tmp1 = scores[:, t-real_L+1:t+1] + flip_logits_t # batch_size x L
if t>self.L-1:
# (2)从relay跳转过来的
tmp2 = relay_scores[:, t-self.L] # batch_size
tmp2 = tmp2 + flip_logits_t[:, 0] # batch_size
tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1)
scores[:, t+1] = torch.logsumexp(tmp1, dim=-1) # 更新当前时刻的分数

# 计算golden
seg_i = relay_target[:, t] # batch_size
gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(0), 0) # batch_size, 后向从0到L长度的segment的分数
relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(0), 0)
gold_scores = gold_scores + relay_score + gold_segment_scores
all_scores = scores.gather(dim=1, index=seq_len.unsqueeze(1)).squeeze(1) # batch_size
return all_scores - gold_scores

def predict(self, logits, relay_logits, seq_len):
"""
relay node是接下来L个字都不是它的结束。relay的状态是往后滑动L-1个位置

:param logits: batch_size x max_len x L, 当前位置左边L个segment的分数,最后一维的0是长度为1的segment(即本身)
:param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数
:param seq_len: batch_size, 句子的长度
:return: pred: batch_size x max_len以该点开始的segment的(长度-1); pred_mask为1的地方预测有segment开始
"""
batch_size, max_len, L = logits.size()
# 当前时刻为relay node的分数是多少
max_relay_scores = logits.new_zeros(batch_size, max_len)
relay_bt = seq_len.new_zeros(batch_size, max_len) # 当前结果是否来自于relay的结果
# 当前时刻结束的分数是多少
max_scores = logits.new_zeros(batch_size, max_len+1)
bt = seq_len.new_zeros(batch_size, max_len)
# 初始化
max_scores[:, 1] = logits[:, 0, 0]
max_relay_scores[:, 0] = relay_logits[:, 0]
last_relay_index = max_len - self.L
for t in range(1, max_len):
real_L = min(t+1, L)
flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment
# 计算relay_scores的更新
if t<last_relay_index:
# (1) 从正常位置跳转
tmp1 = relay_logits[:, t] + max_scores[:, t]
# (2) 从relay跳转
tmp2 = relay_logits[:, t] + max_relay_scores[:, t-1] # batch_size
# 每个sample的倒数L位不能是relay了
tmp2 = tmp2.masked_fill(seq_len.le(t+L), float('-inf'))
mask_i = tmp1.lt(tmp2) # 为1的位置为relay跳转
relay_bt[:, t].masked_fill_(mask_i, 1)
max_relay_scores[:, t] = torch.max(tmp1, tmp2)

# 计算scores的更新
# (1)从之前的位置跳转过来的
tmp1 = max_scores[:, t-real_L+1:t+1] + flip_logits_t # batch_size x L
tmp1 = tmp1.flip(dims=[1]) # 0的位置代表长度为1的segment
if self.L-1<t:
# (2)从relay跳转过来的
tmp2 = max_relay_scores[:, t-self.L] # batch_size
tmp2 = tmp2 + flip_logits_t[:, 0]
tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1)
# 看哪个更大
max_score, pt = torch.max(tmp1, dim=1)
max_scores[:, t+1] = max_score
# mask_i = pt.ge(self.L)
bt[:, t] = pt # 假设L=3, 那么对于0,1,2,3分别代表的是[t, t], [t-1, t], [t-2, t], [t-self.L(relay), t]
# 需要把结果decode出来
pred = np.zeros((batch_size, max_len), dtype=int)
pred_mask = np.zeros((batch_size, max_len), dtype=int)
seq_len = seq_len.tolist()
bt = bt.tolist()
relay_bt = relay_bt.tolist()
for b in range(batch_size):
seq_len_i = seq_len[b]
bt_i = bt[b][:seq_len_i]
relay_bt_i = relay_bt[b][:seq_len_i]
j = seq_len_i - 1
assert relay_bt_i[j]!=1
while j>-1:
if bt_i[j]==self.L:
seg_start_pos = j
j = j-self.L
while relay_bt_i[j]!=0 and j>-1:
j = j - 1
pred[b, j] = seg_start_pos - j
pred_mask[b, j] = 1
else:
length = bt_i[j]
j = j - bt_i[j]
pred_mask[b, j] = 1
pred[b, j] = length
j = j - 1

return torch.LongTensor(pred).to(logits.device), torch.LongTensor(pred_mask).to(logits.device)



class FeatureFunMax(nn.Module):
def __init__(self, hidden_size:int, L:int):
"""
用于计算semi-CRF特征的函数。给定batch_size x max_len x hidden_size形状的输入,输出为batch_size x max_len x L的
分数,以及batch_size x max_len的relay的分数。两者的区别参考论文 TODO 补充

:param hidden_size: 输入特征的维度大小
:param L: 不包含relay node的segment的长度大小。
"""
super().__init__()

self.end_fc = nn.Linear(hidden_size, 1, bias=False)
self.whole_w = nn.Parameter(torch.randn(L, hidden_size))
self.relay_fc = nn.Linear(hidden_size, 1)
self.length_bias = nn.Parameter(torch.randn(L))
self.L = L
def forward(self, logits):
"""

:param logits: batch_size x max_len x hidden_size
:return: batch_size x max_len x L # 最后一维为左边segment的分数,0处为长度为1的segment
batch_size x max_len, # 当前位置是接下来L-1个位置都不是终点的分数

"""
batch_size, max_len, hidden_size = logits.size()
# start_scores = self.start_fc(logits) # batch_size x max_len x 1 # 每个位置作为start的分数
tmp = logits.new_zeros(batch_size, max_len+self.L-1, hidden_size)
tmp[:, -max_len:] = logits
# batch_size x max_len x hidden_size x (self.L) -> batch_size x max_len x (self.L) x hidden_size
start_logits = tmp.unfold(dimension=1, size=self.L, step=1).transpose(2, 3).flip(dims=[2])
end_scores = self.end_fc(logits) # batch_size x max_len x 1
# 计算relay的特征
relay_tmp = logits.new_zeros(batch_size, max_len, hidden_size)
relay_tmp[:, :-self.L] = logits[:, self.L:]
# batch_size x max_len x hidden_size
relay_logits_max = torch.max(relay_tmp, logits) # end - start
logits_max = torch.max(logits.unsqueeze(2), start_logits) # batch_size x max_len x L x hidden_size
whole_scores = (logits_max*self.whole_w).sum(dim=-1) # batch_size x max_len x self.L
# whole_scores = self.whole_fc().squeeze(-1) # bz x max_len x self.L
# batch_size x max_len
relay_scores = self.relay_fc(relay_logits_max).squeeze(-1)
return whole_scores+end_scores+self.length_bias.view(1, 1, -1), relay_scores

+ 0
- 0
reproduction/seqence_labelling/cws/test/__init__.py View File


+ 17
- 0
reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py View File

@@ -0,0 +1,17 @@


import unittest
from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader
from fastNLP.core.vocabulary import VocabularyOption


class TestCWSDataLoader(unittest.TestCase):
def test_case1(self):
cws_loader = SigHanLoader(target_type='bmes')
data = cws_loader.process('pku_demo.txt')
print(data.datasets)

def test_calse2(self):
cws_loader = SigHanLoader(target_type='bmes')
data = cws_loader.process('pku_demo.txt', bigram_vocab_opt=VocabularyOption())
print(data.datasets)

+ 68
- 0
reproduction/seqence_labelling/cws/train_shift_relay.py View File

@@ -0,0 +1,68 @@

import os

from fastNLP import cache_results
from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader
from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel
from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP import Trainer
from torch.optim import Adam
from fastNLP import BucketSampler
from fastNLP import GradientClipCallback
from reproduction.seqence_labelling.cws.model.metric import RelayMetric


# 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果
@cache_results(None)
def prepare_data():
data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt,
bigram_vocab_opt=bigram_vocab_opt,
bigram_embed_opt=bigram_embed_opt,
L=L)
return data

#########hyper
L = 4
hidden_size = 200
num_layers = 1
drop_p = 0.2
lr = 0.02

#########hyper
device = 0

# !!!!这里前往不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到
# 你们的reproduction路径下,然后设置.gitignore
file_dir = '/path/to/pku'
char_embed_path = '/path/to/1grams_t3_m50_corpus.txt'
bigram_embed_path = 'path/to/2grams_t3_m50_corpus.txt'
bigram_vocab_opt = VocabularyOption(min_freq=3)
char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path)
bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path)

data_name = os.path.basename(file_dir)
cache_fp = 'caches/{}.pkl'.format(data_name)

data = prepare_data(_cache_fp=cache_fp, _refresh=False)

model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'],
hidden_size=hidden_size, num_layers=num_layers,
L=L, num_bigram_per_char=1, drop_p=drop_p)

sampler = BucketSampler(batch_size=32)
optimizer = Adam(model.parameters(), lr=lr)
clipper = GradientClipCallback(clip_value=5, clip_type='value')
callbacks = [clipper]
# if pretrain:
# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until)
# callbacks.append(fixer)
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None,
batch_size=32, sampler=sampler, update_every=5,
n_epochs=3, print_every=5,
dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f',
validate_every=-1, save_path=None,
prefetch=True, use_tqdm=True, device=device,
callbacks=callbacks,
check_code_level=0)
trainer.train()

+ 51
- 0
reproduction/utils.py View File

@@ -0,0 +1,51 @@
import os

from typing import Union, Dict


def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
"""
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
{
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
'test': 'xxx' # 可能有,也可能没有
...
}
如果paths为不合法的,将直接进行raise相应的错误

:param paths: 路径
:return:
"""
if isinstance(paths, str):
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
train_fp = os.path.join(paths, 'train.txt')
if not os.path.isfile(train_fp):
raise FileNotFoundError(f"train.txt is not found in folder {paths}.")
files = {'train': train_fp}
for filename in ['test.txt', 'dev.txt']:
fp = os.path.join(paths, filename)
if os.path.isfile(fp):
files[filename.split('.')[0]] = fp
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")

elif isinstance(paths, dict):
if paths:
if 'train' not in paths:
raise KeyError("You have to include `train` in your dict.")
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
if not os.path.isfile(value):
raise TypeError(f"{value} is not a valid file.")
else:
raise TypeError("All keys and values in paths should be str.")
return paths
else:
raise ValueError("Empty paths is not allowed.")
else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.")



Loading…
Cancel
Save