@@ -690,11 +690,11 @@ class Trainer(object): | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
and self.dev_data is not None: | |||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) | |||
# pbar.write(eval_str + '\n') | |||
self.logger.info(eval_str + '\n') | |||
self.logger.info(eval_str) | |||
self.logger.info(self.tester._format_eval_results(eval_res)+'\n') | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
@@ -907,7 +907,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL | |||
info_str += '\n' | |||
else: | |||
info_str += 'There is no target field.' | |||
print(info_str) | |||
logger.info(info_str) | |||
_check_forward_error(forward_func=forward_func, dataset=dataset, | |||
batch_x=batch_x, check_level=check_level) | |||
refined_batch_x = _build_args(forward_func, **batch_x) | |||
@@ -67,8 +67,8 @@ class BertEmbedding(ContextualEmbedding): | |||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name)) | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
@@ -59,7 +59,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
@@ -70,10 +70,10 @@ class StaticEmbedding(TokenEmbedding): | |||
model_url = _get_embedding_url('static', model_dir_or_name.lower()) | |||
model_path = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_path = model_dir_or_name | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_path = _get_file_name_base_on_postfix(model_dir_or_name, '.txt') | |||
elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_path = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
@@ -94,7 +94,7 @@ class StaticEmbedding(TokenEmbedding): | |||
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | |||
# 只限制在train里面的词语使用min_freq筛选 | |||
if kwargs.get('only_train_min_freq', False): | |||
if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None: | |||
for word in truncated_vocab.word_count.keys(): | |||
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]<min_freq: | |||
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | |||
@@ -114,8 +114,8 @@ class StaticEmbedding(TokenEmbedding): | |||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | |||
else: | |||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | |||
print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " | |||
f"words, {len(lowered_vocab)} unique lowered words.") | |||
print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||
f"unique lowered words.") | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | |||
else: | |||
@@ -222,7 +222,8 @@ class DataBundle: | |||
:param bool flag: 将field_name的target状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return self | |||
""" | |||
for field_name in field_names: | |||
@@ -241,16 +242,61 @@ class DataBundle: | |||
:param str field_name: | |||
:param str new_field_name: | |||
:param bool ignore_miss_dataset: 若DataBundle中的DataSet的 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.copy_field(field_name=field_name, new_field_name=new_field_name) | |||
elif ignore_miss_dataset: | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): | |||
""" | |||
对DataBundle中所有的dataset使用apply方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str field_name: 传入func的是哪个field。 | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||
1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input | |||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | |||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def apply(self, func, new_field_name:str, **kwargs): | |||
""" | |||
对DataBundle中所有的dataset使用apply方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||
1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input | |||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | |||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | |||
return self | |||
def __repr__(self): | |||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
@@ -1,13 +1,15 @@ | |||
from .loader import Loader | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
import glob | |||
import os | |||
import time | |||
import shutil | |||
import random | |||
class CWSLoader(Loader): | |||
""" | |||
分词任务数据加载器, | |||
SigHan2005的数据可以用xxx下载并预处理 | |||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||
Example:: | |||
@@ -24,9 +26,16 @@ class CWSLoader(Loader): | |||
"上海 浦东 开发 与 法制 建设 同步" | |||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||
"..." | |||
:param: str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | |||
""" | |||
def __init__(self): | |||
def __init__(self, dataset_name:str=None): | |||
super().__init__() | |||
datanames = {'pku': 'cws-pku', 'msra':'cws-msra', 'as':'cws-as', 'cityu':'cws-cityu'} | |||
if dataset_name in datanames: | |||
self.dataset_name = datanames[dataset_name] | |||
else: | |||
self.dataset_name = None | |||
def _load(self, path:str): | |||
ds = DataSet() | |||
@@ -37,5 +46,42 @@ class CWSLoader(Loader): | |||
ds.append(Instance(raw_words=line)) | |||
return ds | |||
def download(self, output_dir=None): | |||
raise RuntimeError("You can refer {} for sighan2005's data downloading.") | |||
def download(self, dev_ratio=0.1, re_download=False)->str: | |||
""" | |||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | |||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str | |||
""" | |||
if self.dataset_name is None: | |||
return None | |||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||
modify_time = 0 | |||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||
modify_time = os.stat(filepath).st_mtime | |||
break | |||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.txt')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
try: | |||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||
for line in f: | |||
if random.random() < dev_ratio: | |||
f2.write(line) | |||
else: | |||
f1.write(line) | |||
os.remove(os.path.join(data_dir, 'train.txt')) | |||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||
return data_dir |
@@ -21,6 +21,7 @@ __all__ = [ | |||
"MsraNERPipe", | |||
"WeiboNERPipe", | |||
"PeopleDailyPipe", | |||
"Conll2003Pipe", | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
@@ -41,3 +42,4 @@ from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe | |||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | |||
from .pipe import Pipe | |||
from .conll import Conll2003Pipe |
@@ -19,16 +19,14 @@ class _NERPipe(Pipe): | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0): | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.lower = lower | |||
self.target_pad_val = int(target_pad_val) | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
@@ -58,7 +56,6 @@ class _NERPipe(Pipe): | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.set_pad_val(Const.TARGET, self.target_pad_val) | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(*input_fields) | |||
@@ -86,7 +83,6 @@ class Conll2003NERPipe(_NERPipe): | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||
""" | |||
def process_from_file(self, paths) -> DataBundle: | |||
@@ -103,7 +99,7 @@ class Conll2003NERPipe(_NERPipe): | |||
class Conll2003Pipe(Pipe): | |||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, target_pad_val=0): | |||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | |||
""" | |||
经过该Pipe后,DataSet中的内容如下 | |||
@@ -119,7 +115,6 @@ class Conll2003Pipe(Pipe): | |||
:param str chunk_encoding_type: 支持bioes, bio。 | |||
:param str ner_encoding_type: 支持bioes, bio。 | |||
:param bool lower: 是否将words列小写化后再建立词表 | |||
:param int target_pad_val: pos, ner, chunk列的padding值 | |||
""" | |||
if chunk_encoding_type == 'bio': | |||
self.chunk_convert_tag = iob2 | |||
@@ -130,7 +125,6 @@ class Conll2003Pipe(Pipe): | |||
else: | |||
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||
self.lower = lower | |||
self.target_pad_val = int(target_pad_val) | |||
def process(self, data_bundle)->DataBundle: | |||
""" | |||
@@ -166,9 +160,6 @@ class Conll2003Pipe(Pipe): | |||
target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.set_pad_val('pos', self.target_pad_val) | |||
dataset.set_pad_val('ner', self.target_pad_val) | |||
dataset.set_pad_val('chunk', self.target_pad_val) | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(*input_fields) | |||
@@ -202,7 +193,6 @@ class OntoNotesNERPipe(_NERPipe): | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||
""" | |||
def process_from_file(self, paths): | |||
@@ -220,15 +210,13 @@ class _CNNERPipe(Pipe): | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', target_pad_val=0): | |||
def __init__(self, encoding_type: str = 'bio'): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.target_pad_val = int(target_pad_val) | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
@@ -261,7 +249,6 @@ class _CNNERPipe(Pipe): | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.set_pad_val(Const.TARGET, self.target_pad_val) | |||
dataset.add_seq_len(Const.CHAR_INPUT) | |||
data_bundle.set_input(*input_fields) | |||
@@ -324,7 +311,6 @@ class WeiboNERPipe(_CNNERPipe): | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
data_bundle = WeiboNERLoader().load(paths) | |||
@@ -0,0 +1,246 @@ | |||
from .pipe import Pipe | |||
from .. import DataBundle | |||
from ..loader import CWSLoader | |||
from ... import Const | |||
from itertools import chain | |||
from .utils import _indexize | |||
import re | |||
def _word_lens_to_bmes(word_lens): | |||
""" | |||
:param list word_lens: List[int], 每个词语的长度 | |||
:return: List[str], BMES的序列 | |||
""" | |||
tags = [] | |||
for word_len in word_lens: | |||
if word_len==1: | |||
tags.append('S') | |||
else: | |||
tags.append('B') | |||
tags.extend(['M']*(word_len-2)) | |||
tags.append('E') | |||
return tags | |||
def _word_lens_to_segapp(word_lens): | |||
""" | |||
:param list word_lens: List[int], 每个词语的长度 | |||
:return: List[str], BMES的序列 | |||
""" | |||
tags = [] | |||
for word_len in word_lens: | |||
if word_len==1: | |||
tags.append('SEG') | |||
else: | |||
tags.extend(['APP']*(word_len-1)) | |||
tags.append('SEG') | |||
return tags | |||
def _alpha_span_to_special_tag(span): | |||
""" | |||
将span替换成特殊的字符 | |||
:param str span: | |||
:return: | |||
""" | |||
if 'oo' == span.lower(): # speical case when represent 2OO8 | |||
return span | |||
if len(span) == 1: | |||
return span | |||
else: | |||
return '<ENG>' | |||
def _find_and_replace_alpha_spans(line): | |||
""" | |||
传入原始句子,替换其中的字母为特殊标记 | |||
:param str line:原始数据 | |||
:return: str | |||
""" | |||
new_line = '' | |||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%,.。!<-“])' | |||
prev_end = 0 | |||
for match in re.finditer(pattern, line): | |||
start, end = match.span() | |||
span = line[start:end] | |||
new_line += line[prev_end:start] + _alpha_span_to_special_tag(span) | |||
prev_end = end | |||
new_line += line[prev_end:] | |||
return new_line | |||
def _digit_span_to_special_tag(span): | |||
""" | |||
:param str span: 需要替换的str | |||
:return: | |||
""" | |||
if span[0] == '0' and len(span) > 2: | |||
return '<NUM>' | |||
decimal_point_count = 0 # one might have more than one decimal pointers | |||
for idx, char in enumerate(span): | |||
if char == '.' or char == '﹒' or char == '·': | |||
decimal_point_count += 1 | |||
if span[-1] == '.' or span[-1] == '﹒' or span[ | |||
-1] == '·': # last digit being decimal point means this is not a number | |||
if decimal_point_count == 1: | |||
return span | |||
else: | |||
return '<UNKDGT>' | |||
if decimal_point_count == 1: | |||
return '<DEC>' | |||
elif decimal_point_count > 1: | |||
return '<UNKDGT>' | |||
else: | |||
return '<NUM>' | |||
def _find_and_replace_digit_spans(line): | |||
# only consider words start with number, contains '.', characters. | |||
# If ends with space, will be processed | |||
# If ends with Chinese character, will be processed | |||
# If ends with or contains english char, not handled. | |||
# floats are replaced by <DEC> | |||
# otherwise unkdgt | |||
new_line = '' | |||
pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])' | |||
prev_end = 0 | |||
for match in re.finditer(pattern, line): | |||
start, end = match.span() | |||
span = line[start:end] | |||
new_line += line[prev_end:start] + _digit_span_to_special_tag(span) | |||
prev_end = end | |||
new_line += line[prev_end:] | |||
return new_line | |||
class CWSPipe(Pipe): | |||
""" | |||
对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "chars", "target", "bigrams", "trigrams", "seq_len" | |||
"共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", "[10, 4, 1,...]","[6, 4, 1,...]", 13 | |||
"2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", "[11, 12, ...]","[3, 9, ...]", 20 | |||
"...", "[...]","[...]", "[...]","[...]", . | |||
其中bigrams仅当bigrams列为True的时候为真 | |||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | |||
:param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp | |||
的tag为[seg, app, seg, app, app, app, seg, ...] | |||
:param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 | |||
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] | |||
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
""" | |||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | |||
if encoding_type=='bmes': | |||
self.word_lens_to_tags = _word_lens_to_bmes | |||
else: | |||
self.word_lens_to_tags = _word_lens_to_segapp | |||
self.dataset_name = dataset_name | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
self.replace_num_alpha = replace_num_alpha | |||
def _tokenize(self, data_bundle): | |||
""" | |||
将data_bundle中的'chars'列切分成一个一个的word. | |||
例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] | |||
:param data_bundle: | |||
:return: | |||
""" | |||
def split_word_into_chars(raw_chars): | |||
words = raw_chars.split() | |||
chars = [] | |||
for word in words: | |||
char = [] | |||
subchar = [] | |||
for c in word: | |||
if c=='<': | |||
subchar.append(c) | |||
continue | |||
if c=='>' and subchar[0]=='<': | |||
char.append(''.join(subchar)) | |||
subchar = [] | |||
if subchar: | |||
subchar.append(c) | |||
else: | |||
char.append(c) | |||
char.extend(subchar) | |||
chars.append(char) | |||
return chars | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, | |||
new_field_name=Const.CHAR_INPUT) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
可以处理的DataSet需要包含raw_words列 | |||
.. csv-table:: | |||
:header: "raw_words" | |||
"上海 浦东 开发 与 法制 建设 同步" | |||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||
"..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) | |||
if self.replace_num_alpha: | |||
data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) | |||
data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) | |||
self._tokenize(data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars:self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT, | |||
new_field_name=Const.TARGET) | |||
dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT, | |||
new_field_name=Const.CHAR_INPUT) | |||
input_field_names = [Const.CHAR_INPUT] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])], | |||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1+c2+c3 for c1, c2, c3 in zip(chars, chars[1:]+['<eos>'], chars[2:]+['<eos>']*2)], | |||
field_name=Const.CHAR_INPUT, new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
_indexize(data_bundle, input_field_names, Const.TARGET) | |||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.CHAR_INPUT) | |||
data_bundle.set_input(*input_fields) | |||
data_bundle.set_target(*target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
""" | |||
:param str paths: | |||
:return: | |||
""" | |||
if self.dataset_name is None and paths is None: | |||
raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") | |||
if self.dataset_name is not None and paths is not None: | |||
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") | |||
data_bundle = CWSLoader(self.dataset_name).load(paths) | |||
return self.process(data_bundle) |
@@ -1,249 +0,0 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
from reproduction.utils import check_dataloader_paths | |||
from functools import partial | |||
class SigHanLoader(DataSetLoader): | |||
""" | |||
任务相关的说明可以在这里找到http://sighan.cs.uchicago.edu/ | |||
支持的数据格式为,一行一句,不同的word用空格隔开。如下例 | |||
共同 创造 美好 的 新 世纪 —— 二○○一年 新年 | |||
女士 们 , 先生 们 , 同志 们 , 朋友 们 : | |||
读取sighan中的数据集,返回的DataSet将包含以下的内容fields: | |||
raw_chars: list(str), 每个元素是一个汉字 | |||
chars: list(str), 每个元素是一个index(汉字对应的index) | |||
target: list(int), 根据不同的encoding_type会有不同的变化 | |||
:param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay" | |||
""" | |||
def __init__(self, target_type:str): | |||
super().__init__() | |||
if target_type.lower() not in ('bmes', 'shift_relay'): | |||
raise ValueError("target_type only supports 'bmes', 'shift_relay'.") | |||
self.target_type = target_type | |||
if target_type=='bmes': | |||
self._word_len_to_target = self._word_len_to_bems | |||
elif target_type=='shift_relay': | |||
self._word_len_to_target = self._word_lens_to_relay | |||
@staticmethod | |||
def _word_lens_to_relay(word_lens: Iterator[int]): | |||
""" | |||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||
:param word_lens: | |||
:return: {'target': , 'end_seg_mask':, 'start_seg_mask':} | |||
""" | |||
tags = [] | |||
end_seg_mask = [] | |||
start_seg_mask = [] | |||
for word_len in word_lens: | |||
tags.extend([idx for idx in range(word_len - 1, -1, -1)]) | |||
end_seg_mask.extend([0] * (word_len - 1) + [1]) | |||
start_seg_mask.extend([1] + [0] * (word_len - 1)) | |||
return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask} | |||
@staticmethod | |||
def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]: | |||
""" | |||
:param word_lens: 每个word的长度 | |||
:return: | |||
""" | |||
tags = [] | |||
for word_len in word_lens: | |||
if word_len==1: | |||
tags.append('S') | |||
else: | |||
tags.append('B') | |||
for _ in range(word_len-2): | |||
tags.append('M') | |||
tags.append('E') | |||
return {'target':tags} | |||
@staticmethod | |||
def _gen_bigram(chars:List[str])->List[str]: | |||
""" | |||
:param chars: | |||
:return: | |||
""" | |||
return [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])] | |||
def load(self, path:str, bigram:bool=False)->DataSet: | |||
""" | |||
:param path: str | |||
:param bigram: 是否使用bigram feature | |||
:return: | |||
""" | |||
dataset = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if not line: # 去掉空行 | |||
continue | |||
parts = line.split() | |||
word_lens = map(len, parts) | |||
chars = list(''.join(parts)) | |||
tags = self._word_len_to_target(word_lens) | |||
assert len(chars)==len(tags['target']) | |||
dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars))) | |||
if len(dataset)==0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
if bigram: | |||
dataset.apply_field(self._gen_bigram, field_name='raw_chars', new_field_name='bigrams') | |||
return dataset | |||
def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, | |||
char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, | |||
bigram_embed_opt:EmbeddingOption=None, L:int=4): | |||
""" | |||
支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 | |||
Option:: | |||
共同 创造 美好 的 新 世纪 —— 二○○一年 新年 贺词 | |||
( 二○○○年 十二月 三十一日 ) ( 附 图片 1 张 ) | |||
女士 们 , 先生 们 , 同志 们 , 朋友 们 : | |||
paths支持两种格式,第一种是str,第二种是Dict[str, str]. | |||
Option:: | |||
# 1. str类型 | |||
# 1.1 传入具体的文件路径 | |||
data = SigHanLoader('bmes').process('/path/to/cws/data.txt') # 将读取data.txt的内容 | |||
# 包含以下的内容data.vocabs['chars']:Vocabulary对象, | |||
# data.vocabs['target']: Vocabulary对象,根据encoding_type可能会没有该值 | |||
# data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项 | |||
# data.datasets['train']: DataSet对象 | |||
# 包含的field有: | |||
# raw_chars: list[str], 每个元素是一个汉字 | |||
# chars: list[int], 每个元素是汉字对应的index | |||
# target: list[int], 根据encoding_type有对应的变化 | |||
# 1.2 传入一个目录, 里面必须包含train.txt文件 | |||
data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt | |||
# 包含以下的内容data.vocabs['chars']: Vocabulary对象 | |||
# data.vocabs['target']:Vocabulary对象 | |||
# data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象; | |||
# data.datasets['train']: DataSet对象 | |||
# 包含的field有: | |||
# raw_chars: list[str], 每个元素是一个汉字 | |||
# chars: list[int], 每个元素是汉字对应的index | |||
# target: list[int], 根据encoding_type有对应的变化 | |||
# data.datasets['dev']: DataSet对象,如果文件夹下包含了dev.txt;内容与data.datasets['train']一样 | |||
# 2. dict类型, key是文件的名称,value是对应的读取路径. 必须包含'train'这个key | |||
paths = {'train': '/path/to/train/train.txt', 'test':'/path/to/test/test.txt', 'dev':'/path/to/dev/dev.txt'} | |||
data = SigHanLoader(paths).process(paths) | |||
# 结果与传入目录时是一致的,但是可以传入多个数据集。data.datasets中的key将与这里传入的一致 | |||
:param paths: 支持传入目录,文件路径,以及dict。 | |||
:param char_vocab_opt: 用于构建chars的vocabulary参数,默认为min_freq=2 | |||
:param char_embed_opt: 用于读取chars的Embedding的参数,默认不读取pretrained的embedding | |||
:param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 | |||
为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e<eos> | |||
:param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 | |||
:param L: 当target_type为shift_relay时传入的segment长度 | |||
:return: | |||
""" | |||
# 推荐大家使用这个check_data_loader_paths进行paths的验证 | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
data = DataBundle() | |||
bigram = bigram_vocab_opt is not None | |||
for name, path in paths.items(): | |||
dataset = self.load(path, bigram=bigram) | |||
datasets[name] = dataset | |||
input_fields = [] | |||
target_fields = [] | |||
# 创建vocab | |||
char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) | |||
char_vocab.from_dataset(datasets['train'], field_name='raw_chars') | |||
char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') | |||
data.vocabs[Const.CHAR_INPUT] = char_vocab | |||
input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET]) | |||
target_fields.append(Const.TARGET) | |||
# 创建target | |||
if self.target_type == 'bmes': | |||
target_vocab = Vocabulary(unknown=None, padding=None) | |||
target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) | |||
target_vocab.index_dataset(*datasets.values(), field_name='target') | |||
data.vocabs[Const.TARGET] = target_vocab | |||
if char_embed_opt is not None: | |||
char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab) | |||
data.embeddings['chars'] = char_embed | |||
if bigram: | |||
bigram_vocab = Vocabulary(**bigram_vocab_opt) | |||
bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') | |||
bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') | |||
data.vocabs['bigrams'] = bigram_vocab | |||
if bigram_embed_opt is not None: | |||
bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab) | |||
data.embeddings['bigrams'] = bigram_embed | |||
input_fields.append('bigrams') | |||
if self.target_type == 'shift_relay': | |||
func = partial(self._clip_target, L=L) | |||
for name, dataset in datasets.items(): | |||
res = dataset.apply_field(func, field_name='target') | |||
relay_target = [res_i[0] for res_i in res] | |||
relay_mask = [res_i[1] for res_i in res] | |||
dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) | |||
dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) | |||
if self.target_type == 'shift_relay': | |||
input_fields.extend(['end_seg_mask']) | |||
target_fields.append('start_seg_mask') | |||
# 将dataset加入DataInfo | |||
for name, dataset in datasets.items(): | |||
dataset.set_input(*input_fields) | |||
dataset.set_target(*target_fields) | |||
data.datasets[name] = dataset | |||
return data | |||
@staticmethod | |||
def _clip_target(target:List[int], L:int): | |||
""" | |||
只有在target_type为shift_relay的使用 | |||
:param target: List[int] | |||
:param L: | |||
:return: | |||
""" | |||
relay_target_i = [] | |||
tmp = [] | |||
for j in range(len(target) - 1): | |||
tmp.append(target[j]) | |||
if target[j] > target[j + 1]: | |||
pass | |||
else: | |||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||
tmp = [] | |||
# 处理未结束的部分 | |||
if len(tmp) == 0: | |||
relay_target_i.append(0) | |||
else: | |||
tmp.append(target[-1]) | |||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||
relay_mask_i = [] | |||
j = 0 | |||
while j < len(target): | |||
seg_len = target[j] + 1 | |||
if target[j] < L: | |||
relay_mask_i.extend([0] * (seg_len)) | |||
else: | |||
relay_mask_i.extend([1] * (seg_len - L) + [0] * L) | |||
j = seg_len + j | |||
return relay_target_i, relay_mask_i | |||
@@ -0,0 +1,202 @@ | |||
from fastNLP.io.pipe import Pipe | |||
from fastNLP.io import DataBundle | |||
from fastNLP.io.loader import CWSLoader | |||
from fastNLP import Const | |||
from itertools import chain | |||
from fastNLP.io.pipe.utils import _indexize | |||
from functools import partial | |||
from fastNLP.io.pipe.cws import _find_and_replace_alpha_spans, _find_and_replace_digit_spans | |||
def _word_lens_to_relay(word_lens): | |||
""" | |||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||
:param word_lens: | |||
:return: | |||
""" | |||
tags = [] | |||
for word_len in word_lens: | |||
tags.extend([idx for idx in range(word_len - 1, -1, -1)]) | |||
return tags | |||
def _word_lens_to_end_seg_mask(word_lens): | |||
""" | |||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||
:param word_lens: | |||
:return: | |||
""" | |||
end_seg_mask = [] | |||
for word_len in word_lens: | |||
end_seg_mask.extend([0] * (word_len - 1) + [1]) | |||
return end_seg_mask | |||
def _word_lens_to_start_seg_mask(word_lens): | |||
""" | |||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||
:param word_lens: | |||
:return: | |||
""" | |||
start_seg_mask = [] | |||
for word_len in word_lens: | |||
start_seg_mask.extend([1] + [0] * (word_len - 1)) | |||
return start_seg_mask | |||
class CWSShiftRelayPipe(Pipe): | |||
""" | |||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | |||
:param int L: ShiftRelay模型的超参数 | |||
:param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 | |||
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] | |||
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
""" | |||
def __init__(self, dataset_name=None, L=5, replace_num_alpha=True, bigrams=True): | |||
self.dataset_name = dataset_name | |||
self.bigrams = bigrams | |||
self.replace_num_alpha = replace_num_alpha | |||
self.L = L | |||
def _tokenize(self, data_bundle): | |||
""" | |||
将data_bundle中的'chars'列切分成一个一个的word. | |||
例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] | |||
:param data_bundle: | |||
:return: | |||
""" | |||
def split_word_into_chars(raw_chars): | |||
words = raw_chars.split() | |||
chars = [] | |||
for word in words: | |||
char = [] | |||
subchar = [] | |||
for c in word: | |||
if c=='<': | |||
subchar.append(c) | |||
continue | |||
if c=='>' and subchar[0]=='<': | |||
char.append(''.join(subchar)) | |||
subchar = [] | |||
if subchar: | |||
subchar.append(c) | |||
else: | |||
char.append(c) | |||
char.extend(subchar) | |||
chars.append(char) | |||
return chars | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, | |||
new_field_name=Const.CHAR_INPUT) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
可以处理的DataSet需要包含raw_words列 | |||
.. csv-table:: | |||
:header: "raw_words" | |||
"上海 浦东 开发 与 法制 建设 同步" | |||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||
"..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) | |||
if self.replace_num_alpha: | |||
data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) | |||
data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) | |||
self._tokenize(data_bundle) | |||
input_field_names = [Const.CHAR_INPUT] | |||
target_field_names = [] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars:_word_lens_to_relay(map(len, chars)), field_name=Const.CHAR_INPUT, | |||
new_field_name=Const.TARGET) | |||
dataset.apply_field(lambda chars:_word_lens_to_start_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT, | |||
new_field_name='start_seg_mask') | |||
dataset.apply_field(lambda chars:_word_lens_to_end_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT, | |||
new_field_name='end_seg_mask') | |||
dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT, | |||
new_field_name=Const.CHAR_INPUT) | |||
target_field_names.append('start_seg_mask') | |||
input_field_names.append('end_seg_mask') | |||
if self.bigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])], | |||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
_indexize(data_bundle, ['chars', 'bigrams'], []) | |||
func = partial(_clip_target, L=self.L) | |||
for name, dataset in data_bundle.datasets.items(): | |||
res = dataset.apply_field(func, field_name='target') | |||
relay_target = [res_i[0] for res_i in res] | |||
relay_mask = [res_i[1] for res_i in res] | |||
dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) | |||
dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) | |||
input_field_names.append('relay_target') | |||
input_field_names.append('relay_mask') | |||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] + target_field_names | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.CHAR_INPUT) | |||
data_bundle.set_input(*input_fields) | |||
data_bundle.set_target(*target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
""" | |||
:param str paths: | |||
:return: | |||
""" | |||
if self.dataset_name is None and paths is None: | |||
raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") | |||
if self.dataset_name is not None and paths is not None: | |||
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") | |||
data_bundle = CWSLoader(self.dataset_name).load(paths) | |||
return self.process(data_bundle) | |||
def _clip_target(target, L:int): | |||
""" | |||
只有在target_type为shift_relay的使用 | |||
:param target: List[int] | |||
:param L: | |||
:return: | |||
""" | |||
relay_target_i = [] | |||
tmp = [] | |||
for j in range(len(target) - 1): | |||
tmp.append(target[j]) | |||
if target[j] > target[j + 1]: | |||
pass | |||
else: | |||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||
tmp = [] | |||
# 处理未结束的部分 | |||
if len(tmp) == 0: | |||
relay_target_i.append(0) | |||
else: | |||
tmp.append(target[-1]) | |||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||
relay_mask_i = [] | |||
j = 0 | |||
while j < len(target): | |||
seg_len = target[j] + 1 | |||
if target[j] < L: | |||
relay_mask_i.extend([0] * (seg_len)) | |||
else: | |||
relay_mask_i.extend([1] * (seg_len - L) + [0] * L) | |||
j = seg_len + j | |||
return relay_target_i, relay_mask_i |
@@ -0,0 +1,60 @@ | |||
import torch | |||
from fastNLP.modules import LSTM | |||
from fastNLP.modules import allowed_transitions, ConditionalRandomField | |||
from fastNLP import seq_len_to_mask | |||
from torch import nn | |||
from fastNLP import Const | |||
import torch.nn.functional as F | |||
class BiLSTMCRF(nn.Module): | |||
def __init__(self, char_embed, hidden_size, num_layers, target_vocab=None, bigram_embed=None, trigram_embed=None, | |||
dropout=0.5): | |||
super().__init__() | |||
embed_size = char_embed.embed_size | |||
self.char_embed = char_embed | |||
if bigram_embed: | |||
embed_size += bigram_embed.embed_size | |||
self.bigram_embed = bigram_embed | |||
if trigram_embed: | |||
embed_size += trigram_embed.embed_size | |||
self.trigram_embed = trigram_embed | |||
self.lstm = LSTM(embed_size, hidden_size=hidden_size//2, bidirectional=True, batch_first=True, | |||
num_layers=num_layers) | |||
self.dropout = nn.Dropout(p=dropout) | |||
self.fc = nn.Linear(hidden_size, len(target_vocab)) | |||
transitions = None | |||
if target_vocab: | |||
transitions = allowed_transitions(target_vocab, include_start_end=True, encoding_type='bmes') | |||
self.crf = ConditionalRandomField(num_tags=len(target_vocab), allowed_transitions=transitions) | |||
def _forward(self, chars, bigrams, trigrams, seq_len, target=None): | |||
chars = self.char_embed(chars) | |||
if bigrams is not None: | |||
bigrams = self.bigram_embed(bigrams) | |||
chars = torch.cat([chars, bigrams], dim=-1) | |||
if trigrams is not None: | |||
trigrams = self.trigram_embed(trigrams) | |||
chars = torch.cat([chars, trigrams], dim=-1) | |||
output, _ = self.lstm(chars, seq_len) | |||
output = self.dropout(output) | |||
output = self.fc(output) | |||
output = F.log_softmax(output, dim=-1) | |||
mask = seq_len_to_mask(seq_len) | |||
if target is None: | |||
pred, _ = self.crf.viterbi_decode(output, mask) | |||
return {Const.OUTPUT:pred} | |||
else: | |||
loss = self.crf.forward(output, tags=target, mask=mask) | |||
return {Const.LOSS:loss} | |||
def forward(self, chars, seq_len, target, bigrams=None, trigrams=None): | |||
return self._forward(chars, bigrams, trigrams, seq_len, target) | |||
def predict(self, chars, seq_len, bigrams=None, trigrams=None): | |||
return self._forward(chars, bigrams, trigrams, seq_len) |
@@ -1,7 +1,5 @@ | |||
from torch import nn | |||
import torch | |||
from fastNLP.embeddings import Embedding | |||
import numpy as np | |||
from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay | |||
from fastNLP.modules import LSTM | |||
@@ -21,25 +19,21 @@ class ShiftRelayCWSModel(nn.Module): | |||
:param num_bigram_per_char: 每个character对应的bigram的数量 | |||
:param drop_p: Dropout的大小 | |||
""" | |||
def __init__(self, char_embed:Embedding, bigram_embed:Embedding, hidden_size:int=400, num_layers:int=1, | |||
L:int=6, num_bigram_per_char:int=1, drop_p:float=0.2): | |||
def __init__(self, char_embed, bigram_embed, hidden_size:int=400, num_layers:int=1, L:int=6, drop_p:float=0.2): | |||
super().__init__() | |||
self.char_embedding = Embedding(char_embed, dropout=drop_p) | |||
self._pretrained_embed = False | |||
if isinstance(char_embed, np.ndarray): | |||
self._pretrained_embed = True | |||
self.bigram_embedding = Embedding(bigram_embed, dropout=drop_p) | |||
self.lstm = LSTM(100 * (num_bigram_per_char + 1), hidden_size // 2, num_layers=num_layers, bidirectional=True, | |||
self.char_embedding = char_embed | |||
self.bigram_embedding = bigram_embed | |||
self.lstm = LSTM(char_embed.embed_size+bigram_embed.embed_size, hidden_size // 2, num_layers=num_layers, | |||
bidirectional=True, | |||
batch_first=True) | |||
self.feature_fn = FeatureFunMax(hidden_size, L) | |||
self.semi_crf_relay = SemiCRFShiftRelay(L) | |||
self.feat_drop = nn.Dropout(drop_p) | |||
self.reset_param() | |||
# self.feature_fn.reset_parameters() | |||
def reset_param(self): | |||
for name, param in self.named_parameters(): | |||
if 'embedding' in name and self._pretrained_embed: | |||
if 'embedding' in name: | |||
continue | |||
if 'bias_hh' in name: | |||
nn.init.constant_(param, 0) | |||
@@ -51,10 +45,8 @@ class ShiftRelayCWSModel(nn.Module): | |||
nn.init.xavier_uniform_(param) | |||
def get_feats(self, chars, bigrams, seq_len): | |||
batch_size, max_len = chars.size() | |||
chars = self.char_embedding(chars) | |||
bigrams = self.bigram_embedding(bigrams) | |||
bigrams = bigrams.view(bigrams.size(0), max_len, -1) | |||
chars = torch.cat([chars, bigrams], dim=-1) | |||
feats, _ = self.lstm(chars, seq_len) | |||
feats = self.feat_drop(feats) |
@@ -0,0 +1,52 @@ | |||
import sys | |||
sys.path.append('../../..') | |||
from fastNLP.io.pipe.cws import CWSPipe | |||
from reproduction.seqence_labelling.cws.model.bilstm_crf_cws import BiLSTMCRF | |||
from fastNLP import Trainer, cache_results | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import EvaluateCallback, BucketSampler, SpanFPreRecMetric, GradientClipCallback | |||
from torch.optim import Adagrad | |||
###########hyper | |||
dataname = 'pku' | |||
hidden_size = 400 | |||
num_layers = 1 | |||
lr = 0.05 | |||
###########hyper | |||
@cache_results('{}.pkl'.format(dataname), _refresh=False) | |||
def get_data(): | |||
data_bundle = CWSPipe(dataset_name=dataname, bigrams=True, trigrams=False).process_from_file() | |||
char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.33, word_dropout=0.01, | |||
model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt') | |||
bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.33,min_freq=3, word_dropout=0.01, | |||
model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt') | |||
return data_bundle, char_embed, bigram_embed | |||
data_bundle, char_embed, bigram_embed = get_data() | |||
print(data_bundle) | |||
model = BiLSTMCRF(char_embed, hidden_size, num_layers, target_vocab=data_bundle.get_vocab('target'), bigram_embed=bigram_embed, | |||
trigram_embed=None, dropout=0.3) | |||
model.cuda() | |||
callbacks = [] | |||
callbacks.append(EvaluateCallback(data_bundle.get_dataset('test'))) | |||
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) | |||
optimizer = Adagrad(model.parameters(), lr=lr) | |||
metrics = [] | |||
metric1 = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type='bmes') | |||
metrics.append(metric1) | |||
trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer=optimizer, loss=None, | |||
batch_size=128, sampler=BucketSampler(), update_every=1, | |||
num_workers=1, n_epochs=10, print_every=5, | |||
dev_data=data_bundle.get_dataset('dev'), | |||
metrics=metrics, | |||
metric_key=None, | |||
validate_every=-1, save_path=None, use_tqdm=True, device=0, | |||
callbacks=callbacks, check_code_level=0, dev_batch_size=128) | |||
trainer.train() |
@@ -1,64 +1,53 @@ | |||
import os | |||
import sys | |||
sys.path.append('../../..') | |||
from fastNLP import cache_results | |||
from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader | |||
from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel | |||
from fastNLP.io.embed_loader import EmbeddingOption | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from reproduction.seqence_labelling.cws.data.cws_shift_pipe import CWSShiftRelayPipe | |||
from reproduction.seqence_labelling.cws.model.bilstm_shift_relay import ShiftRelayCWSModel | |||
from fastNLP import Trainer | |||
from torch.optim import Adam | |||
from fastNLP import BucketSampler | |||
from fastNLP import GradientClipCallback | |||
from reproduction.seqence_labelling.cws.model.metric import RelayMetric | |||
# 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果 | |||
@cache_results(None) | |||
def prepare_data(): | |||
data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt, | |||
bigram_vocab_opt=bigram_vocab_opt, | |||
bigram_embed_opt=bigram_embed_opt, | |||
L=L) | |||
return data | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import EvaluateCallback | |||
#########hyper | |||
L = 4 | |||
hidden_size = 200 | |||
num_layers = 1 | |||
drop_p = 0.2 | |||
lr = 0.02 | |||
lr = 0.008 | |||
data_name = 'pku' | |||
#########hyper | |||
device = 0 | |||
# !!!!这里千万不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 | |||
# 你们的reproduction路径下,然后设置.gitignore | |||
file_dir = '/path/to/' | |||
char_embed_path = '/pretrain/vectors/1grams_t3_m50_corpus.txt' | |||
bigram_embed_path = '/pretrain/vectors/2grams_t3_m50_corpus.txt' | |||
bigram_vocab_opt = VocabularyOption(min_freq=3) | |||
char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path) | |||
bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path) | |||
data_name = os.path.basename(file_dir) | |||
cache_fp = 'caches/{}.pkl'.format(data_name) | |||
@cache_results(_cache_fp=cache_fp, _refresh=True) # 将结果缓存到cache_fp中,这样下次运行就直接读取,而不需要再次运行 | |||
def prepare_data(): | |||
data_bundle = CWSShiftRelayPipe(dataset_name=data_name, L=L).process_from_file() | |||
# 预训练的character embedding和bigram embedding | |||
char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.5, word_dropout=0.01, | |||
model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt') | |||
bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.5, min_freq=3, word_dropout=0.01, | |||
model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt') | |||
data = prepare_data(_cache_fp=cache_fp, _refresh=True) | |||
return data_bundle, char_embed, bigram_embed | |||
model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'], | |||
hidden_size=hidden_size, num_layers=num_layers, | |||
L=L, num_bigram_per_char=1, drop_p=drop_p) | |||
data, char_embed, bigram_embed = prepare_data() | |||
sampler = BucketSampler(batch_size=32) | |||
model = ShiftRelayCWSModel(char_embed=char_embed, bigram_embed=bigram_embed, | |||
hidden_size=hidden_size, num_layers=num_layers, drop_p=drop_p, L=L) | |||
sampler = BucketSampler() | |||
optimizer = Adam(model.parameters(), lr=lr) | |||
clipper = GradientClipCallback(clip_value=5, clip_type='value') | |||
callbacks = [clipper] | |||
# if pretrain: | |||
# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | |||
# callbacks.append(fixer) | |||
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, | |||
update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), | |||
clipper = GradientClipCallback(clip_value=5, clip_type='value') # 截断太大的梯度 | |||
evaluator = EvaluateCallback(data.get_dataset('test')) # 额外测试在test集上的效果 | |||
callbacks = [clipper, evaluator] | |||
trainer = Trainer(data.get_dataset('train'), model, optimizer=optimizer, loss=None, batch_size=128, sampler=sampler, | |||
update_every=1, n_epochs=10, print_every=5, dev_data=data.get_dataset('dev'), metrics=RelayMetric(), | |||
metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, | |||
check_code_level=0) | |||
check_code_level=0, num_workers=1) | |||
trainer.train() |
@@ -8,11 +8,10 @@ import torch.nn.functional as F | |||
from fastNLP import Const | |||
class CNNBiLSTMCRF(nn.Module): | |||
def __init__(self, embed, char_embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'): | |||
def __init__(self, embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'): | |||
super().__init__() | |||
self.embedding = embed | |||
self.char_embedding = char_embed | |||
self.lstm = LSTM(input_size=self.embedding.embedding_dim+self.char_embedding.embedding_dim, | |||
self.lstm = LSTM(input_size=self.embedding.embedding_dim, | |||
hidden_size=hidden_size//2, num_layers=num_layers, | |||
bidirectional=True, batch_first=True) | |||
self.fc = nn.Linear(hidden_size, len(tag_vocab)) | |||
@@ -32,9 +31,7 @@ class CNNBiLSTMCRF(nn.Module): | |||
nn.init.zeros_(param) | |||
def _forward(self, words, seq_len, target=None): | |||
word_embeds = self.embedding(words) | |||
char_embeds = self.char_embedding(words) | |||
words = torch.cat((word_embeds, char_embeds), dim=-1) | |||
words = self.embedding(words) | |||
outputs, _ = self.lstm(words, seq_len) | |||
self.dropout(outputs) | |||
@@ -1,7 +1,7 @@ | |||
import sys | |||
sys.path.append('../../..') | |||
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding | |||
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding, StackEmbedding | |||
from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF | |||
from fastNLP import Trainer | |||
@@ -22,7 +22,7 @@ def load_data(): | |||
paths = {'test':"NER/corpus/CoNLL-2003/eng.testb", | |||
'train':"NER/corpus/CoNLL-2003/eng.train", | |||
'dev':"NER/corpus/CoNLL-2003/eng.testa"} | |||
data = Conll2003NERPipe(encoding_type=encoding_type, target_pad_val=0).process_from_file(paths) | |||
data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths) | |||
return data | |||
data = load_data() | |||
print(data) | |||
@@ -33,8 +33,9 @@ word_embed = StaticEmbedding(vocab=data.get_vocab('words'), | |||
model_dir_or_name='en-glove-6b-100d', | |||
requires_grad=True, lower=True, word_dropout=0.01, dropout=0.5) | |||
word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std() | |||
embed = StackEmbedding([word_embed, char_embed]) | |||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||
model = CNNBiLSTMCRF(embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||
encoding_type=encoding_type) | |||
callbacks = [ | |||
@@ -2,7 +2,7 @@ import sys | |||
sys.path.append('../../..') | |||
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding | |||
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding, StackEmbedding | |||
from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF | |||
from fastNLP import Trainer | |||
@@ -35,7 +35,7 @@ def cache(): | |||
char_embed = CNNCharEmbedding(vocab=data.vocabs['words'], embed_size=30, char_emb_size=30, filter_nums=[30], | |||
kernel_sizes=[3], dropout=dropout) | |||
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | |||
model_dir_or_name='en-glove-100d', | |||
model_dir_or_name='en-glove-6b-100d', | |||
requires_grad=True, | |||
normalize=normalize, | |||
word_dropout=0.01, | |||
@@ -47,7 +47,8 @@ data, char_embed, word_embed = cache() | |||
print(data) | |||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||
embed = StackEmbedding([word_embed, char_embed]) | |||
model = CNNBiLSTMCRF(embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||
encoding_type=encoding_type, dropout=dropout) | |||
callbacks = [ | |||
@@ -0,0 +1,13 @@ | |||
import unittest | |||
import os | |||
from fastNLP.io.loader import CWSLoader | |||
class CWSLoaderTest(unittest.TestCase): | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_download(self): | |||
dataset_names = ['pku', 'cityu', 'as', 'msra'] | |||
for dataset_name in dataset_names: | |||
with self.subTest(dataset_name=dataset_name): | |||
data_bundle = CWSLoader(dataset_name=dataset_name).load() | |||
print(data_bundle) |
@@ -0,0 +1,13 @@ | |||
import unittest | |||
import os | |||
from fastNLP.io.pipe.cws import CWSPipe | |||
class CWSPipeTest(unittest.TestCase): | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_process_from_file(self): | |||
dataset_names = ['pku', 'cityu', 'as', 'msra'] | |||
for dataset_name in dataset_names: | |||
with self.subTest(dataset_name=dataset_name): | |||
data_bundle = CWSPipe(dataset_name=dataset_name).process_from_file() | |||
print(data_bundle) |