@@ -690,11 +690,11 @@ class Trainer(object): | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | ||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
self.n_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, | |||||
self.n_steps) | |||||
# pbar.write(eval_str + '\n') | # pbar.write(eval_str + '\n') | ||||
self.logger.info(eval_str + '\n') | |||||
self.logger.info(eval_str) | |||||
self.logger.info(self.tester._format_eval_results(eval_res)+'\n') | |||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
@@ -907,7 +907,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL | |||||
info_str += '\n' | info_str += '\n' | ||||
else: | else: | ||||
info_str += 'There is no target field.' | info_str += 'There is no target field.' | ||||
print(info_str) | |||||
logger.info(info_str) | |||||
_check_forward_error(forward_func=forward_func, dataset=dataset, | _check_forward_error(forward_func=forward_func, dataset=dataset, | ||||
batch_x=batch_x, check_level=check_level) | batch_x=batch_x, check_level=check_level) | ||||
refined_batch_x = _build_args(forward_func, **batch_x) | refined_batch_x = _build_args(forward_func, **batch_x) | ||||
@@ -67,8 +67,8 @@ class BertEmbedding(ContextualEmbedding): | |||||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | ||||
model_dir = cached_path(model_url, name='embedding') | model_dir = cached_path(model_url, name='embedding') | ||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||||
model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name)) | |||||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
@@ -59,7 +59,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) | model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) | ||||
model_dir = cached_path(model_url, name='embedding') | model_dir = cached_path(model_url, name='embedding') | ||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_dir = model_dir_or_name | model_dir = model_dir_or_name | ||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
@@ -70,10 +70,10 @@ class StaticEmbedding(TokenEmbedding): | |||||
model_url = _get_embedding_url('static', model_dir_or_name.lower()) | model_url = _get_embedding_url('static', model_dir_or_name.lower()) | ||||
model_path = cached_path(model_url, name='embedding') | model_path = cached_path(model_url, name='embedding') | ||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||||
model_path = model_dir_or_name | |||||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||||
model_path = _get_file_name_base_on_postfix(model_dir_or_name, '.txt') | |||||
elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_path = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||||
model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | |||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
@@ -94,7 +94,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | ||||
# 只限制在train里面的词语使用min_freq筛选 | # 只限制在train里面的词语使用min_freq筛选 | ||||
if kwargs.get('only_train_min_freq', False): | |||||
if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None: | |||||
for word in truncated_vocab.word_count.keys(): | for word in truncated_vocab.word_count.keys(): | ||||
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]<min_freq: | if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]<min_freq: | ||||
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | ||||
@@ -114,8 +114,8 @@ class StaticEmbedding(TokenEmbedding): | |||||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | lowered_vocab.add_word(word.lower(), no_create_entry=True) | ||||
else: | else: | ||||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | ||||
print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " | |||||
f"words, {len(lowered_vocab)} unique lowered words.") | |||||
print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||||
f"unique lowered words.") | |||||
if model_path: | if model_path: | ||||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | ||||
else: | else: | ||||
@@ -222,7 +222,8 @@ class DataBundle: | |||||
:param bool flag: 将field_name的target状态设置为flag | :param bool flag: 将field_name的target状态设置为flag | ||||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | ||||
行的数据进行类型和维度推断本列的数据的类型和维度。 | 行的数据进行类型和维度推断本列的数据的类型和维度。 | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 | |||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||||
如果为False,则报错 | |||||
:return self | :return self | ||||
""" | """ | ||||
for field_name in field_names: | for field_name in field_names: | ||||
@@ -241,16 +242,61 @@ class DataBundle: | |||||
:param str field_name: | :param str field_name: | ||||
:param str new_field_name: | :param str new_field_name: | ||||
:param bool ignore_miss_dataset: 若DataBundle中的DataSet的 | |||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||||
如果为False,则报错 | |||||
:return: self | :return: self | ||||
""" | """ | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
dataset.copy_field(field_name=field_name, new_field_name=new_field_name) | dataset.copy_field(field_name=field_name, new_field_name=new_field_name) | ||||
elif ignore_miss_dataset: | |||||
elif not ignore_miss_dataset: | |||||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||||
return self | |||||
def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): | |||||
""" | |||||
对DataBundle中所有的dataset使用apply方法 | |||||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||||
:param str field_name: 传入func的是哪个field。 | |||||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||||
盖之前的field。如果为None则不创建新的field。 | |||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||||
如果为False,则报错 | |||||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||||
1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input | |||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | |||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | |||||
""" | |||||
for name, dataset in self.datasets.items(): | |||||
if dataset.has_field(field_name=field_name): | |||||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs) | |||||
elif not ignore_miss_dataset: | |||||
raise KeyError(f"{field_name} not found DataSet:{name}.") | raise KeyError(f"{field_name} not found DataSet:{name}.") | ||||
return self | return self | ||||
def apply(self, func, new_field_name:str, **kwargs): | |||||
""" | |||||
对DataBundle中所有的dataset使用apply方法 | |||||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||||
盖之前的field。如果为None则不创建新的field。 | |||||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||||
1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input | |||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | |||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | |||||
""" | |||||
for name, dataset in self.datasets.items(): | |||||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | |||||
return self | |||||
def __repr__(self): | def __repr__(self): | ||||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | _str = 'In total {} datasets:\n'.format(len(self.datasets)) | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
@@ -1,13 +1,15 @@ | |||||
from .loader import Loader | from .loader import Loader | ||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
import glob | |||||
import os | |||||
import time | |||||
import shutil | |||||
import random | |||||
class CWSLoader(Loader): | class CWSLoader(Loader): | ||||
""" | """ | ||||
分词任务数据加载器, | |||||
SigHan2005的数据可以用xxx下载并预处理 | |||||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | ||||
Example:: | Example:: | ||||
@@ -24,9 +26,16 @@ class CWSLoader(Loader): | |||||
"上海 浦东 开发 与 法制 建设 同步" | "上海 浦东 开发 与 法制 建设 同步" | ||||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | ||||
"..." | "..." | ||||
:param: str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | |||||
""" | """ | ||||
def __init__(self): | |||||
def __init__(self, dataset_name:str=None): | |||||
super().__init__() | super().__init__() | ||||
datanames = {'pku': 'cws-pku', 'msra':'cws-msra', 'as':'cws-as', 'cityu':'cws-cityu'} | |||||
if dataset_name in datanames: | |||||
self.dataset_name = datanames[dataset_name] | |||||
else: | |||||
self.dataset_name = None | |||||
def _load(self, path:str): | def _load(self, path:str): | ||||
ds = DataSet() | ds = DataSet() | ||||
@@ -37,5 +46,42 @@ class CWSLoader(Loader): | |||||
ds.append(Instance(raw_words=line)) | ds.append(Instance(raw_words=line)) | ||||
return ds | return ds | ||||
def download(self, output_dir=None): | |||||
raise RuntimeError("You can refer {} for sighan2005's data downloading.") | |||||
def download(self, dev_ratio=0.1, re_download=False)->str: | |||||
""" | |||||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | |||||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | |||||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||||
:return: str | |||||
""" | |||||
if self.dataset_name is None: | |||||
return None | |||||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||||
modify_time = 0 | |||||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||||
modify_time = os.stat(filepath).st_mtime | |||||
break | |||||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||||
shutil.rmtree(data_dir) | |||||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||||
if not os.path.exists(os.path.join(data_dir, 'dev.txt')): | |||||
if dev_ratio > 0: | |||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||||
try: | |||||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||||
for line in f: | |||||
if random.random() < dev_ratio: | |||||
f2.write(line) | |||||
else: | |||||
f1.write(line) | |||||
os.remove(os.path.join(data_dir, 'train.txt')) | |||||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||||
finally: | |||||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||||
return data_dir |
@@ -21,6 +21,7 @@ __all__ = [ | |||||
"MsraNERPipe", | "MsraNERPipe", | ||||
"WeiboNERPipe", | "WeiboNERPipe", | ||||
"PeopleDailyPipe", | "PeopleDailyPipe", | ||||
"Conll2003Pipe", | |||||
"MatchingBertPipe", | "MatchingBertPipe", | ||||
"RTEBertPipe", | "RTEBertPipe", | ||||
@@ -41,3 +42,4 @@ from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe | |||||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | ||||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | ||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .conll import Conll2003Pipe |
@@ -19,16 +19,14 @@ class _NERPipe(Pipe): | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | ||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0): | |||||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
else: | else: | ||||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | self.convert_tag = lambda words: iob2bioes(iob2(words)) | ||||
self.lower = lower | self.lower = lower | ||||
self.target_pad_val = int(target_pad_val) | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | """ | ||||
@@ -58,7 +56,6 @@ class _NERPipe(Pipe): | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | target_fields = [Const.TARGET, Const.INPUT_LEN] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.set_pad_val(Const.TARGET, self.target_pad_val) | |||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
data_bundle.set_input(*input_fields) | data_bundle.set_input(*input_fields) | ||||
@@ -86,7 +83,6 @@ class Conll2003NERPipe(_NERPipe): | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | ||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||||
""" | """ | ||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
@@ -103,7 +99,7 @@ class Conll2003NERPipe(_NERPipe): | |||||
class Conll2003Pipe(Pipe): | class Conll2003Pipe(Pipe): | ||||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, target_pad_val=0): | |||||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | |||||
""" | """ | ||||
经过该Pipe后,DataSet中的内容如下 | 经过该Pipe后,DataSet中的内容如下 | ||||
@@ -119,7 +115,6 @@ class Conll2003Pipe(Pipe): | |||||
:param str chunk_encoding_type: 支持bioes, bio。 | :param str chunk_encoding_type: 支持bioes, bio。 | ||||
:param str ner_encoding_type: 支持bioes, bio。 | :param str ner_encoding_type: 支持bioes, bio。 | ||||
:param bool lower: 是否将words列小写化后再建立词表 | :param bool lower: 是否将words列小写化后再建立词表 | ||||
:param int target_pad_val: pos, ner, chunk列的padding值 | |||||
""" | """ | ||||
if chunk_encoding_type == 'bio': | if chunk_encoding_type == 'bio': | ||||
self.chunk_convert_tag = iob2 | self.chunk_convert_tag = iob2 | ||||
@@ -130,7 +125,6 @@ class Conll2003Pipe(Pipe): | |||||
else: | else: | ||||
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | ||||
self.lower = lower | self.lower = lower | ||||
self.target_pad_val = int(target_pad_val) | |||||
def process(self, data_bundle)->DataBundle: | def process(self, data_bundle)->DataBundle: | ||||
""" | """ | ||||
@@ -166,9 +160,6 @@ class Conll2003Pipe(Pipe): | |||||
target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] | target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.set_pad_val('pos', self.target_pad_val) | |||||
dataset.set_pad_val('ner', self.target_pad_val) | |||||
dataset.set_pad_val('chunk', self.target_pad_val) | |||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
data_bundle.set_input(*input_fields) | data_bundle.set_input(*input_fields) | ||||
@@ -202,7 +193,6 @@ class OntoNotesNERPipe(_NERPipe): | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | ||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||||
""" | """ | ||||
def process_from_file(self, paths): | def process_from_file(self, paths): | ||||
@@ -220,15 +210,13 @@ class _CNNERPipe(Pipe): | |||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 | ||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio', target_pad_val=0): | |||||
def __init__(self, encoding_type: str = 'bio'): | |||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
else: | else: | ||||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | self.convert_tag = lambda words: iob2bioes(iob2(words)) | ||||
self.target_pad_val = int(target_pad_val) | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | """ | ||||
@@ -261,7 +249,6 @@ class _CNNERPipe(Pipe): | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | target_fields = [Const.TARGET, Const.INPUT_LEN] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.set_pad_val(Const.TARGET, self.target_pad_val) | |||||
dataset.add_seq_len(Const.CHAR_INPUT) | dataset.add_seq_len(Const.CHAR_INPUT) | ||||
data_bundle.set_input(*input_fields) | data_bundle.set_input(*input_fields) | ||||
@@ -324,7 +311,6 @@ class WeiboNERPipe(_CNNERPipe): | |||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | ||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。 | |||||
""" | """ | ||||
def process_from_file(self, paths=None) -> DataBundle: | def process_from_file(self, paths=None) -> DataBundle: | ||||
data_bundle = WeiboNERLoader().load(paths) | data_bundle = WeiboNERLoader().load(paths) | ||||
@@ -0,0 +1,246 @@ | |||||
from .pipe import Pipe | |||||
from .. import DataBundle | |||||
from ..loader import CWSLoader | |||||
from ... import Const | |||||
from itertools import chain | |||||
from .utils import _indexize | |||||
import re | |||||
def _word_lens_to_bmes(word_lens): | |||||
""" | |||||
:param list word_lens: List[int], 每个词语的长度 | |||||
:return: List[str], BMES的序列 | |||||
""" | |||||
tags = [] | |||||
for word_len in word_lens: | |||||
if word_len==1: | |||||
tags.append('S') | |||||
else: | |||||
tags.append('B') | |||||
tags.extend(['M']*(word_len-2)) | |||||
tags.append('E') | |||||
return tags | |||||
def _word_lens_to_segapp(word_lens): | |||||
""" | |||||
:param list word_lens: List[int], 每个词语的长度 | |||||
:return: List[str], BMES的序列 | |||||
""" | |||||
tags = [] | |||||
for word_len in word_lens: | |||||
if word_len==1: | |||||
tags.append('SEG') | |||||
else: | |||||
tags.extend(['APP']*(word_len-1)) | |||||
tags.append('SEG') | |||||
return tags | |||||
def _alpha_span_to_special_tag(span): | |||||
""" | |||||
将span替换成特殊的字符 | |||||
:param str span: | |||||
:return: | |||||
""" | |||||
if 'oo' == span.lower(): # speical case when represent 2OO8 | |||||
return span | |||||
if len(span) == 1: | |||||
return span | |||||
else: | |||||
return '<ENG>' | |||||
def _find_and_replace_alpha_spans(line): | |||||
""" | |||||
传入原始句子,替换其中的字母为特殊标记 | |||||
:param str line:原始数据 | |||||
:return: str | |||||
""" | |||||
new_line = '' | |||||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%,.。!<-“])' | |||||
prev_end = 0 | |||||
for match in re.finditer(pattern, line): | |||||
start, end = match.span() | |||||
span = line[start:end] | |||||
new_line += line[prev_end:start] + _alpha_span_to_special_tag(span) | |||||
prev_end = end | |||||
new_line += line[prev_end:] | |||||
return new_line | |||||
def _digit_span_to_special_tag(span): | |||||
""" | |||||
:param str span: 需要替换的str | |||||
:return: | |||||
""" | |||||
if span[0] == '0' and len(span) > 2: | |||||
return '<NUM>' | |||||
decimal_point_count = 0 # one might have more than one decimal pointers | |||||
for idx, char in enumerate(span): | |||||
if char == '.' or char == '﹒' or char == '·': | |||||
decimal_point_count += 1 | |||||
if span[-1] == '.' or span[-1] == '﹒' or span[ | |||||
-1] == '·': # last digit being decimal point means this is not a number | |||||
if decimal_point_count == 1: | |||||
return span | |||||
else: | |||||
return '<UNKDGT>' | |||||
if decimal_point_count == 1: | |||||
return '<DEC>' | |||||
elif decimal_point_count > 1: | |||||
return '<UNKDGT>' | |||||
else: | |||||
return '<NUM>' | |||||
def _find_and_replace_digit_spans(line): | |||||
# only consider words start with number, contains '.', characters. | |||||
# If ends with space, will be processed | |||||
# If ends with Chinese character, will be processed | |||||
# If ends with or contains english char, not handled. | |||||
# floats are replaced by <DEC> | |||||
# otherwise unkdgt | |||||
new_line = '' | |||||
pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])' | |||||
prev_end = 0 | |||||
for match in re.finditer(pattern, line): | |||||
start, end = match.span() | |||||
span = line[start:end] | |||||
new_line += line[prev_end:start] + _digit_span_to_special_tag(span) | |||||
prev_end = end | |||||
new_line += line[prev_end:] | |||||
return new_line | |||||
class CWSPipe(Pipe): | |||||
""" | |||||
对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | |||||
.. csv-table:: | |||||
:header: "raw_words", "chars", "target", "bigrams", "trigrams", "seq_len" | |||||
"共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", "[10, 4, 1,...]","[6, 4, 1,...]", 13 | |||||
"2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", "[11, 12, ...]","[3, 9, ...]", 20 | |||||
"...", "[...]","[...]", "[...]","[...]", . | |||||
其中bigrams仅当bigrams列为True的时候为真 | |||||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | |||||
:param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp | |||||
的tag为[seg, app, seg, app, app, app, seg, ...] | |||||
:param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 | |||||
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] | |||||
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||||
""" | |||||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | |||||
if encoding_type=='bmes': | |||||
self.word_lens_to_tags = _word_lens_to_bmes | |||||
else: | |||||
self.word_lens_to_tags = _word_lens_to_segapp | |||||
self.dataset_name = dataset_name | |||||
self.bigrams = bigrams | |||||
self.trigrams = trigrams | |||||
self.replace_num_alpha = replace_num_alpha | |||||
def _tokenize(self, data_bundle): | |||||
""" | |||||
将data_bundle中的'chars'列切分成一个一个的word. | |||||
例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
def split_word_into_chars(raw_chars): | |||||
words = raw_chars.split() | |||||
chars = [] | |||||
for word in words: | |||||
char = [] | |||||
subchar = [] | |||||
for c in word: | |||||
if c=='<': | |||||
subchar.append(c) | |||||
continue | |||||
if c=='>' and subchar[0]=='<': | |||||
char.append(''.join(subchar)) | |||||
subchar = [] | |||||
if subchar: | |||||
subchar.append(c) | |||||
else: | |||||
char.append(c) | |||||
char.extend(subchar) | |||||
chars.append(char) | |||||
return chars | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, | |||||
new_field_name=Const.CHAR_INPUT) | |||||
return data_bundle | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||||
""" | |||||
可以处理的DataSet需要包含raw_words列 | |||||
.. csv-table:: | |||||
:header: "raw_words" | |||||
"上海 浦东 开发 与 法制 建设 同步" | |||||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||||
"..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) | |||||
if self.replace_num_alpha: | |||||
data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) | |||||
data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) | |||||
self._tokenize(data_bundle) | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(lambda chars:self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT, | |||||
new_field_name=Const.TARGET) | |||||
dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT, | |||||
new_field_name=Const.CHAR_INPUT) | |||||
input_field_names = [Const.CHAR_INPUT] | |||||
if self.bigrams: | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])], | |||||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||||
input_field_names.append('bigrams') | |||||
if self.trigrams: | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(lambda chars: [c1+c2+c3 for c1, c2, c3 in zip(chars, chars[1:]+['<eos>'], chars[2:]+['<eos>']*2)], | |||||
field_name=Const.CHAR_INPUT, new_field_name='trigrams') | |||||
input_field_names.append('trigrams') | |||||
_indexize(data_bundle, input_field_names, Const.TARGET) | |||||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.CHAR_INPUT) | |||||
data_bundle.set_input(*input_fields) | |||||
data_bundle.set_target(*target_fields) | |||||
return data_bundle | |||||
def process_from_file(self, paths=None) -> DataBundle: | |||||
""" | |||||
:param str paths: | |||||
:return: | |||||
""" | |||||
if self.dataset_name is None and paths is None: | |||||
raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") | |||||
if self.dataset_name is not None and paths is not None: | |||||
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") | |||||
data_bundle = CWSLoader(self.dataset_name).load(paths) | |||||
return self.process(data_bundle) |
@@ -1,249 +0,0 @@ | |||||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
from fastNLP.io.data_bundle import DataSetLoader, DataBundle | |||||
from typing import Union, Dict, List, Iterator | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import Vocabulary | |||||
from fastNLP import Const | |||||
from reproduction.utils import check_dataloader_paths | |||||
from functools import partial | |||||
class SigHanLoader(DataSetLoader): | |||||
""" | |||||
任务相关的说明可以在这里找到http://sighan.cs.uchicago.edu/ | |||||
支持的数据格式为,一行一句,不同的word用空格隔开。如下例 | |||||
共同 创造 美好 的 新 世纪 —— 二○○一年 新年 | |||||
女士 们 , 先生 们 , 同志 们 , 朋友 们 : | |||||
读取sighan中的数据集,返回的DataSet将包含以下的内容fields: | |||||
raw_chars: list(str), 每个元素是一个汉字 | |||||
chars: list(str), 每个元素是一个index(汉字对应的index) | |||||
target: list(int), 根据不同的encoding_type会有不同的变化 | |||||
:param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay" | |||||
""" | |||||
def __init__(self, target_type:str): | |||||
super().__init__() | |||||
if target_type.lower() not in ('bmes', 'shift_relay'): | |||||
raise ValueError("target_type only supports 'bmes', 'shift_relay'.") | |||||
self.target_type = target_type | |||||
if target_type=='bmes': | |||||
self._word_len_to_target = self._word_len_to_bems | |||||
elif target_type=='shift_relay': | |||||
self._word_len_to_target = self._word_lens_to_relay | |||||
@staticmethod | |||||
def _word_lens_to_relay(word_lens: Iterator[int]): | |||||
""" | |||||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||||
:param word_lens: | |||||
:return: {'target': , 'end_seg_mask':, 'start_seg_mask':} | |||||
""" | |||||
tags = [] | |||||
end_seg_mask = [] | |||||
start_seg_mask = [] | |||||
for word_len in word_lens: | |||||
tags.extend([idx for idx in range(word_len - 1, -1, -1)]) | |||||
end_seg_mask.extend([0] * (word_len - 1) + [1]) | |||||
start_seg_mask.extend([1] + [0] * (word_len - 1)) | |||||
return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask} | |||||
@staticmethod | |||||
def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]: | |||||
""" | |||||
:param word_lens: 每个word的长度 | |||||
:return: | |||||
""" | |||||
tags = [] | |||||
for word_len in word_lens: | |||||
if word_len==1: | |||||
tags.append('S') | |||||
else: | |||||
tags.append('B') | |||||
for _ in range(word_len-2): | |||||
tags.append('M') | |||||
tags.append('E') | |||||
return {'target':tags} | |||||
@staticmethod | |||||
def _gen_bigram(chars:List[str])->List[str]: | |||||
""" | |||||
:param chars: | |||||
:return: | |||||
""" | |||||
return [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])] | |||||
def load(self, path:str, bigram:bool=False)->DataSet: | |||||
""" | |||||
:param path: str | |||||
:param bigram: 是否使用bigram feature | |||||
:return: | |||||
""" | |||||
dataset = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if not line: # 去掉空行 | |||||
continue | |||||
parts = line.split() | |||||
word_lens = map(len, parts) | |||||
chars = list(''.join(parts)) | |||||
tags = self._word_len_to_target(word_lens) | |||||
assert len(chars)==len(tags['target']) | |||||
dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars))) | |||||
if len(dataset)==0: | |||||
raise RuntimeError(f"{path} has no valid data.") | |||||
if bigram: | |||||
dataset.apply_field(self._gen_bigram, field_name='raw_chars', new_field_name='bigrams') | |||||
return dataset | |||||
def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, | |||||
char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, | |||||
bigram_embed_opt:EmbeddingOption=None, L:int=4): | |||||
""" | |||||
支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 | |||||
Option:: | |||||
共同 创造 美好 的 新 世纪 —— 二○○一年 新年 贺词 | |||||
( 二○○○年 十二月 三十一日 ) ( 附 图片 1 张 ) | |||||
女士 们 , 先生 们 , 同志 们 , 朋友 们 : | |||||
paths支持两种格式,第一种是str,第二种是Dict[str, str]. | |||||
Option:: | |||||
# 1. str类型 | |||||
# 1.1 传入具体的文件路径 | |||||
data = SigHanLoader('bmes').process('/path/to/cws/data.txt') # 将读取data.txt的内容 | |||||
# 包含以下的内容data.vocabs['chars']:Vocabulary对象, | |||||
# data.vocabs['target']: Vocabulary对象,根据encoding_type可能会没有该值 | |||||
# data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项 | |||||
# data.datasets['train']: DataSet对象 | |||||
# 包含的field有: | |||||
# raw_chars: list[str], 每个元素是一个汉字 | |||||
# chars: list[int], 每个元素是汉字对应的index | |||||
# target: list[int], 根据encoding_type有对应的变化 | |||||
# 1.2 传入一个目录, 里面必须包含train.txt文件 | |||||
data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt | |||||
# 包含以下的内容data.vocabs['chars']: Vocabulary对象 | |||||
# data.vocabs['target']:Vocabulary对象 | |||||
# data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象; | |||||
# data.datasets['train']: DataSet对象 | |||||
# 包含的field有: | |||||
# raw_chars: list[str], 每个元素是一个汉字 | |||||
# chars: list[int], 每个元素是汉字对应的index | |||||
# target: list[int], 根据encoding_type有对应的变化 | |||||
# data.datasets['dev']: DataSet对象,如果文件夹下包含了dev.txt;内容与data.datasets['train']一样 | |||||
# 2. dict类型, key是文件的名称,value是对应的读取路径. 必须包含'train'这个key | |||||
paths = {'train': '/path/to/train/train.txt', 'test':'/path/to/test/test.txt', 'dev':'/path/to/dev/dev.txt'} | |||||
data = SigHanLoader(paths).process(paths) | |||||
# 结果与传入目录时是一致的,但是可以传入多个数据集。data.datasets中的key将与这里传入的一致 | |||||
:param paths: 支持传入目录,文件路径,以及dict。 | |||||
:param char_vocab_opt: 用于构建chars的vocabulary参数,默认为min_freq=2 | |||||
:param char_embed_opt: 用于读取chars的Embedding的参数,默认不读取pretrained的embedding | |||||
:param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 | |||||
为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e<eos> | |||||
:param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 | |||||
:param L: 当target_type为shift_relay时传入的segment长度 | |||||
:return: | |||||
""" | |||||
# 推荐大家使用这个check_data_loader_paths进行paths的验证 | |||||
paths = check_dataloader_paths(paths) | |||||
datasets = {} | |||||
data = DataBundle() | |||||
bigram = bigram_vocab_opt is not None | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path, bigram=bigram) | |||||
datasets[name] = dataset | |||||
input_fields = [] | |||||
target_fields = [] | |||||
# 创建vocab | |||||
char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) | |||||
char_vocab.from_dataset(datasets['train'], field_name='raw_chars') | |||||
char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') | |||||
data.vocabs[Const.CHAR_INPUT] = char_vocab | |||||
input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET]) | |||||
target_fields.append(Const.TARGET) | |||||
# 创建target | |||||
if self.target_type == 'bmes': | |||||
target_vocab = Vocabulary(unknown=None, padding=None) | |||||
target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) | |||||
target_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
data.vocabs[Const.TARGET] = target_vocab | |||||
if char_embed_opt is not None: | |||||
char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab) | |||||
data.embeddings['chars'] = char_embed | |||||
if bigram: | |||||
bigram_vocab = Vocabulary(**bigram_vocab_opt) | |||||
bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') | |||||
bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') | |||||
data.vocabs['bigrams'] = bigram_vocab | |||||
if bigram_embed_opt is not None: | |||||
bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab) | |||||
data.embeddings['bigrams'] = bigram_embed | |||||
input_fields.append('bigrams') | |||||
if self.target_type == 'shift_relay': | |||||
func = partial(self._clip_target, L=L) | |||||
for name, dataset in datasets.items(): | |||||
res = dataset.apply_field(func, field_name='target') | |||||
relay_target = [res_i[0] for res_i in res] | |||||
relay_mask = [res_i[1] for res_i in res] | |||||
dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) | |||||
dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) | |||||
if self.target_type == 'shift_relay': | |||||
input_fields.extend(['end_seg_mask']) | |||||
target_fields.append('start_seg_mask') | |||||
# 将dataset加入DataInfo | |||||
for name, dataset in datasets.items(): | |||||
dataset.set_input(*input_fields) | |||||
dataset.set_target(*target_fields) | |||||
data.datasets[name] = dataset | |||||
return data | |||||
@staticmethod | |||||
def _clip_target(target:List[int], L:int): | |||||
""" | |||||
只有在target_type为shift_relay的使用 | |||||
:param target: List[int] | |||||
:param L: | |||||
:return: | |||||
""" | |||||
relay_target_i = [] | |||||
tmp = [] | |||||
for j in range(len(target) - 1): | |||||
tmp.append(target[j]) | |||||
if target[j] > target[j + 1]: | |||||
pass | |||||
else: | |||||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||||
tmp = [] | |||||
# 处理未结束的部分 | |||||
if len(tmp) == 0: | |||||
relay_target_i.append(0) | |||||
else: | |||||
tmp.append(target[-1]) | |||||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||||
relay_mask_i = [] | |||||
j = 0 | |||||
while j < len(target): | |||||
seg_len = target[j] + 1 | |||||
if target[j] < L: | |||||
relay_mask_i.extend([0] * (seg_len)) | |||||
else: | |||||
relay_mask_i.extend([1] * (seg_len - L) + [0] * L) | |||||
j = seg_len + j | |||||
return relay_target_i, relay_mask_i | |||||
@@ -0,0 +1,202 @@ | |||||
from fastNLP.io.pipe import Pipe | |||||
from fastNLP.io import DataBundle | |||||
from fastNLP.io.loader import CWSLoader | |||||
from fastNLP import Const | |||||
from itertools import chain | |||||
from fastNLP.io.pipe.utils import _indexize | |||||
from functools import partial | |||||
from fastNLP.io.pipe.cws import _find_and_replace_alpha_spans, _find_and_replace_digit_spans | |||||
def _word_lens_to_relay(word_lens): | |||||
""" | |||||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||||
:param word_lens: | |||||
:return: | |||||
""" | |||||
tags = [] | |||||
for word_len in word_lens: | |||||
tags.extend([idx for idx in range(word_len - 1, -1, -1)]) | |||||
return tags | |||||
def _word_lens_to_end_seg_mask(word_lens): | |||||
""" | |||||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||||
:param word_lens: | |||||
:return: | |||||
""" | |||||
end_seg_mask = [] | |||||
for word_len in word_lens: | |||||
end_seg_mask.extend([0] * (word_len - 1) + [1]) | |||||
return end_seg_mask | |||||
def _word_lens_to_start_seg_mask(word_lens): | |||||
""" | |||||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||||
:param word_lens: | |||||
:return: | |||||
""" | |||||
start_seg_mask = [] | |||||
for word_len in word_lens: | |||||
start_seg_mask.extend([1] + [0] * (word_len - 1)) | |||||
return start_seg_mask | |||||
class CWSShiftRelayPipe(Pipe): | |||||
""" | |||||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | |||||
:param int L: ShiftRelay模型的超参数 | |||||
:param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 | |||||
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] | |||||
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||||
""" | |||||
def __init__(self, dataset_name=None, L=5, replace_num_alpha=True, bigrams=True): | |||||
self.dataset_name = dataset_name | |||||
self.bigrams = bigrams | |||||
self.replace_num_alpha = replace_num_alpha | |||||
self.L = L | |||||
def _tokenize(self, data_bundle): | |||||
""" | |||||
将data_bundle中的'chars'列切分成一个一个的word. | |||||
例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
def split_word_into_chars(raw_chars): | |||||
words = raw_chars.split() | |||||
chars = [] | |||||
for word in words: | |||||
char = [] | |||||
subchar = [] | |||||
for c in word: | |||||
if c=='<': | |||||
subchar.append(c) | |||||
continue | |||||
if c=='>' and subchar[0]=='<': | |||||
char.append(''.join(subchar)) | |||||
subchar = [] | |||||
if subchar: | |||||
subchar.append(c) | |||||
else: | |||||
char.append(c) | |||||
char.extend(subchar) | |||||
chars.append(char) | |||||
return chars | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, | |||||
new_field_name=Const.CHAR_INPUT) | |||||
return data_bundle | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||||
""" | |||||
可以处理的DataSet需要包含raw_words列 | |||||
.. csv-table:: | |||||
:header: "raw_words" | |||||
"上海 浦东 开发 与 法制 建设 同步" | |||||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||||
"..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) | |||||
if self.replace_num_alpha: | |||||
data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) | |||||
data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) | |||||
self._tokenize(data_bundle) | |||||
input_field_names = [Const.CHAR_INPUT] | |||||
target_field_names = [] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(lambda chars:_word_lens_to_relay(map(len, chars)), field_name=Const.CHAR_INPUT, | |||||
new_field_name=Const.TARGET) | |||||
dataset.apply_field(lambda chars:_word_lens_to_start_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT, | |||||
new_field_name='start_seg_mask') | |||||
dataset.apply_field(lambda chars:_word_lens_to_end_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT, | |||||
new_field_name='end_seg_mask') | |||||
dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT, | |||||
new_field_name=Const.CHAR_INPUT) | |||||
target_field_names.append('start_seg_mask') | |||||
input_field_names.append('end_seg_mask') | |||||
if self.bigrams: | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])], | |||||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||||
input_field_names.append('bigrams') | |||||
_indexize(data_bundle, ['chars', 'bigrams'], []) | |||||
func = partial(_clip_target, L=self.L) | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
res = dataset.apply_field(func, field_name='target') | |||||
relay_target = [res_i[0] for res_i in res] | |||||
relay_mask = [res_i[1] for res_i in res] | |||||
dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) | |||||
dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) | |||||
input_field_names.append('relay_target') | |||||
input_field_names.append('relay_mask') | |||||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] + target_field_names | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.CHAR_INPUT) | |||||
data_bundle.set_input(*input_fields) | |||||
data_bundle.set_target(*target_fields) | |||||
return data_bundle | |||||
def process_from_file(self, paths=None) -> DataBundle: | |||||
""" | |||||
:param str paths: | |||||
:return: | |||||
""" | |||||
if self.dataset_name is None and paths is None: | |||||
raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") | |||||
if self.dataset_name is not None and paths is not None: | |||||
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") | |||||
data_bundle = CWSLoader(self.dataset_name).load(paths) | |||||
return self.process(data_bundle) | |||||
def _clip_target(target, L:int): | |||||
""" | |||||
只有在target_type为shift_relay的使用 | |||||
:param target: List[int] | |||||
:param L: | |||||
:return: | |||||
""" | |||||
relay_target_i = [] | |||||
tmp = [] | |||||
for j in range(len(target) - 1): | |||||
tmp.append(target[j]) | |||||
if target[j] > target[j + 1]: | |||||
pass | |||||
else: | |||||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||||
tmp = [] | |||||
# 处理未结束的部分 | |||||
if len(tmp) == 0: | |||||
relay_target_i.append(0) | |||||
else: | |||||
tmp.append(target[-1]) | |||||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||||
relay_mask_i = [] | |||||
j = 0 | |||||
while j < len(target): | |||||
seg_len = target[j] + 1 | |||||
if target[j] < L: | |||||
relay_mask_i.extend([0] * (seg_len)) | |||||
else: | |||||
relay_mask_i.extend([1] * (seg_len - L) + [0] * L) | |||||
j = seg_len + j | |||||
return relay_target_i, relay_mask_i |
@@ -0,0 +1,60 @@ | |||||
import torch | |||||
from fastNLP.modules import LSTM | |||||
from fastNLP.modules import allowed_transitions, ConditionalRandomField | |||||
from fastNLP import seq_len_to_mask | |||||
from torch import nn | |||||
from fastNLP import Const | |||||
import torch.nn.functional as F | |||||
class BiLSTMCRF(nn.Module): | |||||
def __init__(self, char_embed, hidden_size, num_layers, target_vocab=None, bigram_embed=None, trigram_embed=None, | |||||
dropout=0.5): | |||||
super().__init__() | |||||
embed_size = char_embed.embed_size | |||||
self.char_embed = char_embed | |||||
if bigram_embed: | |||||
embed_size += bigram_embed.embed_size | |||||
self.bigram_embed = bigram_embed | |||||
if trigram_embed: | |||||
embed_size += trigram_embed.embed_size | |||||
self.trigram_embed = trigram_embed | |||||
self.lstm = LSTM(embed_size, hidden_size=hidden_size//2, bidirectional=True, batch_first=True, | |||||
num_layers=num_layers) | |||||
self.dropout = nn.Dropout(p=dropout) | |||||
self.fc = nn.Linear(hidden_size, len(target_vocab)) | |||||
transitions = None | |||||
if target_vocab: | |||||
transitions = allowed_transitions(target_vocab, include_start_end=True, encoding_type='bmes') | |||||
self.crf = ConditionalRandomField(num_tags=len(target_vocab), allowed_transitions=transitions) | |||||
def _forward(self, chars, bigrams, trigrams, seq_len, target=None): | |||||
chars = self.char_embed(chars) | |||||
if bigrams is not None: | |||||
bigrams = self.bigram_embed(bigrams) | |||||
chars = torch.cat([chars, bigrams], dim=-1) | |||||
if trigrams is not None: | |||||
trigrams = self.trigram_embed(trigrams) | |||||
chars = torch.cat([chars, trigrams], dim=-1) | |||||
output, _ = self.lstm(chars, seq_len) | |||||
output = self.dropout(output) | |||||
output = self.fc(output) | |||||
output = F.log_softmax(output, dim=-1) | |||||
mask = seq_len_to_mask(seq_len) | |||||
if target is None: | |||||
pred, _ = self.crf.viterbi_decode(output, mask) | |||||
return {Const.OUTPUT:pred} | |||||
else: | |||||
loss = self.crf.forward(output, tags=target, mask=mask) | |||||
return {Const.LOSS:loss} | |||||
def forward(self, chars, seq_len, target, bigrams=None, trigrams=None): | |||||
return self._forward(chars, bigrams, trigrams, seq_len, target) | |||||
def predict(self, chars, seq_len, bigrams=None, trigrams=None): | |||||
return self._forward(chars, bigrams, trigrams, seq_len) |
@@ -1,7 +1,5 @@ | |||||
from torch import nn | from torch import nn | ||||
import torch | import torch | ||||
from fastNLP.embeddings import Embedding | |||||
import numpy as np | |||||
from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay | from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay | ||||
from fastNLP.modules import LSTM | from fastNLP.modules import LSTM | ||||
@@ -21,25 +19,21 @@ class ShiftRelayCWSModel(nn.Module): | |||||
:param num_bigram_per_char: 每个character对应的bigram的数量 | :param num_bigram_per_char: 每个character对应的bigram的数量 | ||||
:param drop_p: Dropout的大小 | :param drop_p: Dropout的大小 | ||||
""" | """ | ||||
def __init__(self, char_embed:Embedding, bigram_embed:Embedding, hidden_size:int=400, num_layers:int=1, | |||||
L:int=6, num_bigram_per_char:int=1, drop_p:float=0.2): | |||||
def __init__(self, char_embed, bigram_embed, hidden_size:int=400, num_layers:int=1, L:int=6, drop_p:float=0.2): | |||||
super().__init__() | super().__init__() | ||||
self.char_embedding = Embedding(char_embed, dropout=drop_p) | |||||
self._pretrained_embed = False | |||||
if isinstance(char_embed, np.ndarray): | |||||
self._pretrained_embed = True | |||||
self.bigram_embedding = Embedding(bigram_embed, dropout=drop_p) | |||||
self.lstm = LSTM(100 * (num_bigram_per_char + 1), hidden_size // 2, num_layers=num_layers, bidirectional=True, | |||||
self.char_embedding = char_embed | |||||
self.bigram_embedding = bigram_embed | |||||
self.lstm = LSTM(char_embed.embed_size+bigram_embed.embed_size, hidden_size // 2, num_layers=num_layers, | |||||
bidirectional=True, | |||||
batch_first=True) | batch_first=True) | ||||
self.feature_fn = FeatureFunMax(hidden_size, L) | self.feature_fn = FeatureFunMax(hidden_size, L) | ||||
self.semi_crf_relay = SemiCRFShiftRelay(L) | self.semi_crf_relay = SemiCRFShiftRelay(L) | ||||
self.feat_drop = nn.Dropout(drop_p) | self.feat_drop = nn.Dropout(drop_p) | ||||
self.reset_param() | self.reset_param() | ||||
# self.feature_fn.reset_parameters() | |||||
def reset_param(self): | def reset_param(self): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'embedding' in name and self._pretrained_embed: | |||||
if 'embedding' in name: | |||||
continue | continue | ||||
if 'bias_hh' in name: | if 'bias_hh' in name: | ||||
nn.init.constant_(param, 0) | nn.init.constant_(param, 0) | ||||
@@ -51,10 +45,8 @@ class ShiftRelayCWSModel(nn.Module): | |||||
nn.init.xavier_uniform_(param) | nn.init.xavier_uniform_(param) | ||||
def get_feats(self, chars, bigrams, seq_len): | def get_feats(self, chars, bigrams, seq_len): | ||||
batch_size, max_len = chars.size() | |||||
chars = self.char_embedding(chars) | chars = self.char_embedding(chars) | ||||
bigrams = self.bigram_embedding(bigrams) | bigrams = self.bigram_embedding(bigrams) | ||||
bigrams = bigrams.view(bigrams.size(0), max_len, -1) | |||||
chars = torch.cat([chars, bigrams], dim=-1) | chars = torch.cat([chars, bigrams], dim=-1) | ||||
feats, _ = self.lstm(chars, seq_len) | feats, _ = self.lstm(chars, seq_len) | ||||
feats = self.feat_drop(feats) | feats = self.feat_drop(feats) |
@@ -0,0 +1,52 @@ | |||||
import sys | |||||
sys.path.append('../../..') | |||||
from fastNLP.io.pipe.cws import CWSPipe | |||||
from reproduction.seqence_labelling.cws.model.bilstm_crf_cws import BiLSTMCRF | |||||
from fastNLP import Trainer, cache_results | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP import EvaluateCallback, BucketSampler, SpanFPreRecMetric, GradientClipCallback | |||||
from torch.optim import Adagrad | |||||
###########hyper | |||||
dataname = 'pku' | |||||
hidden_size = 400 | |||||
num_layers = 1 | |||||
lr = 0.05 | |||||
###########hyper | |||||
@cache_results('{}.pkl'.format(dataname), _refresh=False) | |||||
def get_data(): | |||||
data_bundle = CWSPipe(dataset_name=dataname, bigrams=True, trigrams=False).process_from_file() | |||||
char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.33, word_dropout=0.01, | |||||
model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt') | |||||
bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.33,min_freq=3, word_dropout=0.01, | |||||
model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt') | |||||
return data_bundle, char_embed, bigram_embed | |||||
data_bundle, char_embed, bigram_embed = get_data() | |||||
print(data_bundle) | |||||
model = BiLSTMCRF(char_embed, hidden_size, num_layers, target_vocab=data_bundle.get_vocab('target'), bigram_embed=bigram_embed, | |||||
trigram_embed=None, dropout=0.3) | |||||
model.cuda() | |||||
callbacks = [] | |||||
callbacks.append(EvaluateCallback(data_bundle.get_dataset('test'))) | |||||
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) | |||||
optimizer = Adagrad(model.parameters(), lr=lr) | |||||
metrics = [] | |||||
metric1 = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type='bmes') | |||||
metrics.append(metric1) | |||||
trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer=optimizer, loss=None, | |||||
batch_size=128, sampler=BucketSampler(), update_every=1, | |||||
num_workers=1, n_epochs=10, print_every=5, | |||||
dev_data=data_bundle.get_dataset('dev'), | |||||
metrics=metrics, | |||||
metric_key=None, | |||||
validate_every=-1, save_path=None, use_tqdm=True, device=0, | |||||
callbacks=callbacks, check_code_level=0, dev_batch_size=128) | |||||
trainer.train() |
@@ -1,64 +1,53 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append('../../..') | |||||
from fastNLP import cache_results | from fastNLP import cache_results | ||||
from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader | |||||
from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel | |||||
from fastNLP.io.embed_loader import EmbeddingOption | |||||
from fastNLP.core.vocabulary import VocabularyOption | |||||
from reproduction.seqence_labelling.cws.data.cws_shift_pipe import CWSShiftRelayPipe | |||||
from reproduction.seqence_labelling.cws.model.bilstm_shift_relay import ShiftRelayCWSModel | |||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
from torch.optim import Adam | from torch.optim import Adam | ||||
from fastNLP import BucketSampler | from fastNLP import BucketSampler | ||||
from fastNLP import GradientClipCallback | from fastNLP import GradientClipCallback | ||||
from reproduction.seqence_labelling.cws.model.metric import RelayMetric | from reproduction.seqence_labelling.cws.model.metric import RelayMetric | ||||
# 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果 | |||||
@cache_results(None) | |||||
def prepare_data(): | |||||
data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt, | |||||
bigram_vocab_opt=bigram_vocab_opt, | |||||
bigram_embed_opt=bigram_embed_opt, | |||||
L=L) | |||||
return data | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP import EvaluateCallback | |||||
#########hyper | #########hyper | ||||
L = 4 | L = 4 | ||||
hidden_size = 200 | hidden_size = 200 | ||||
num_layers = 1 | num_layers = 1 | ||||
drop_p = 0.2 | drop_p = 0.2 | ||||
lr = 0.02 | |||||
lr = 0.008 | |||||
data_name = 'pku' | |||||
#########hyper | #########hyper | ||||
device = 0 | device = 0 | ||||
# !!!!这里千万不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 | |||||
# 你们的reproduction路径下,然后设置.gitignore | |||||
file_dir = '/path/to/' | |||||
char_embed_path = '/pretrain/vectors/1grams_t3_m50_corpus.txt' | |||||
bigram_embed_path = '/pretrain/vectors/2grams_t3_m50_corpus.txt' | |||||
bigram_vocab_opt = VocabularyOption(min_freq=3) | |||||
char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path) | |||||
bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path) | |||||
data_name = os.path.basename(file_dir) | |||||
cache_fp = 'caches/{}.pkl'.format(data_name) | cache_fp = 'caches/{}.pkl'.format(data_name) | ||||
@cache_results(_cache_fp=cache_fp, _refresh=True) # 将结果缓存到cache_fp中,这样下次运行就直接读取,而不需要再次运行 | |||||
def prepare_data(): | |||||
data_bundle = CWSShiftRelayPipe(dataset_name=data_name, L=L).process_from_file() | |||||
# 预训练的character embedding和bigram embedding | |||||
char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.5, word_dropout=0.01, | |||||
model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt') | |||||
bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.5, min_freq=3, word_dropout=0.01, | |||||
model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt') | |||||
data = prepare_data(_cache_fp=cache_fp, _refresh=True) | |||||
return data_bundle, char_embed, bigram_embed | |||||
model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'], | |||||
hidden_size=hidden_size, num_layers=num_layers, | |||||
L=L, num_bigram_per_char=1, drop_p=drop_p) | |||||
data, char_embed, bigram_embed = prepare_data() | |||||
sampler = BucketSampler(batch_size=32) | |||||
model = ShiftRelayCWSModel(char_embed=char_embed, bigram_embed=bigram_embed, | |||||
hidden_size=hidden_size, num_layers=num_layers, drop_p=drop_p, L=L) | |||||
sampler = BucketSampler() | |||||
optimizer = Adam(model.parameters(), lr=lr) | optimizer = Adam(model.parameters(), lr=lr) | ||||
clipper = GradientClipCallback(clip_value=5, clip_type='value') | |||||
callbacks = [clipper] | |||||
# if pretrain: | |||||
# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | |||||
# callbacks.append(fixer) | |||||
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, | |||||
update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), | |||||
clipper = GradientClipCallback(clip_value=5, clip_type='value') # 截断太大的梯度 | |||||
evaluator = EvaluateCallback(data.get_dataset('test')) # 额外测试在test集上的效果 | |||||
callbacks = [clipper, evaluator] | |||||
trainer = Trainer(data.get_dataset('train'), model, optimizer=optimizer, loss=None, batch_size=128, sampler=sampler, | |||||
update_every=1, n_epochs=10, print_every=5, dev_data=data.get_dataset('dev'), metrics=RelayMetric(), | |||||
metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, | metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, | ||||
check_code_level=0) | |||||
check_code_level=0, num_workers=1) | |||||
trainer.train() | trainer.train() |
@@ -8,11 +8,10 @@ import torch.nn.functional as F | |||||
from fastNLP import Const | from fastNLP import Const | ||||
class CNNBiLSTMCRF(nn.Module): | class CNNBiLSTMCRF(nn.Module): | ||||
def __init__(self, embed, char_embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'): | |||||
def __init__(self, embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'): | |||||
super().__init__() | super().__init__() | ||||
self.embedding = embed | self.embedding = embed | ||||
self.char_embedding = char_embed | |||||
self.lstm = LSTM(input_size=self.embedding.embedding_dim+self.char_embedding.embedding_dim, | |||||
self.lstm = LSTM(input_size=self.embedding.embedding_dim, | |||||
hidden_size=hidden_size//2, num_layers=num_layers, | hidden_size=hidden_size//2, num_layers=num_layers, | ||||
bidirectional=True, batch_first=True) | bidirectional=True, batch_first=True) | ||||
self.fc = nn.Linear(hidden_size, len(tag_vocab)) | self.fc = nn.Linear(hidden_size, len(tag_vocab)) | ||||
@@ -32,9 +31,7 @@ class CNNBiLSTMCRF(nn.Module): | |||||
nn.init.zeros_(param) | nn.init.zeros_(param) | ||||
def _forward(self, words, seq_len, target=None): | def _forward(self, words, seq_len, target=None): | ||||
word_embeds = self.embedding(words) | |||||
char_embeds = self.char_embedding(words) | |||||
words = torch.cat((word_embeds, char_embeds), dim=-1) | |||||
words = self.embedding(words) | |||||
outputs, _ = self.lstm(words, seq_len) | outputs, _ = self.lstm(words, seq_len) | ||||
self.dropout(outputs) | self.dropout(outputs) | ||||
@@ -1,7 +1,7 @@ | |||||
import sys | import sys | ||||
sys.path.append('../../..') | sys.path.append('../../..') | ||||
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding | |||||
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding, StackEmbedding | |||||
from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF | from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
@@ -22,7 +22,7 @@ def load_data(): | |||||
paths = {'test':"NER/corpus/CoNLL-2003/eng.testb", | paths = {'test':"NER/corpus/CoNLL-2003/eng.testb", | ||||
'train':"NER/corpus/CoNLL-2003/eng.train", | 'train':"NER/corpus/CoNLL-2003/eng.train", | ||||
'dev':"NER/corpus/CoNLL-2003/eng.testa"} | 'dev':"NER/corpus/CoNLL-2003/eng.testa"} | ||||
data = Conll2003NERPipe(encoding_type=encoding_type, target_pad_val=0).process_from_file(paths) | |||||
data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths) | |||||
return data | return data | ||||
data = load_data() | data = load_data() | ||||
print(data) | print(data) | ||||
@@ -33,8 +33,9 @@ word_embed = StaticEmbedding(vocab=data.get_vocab('words'), | |||||
model_dir_or_name='en-glove-6b-100d', | model_dir_or_name='en-glove-6b-100d', | ||||
requires_grad=True, lower=True, word_dropout=0.01, dropout=0.5) | requires_grad=True, lower=True, word_dropout=0.01, dropout=0.5) | ||||
word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std() | word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std() | ||||
embed = StackEmbedding([word_embed, char_embed]) | |||||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||||
model = CNNBiLSTMCRF(embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||||
encoding_type=encoding_type) | encoding_type=encoding_type) | ||||
callbacks = [ | callbacks = [ | ||||
@@ -2,7 +2,7 @@ import sys | |||||
sys.path.append('../../..') | sys.path.append('../../..') | ||||
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding | |||||
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding, StackEmbedding | |||||
from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF | from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
@@ -35,7 +35,7 @@ def cache(): | |||||
char_embed = CNNCharEmbedding(vocab=data.vocabs['words'], embed_size=30, char_emb_size=30, filter_nums=[30], | char_embed = CNNCharEmbedding(vocab=data.vocabs['words'], embed_size=30, char_emb_size=30, filter_nums=[30], | ||||
kernel_sizes=[3], dropout=dropout) | kernel_sizes=[3], dropout=dropout) | ||||
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | ||||
model_dir_or_name='en-glove-100d', | |||||
model_dir_or_name='en-glove-6b-100d', | |||||
requires_grad=True, | requires_grad=True, | ||||
normalize=normalize, | normalize=normalize, | ||||
word_dropout=0.01, | word_dropout=0.01, | ||||
@@ -47,7 +47,8 @@ data, char_embed, word_embed = cache() | |||||
print(data) | print(data) | ||||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||||
embed = StackEmbedding([word_embed, char_embed]) | |||||
model = CNNBiLSTMCRF(embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||||
encoding_type=encoding_type, dropout=dropout) | encoding_type=encoding_type, dropout=dropout) | ||||
callbacks = [ | callbacks = [ | ||||
@@ -0,0 +1,13 @@ | |||||
import unittest | |||||
import os | |||||
from fastNLP.io.loader import CWSLoader | |||||
class CWSLoaderTest(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
def test_download(self): | |||||
dataset_names = ['pku', 'cityu', 'as', 'msra'] | |||||
for dataset_name in dataset_names: | |||||
with self.subTest(dataset_name=dataset_name): | |||||
data_bundle = CWSLoader(dataset_name=dataset_name).load() | |||||
print(data_bundle) |
@@ -0,0 +1,13 @@ | |||||
import unittest | |||||
import os | |||||
from fastNLP.io.pipe.cws import CWSPipe | |||||
class CWSPipeTest(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
def test_process_from_file(self): | |||||
dataset_names = ['pku', 'cityu', 'as', 'msra'] | |||||
for dataset_name in dataset_names: | |||||
with self.subTest(dataset_name=dataset_name): | |||||
data_bundle = CWSPipe(dataset_name=dataset_name).process_from_file() | |||||
print(data_bundle) |