Browse Source

1. 增加中文NER相关的loader和pipe; 2. 对应修改sequence_labeling的代码; 3.增加部分测试代码

tags/v0.4.10
yh 5 years ago
parent
commit
511f41dda1
24 changed files with 704 additions and 440 deletions
  1. +5
    -42
      fastNLP/core/dataset.py
  2. +62
    -11
      fastNLP/embeddings/bert_embedding.py
  3. +4
    -2
      fastNLP/embeddings/static_embedding.py
  4. +6
    -0
      fastNLP/io/__init__.py
  5. +31
    -8
      fastNLP/io/data_bundle.py
  6. +55
    -38
      fastNLP/io/file_utils.py
  7. +4
    -0
      fastNLP/io/loader/__init__.py
  8. +32
    -66
      fastNLP/io/loader/classification.py
  9. +175
    -3
      fastNLP/io/loader/conll.py
  10. +7
    -1
      fastNLP/io/pipe/__init__.py
  11. +132
    -33
      fastNLP/io/pipe/conll.py
  12. +2
    -2
      fastNLP/io/pipe/matching.py
  13. +28
    -10
      fastNLP/io/pipe/utils.py
  14. +2
    -1
      fastNLP/modules/encoder/bert.py
  15. +0
    -115
      reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py
  16. +0
    -0
      reproduction/seqence_labelling/chinese_ner/data/__init__.py
  17. +18
    -15
      reproduction/seqence_labelling/chinese_ner/train_bert.py
  18. +56
    -14
      reproduction/seqence_labelling/chinese_ner/train_cn_ner.py
  19. +0
    -1
      reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py
  20. +20
    -45
      reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py
  21. +18
    -33
      reproduction/seqence_labelling/ner/train_ontonote.py
  22. +14
    -0
      test/embeddings/test_bert_embedding.py
  23. +21
    -0
      test/io/loader/test_conll_loader.py
  24. +12
    -0
      test/io/pipe/test_conll.py

+ 5
- 42
fastNLP/core/dataset.py View File

@@ -613,6 +613,7 @@ class DataSet(object):
raise e
else:
raise KeyError("{} is not a valid field name.".format(name))
return self
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
"""
@@ -636,6 +637,7 @@ class DataSet(object):
raise e
else:
raise KeyError("{} is not a valid field name.".format(name))
return self
def set_ignore_type(self, *field_names, flag=True):
"""
@@ -652,6 +654,7 @@ class DataSet(object):
self.field_arrays[name].ignore_type = flag
else:
raise KeyError("{} is not a valid field name.".format(name))
return self
def set_padder(self, field_name, padder):
"""
@@ -667,6 +670,7 @@ class DataSet(object):
if field_name not in self.field_arrays:
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_padder(padder)
return self
def set_pad_val(self, field_name, pad_val):
"""
@@ -678,6 +682,7 @@ class DataSet(object):
if field_name not in self.field_arrays:
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_pad_val(pad_val)
return self
def get_input_name(self):
"""
@@ -868,48 +873,6 @@ class DataSet(object):
return train_set, dev_set
@classmethod
def read_csv(cls, csv_path, headers=None, sep=",", dropna=True):
r"""
.. warning::
此方法会在下个版本移除,请使用 :class:`fastNLP.io.CSVLoader`
从csv_path路径下以csv的格式读取数据。

:param str csv_path: 从哪里读取csv文件
:param list[str] headers: 如果为None,则使用csv文件的第一行作为header; 如果传入list(str), 则元素的个数必须
与csv文件中每行的元素个数相同。
:param str sep: 分割符
:param bool dropna: 是否忽略与header数量不一致行。
:return: 读取后的 :class:`~fastNLP.读取后的DataSet`。
"""
warnings.warn('DataSet.read_csv is deprecated, use CSVLoader instead',
category=DeprecationWarning)
with open(csv_path, "r", encoding='utf-8') as f:
start_idx = 0
if headers is None:
headers = f.readline().rstrip('\r\n')
headers = headers.split(sep)
start_idx += 1
else:
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(
type(headers))
_dict = {}
for col in headers:
_dict[col] = []
for line_idx, line in enumerate(f, start_idx):
contents = line.rstrip('\r\n').split(sep)
if len(contents) != len(headers):
if dropna:
continue
else:
# TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts." \
.format(line_idx, len(contents), len(headers)))
for header, content in zip(headers, contents):
_dict[header].append(content)
return cls(_dict)
def save(self, path):
"""
保存DataSet.


+ 62
- 11
fastNLP/embeddings/bert_embedding.py View File

@@ -61,6 +61,9 @@ class BertEmbedding(ContextualEmbedding):

# 根据model_dir_or_name检查是否存在并下载
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'):
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
" faster speed.")
model_url = _get_embedding_url('bert', model_dir_or_name.lower())
model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
@@ -91,19 +94,33 @@ class BertEmbedding(ContextualEmbedding):
:param torch.LongTensor words: [batch_size, max_len]
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._word_sep_index)
words = self.drop_word(words)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._word_sep_index)
outputs = self._get_sent_reprs(words)
if outputs is not None:
return self.dropout(words)
return self.dropout(outputs)
outputs = self.model(words)
outputs = torch.cat([*outputs], dim=-1)

return self.dropout(outputs)

def drop_word(self, words):
"""
按照设定随机将words设置为unknown_index。

:param torch.LongTensor words: batch_size x max_len
:return:
"""
if self.word_dropout > 0 and self.training:
with torch.no_grad():
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._word_sep_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._word_sep_index)
return words

@property
def requires_grad(self):
"""
@@ -134,10 +151,12 @@ class BertWordPieceEncoder(nn.Module):
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取
[CLS]做预测,一般该值为True。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
:param bool requires_grad: 是否需要gradient。
"""
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1',
pooled_cls: bool = False, requires_grad: bool=False):
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool=False):
super().__init__()

if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
@@ -150,8 +169,12 @@ class BertWordPieceEncoder(nn.Module):
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls)
self._sep_index = self.model._sep_index
self._wordpiece_unk_index = self.model._wordpiece_unknown_index
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
self.requires_grad = requires_grad
self.word_dropout = word_dropout
self.dropout_layer = nn.Dropout(dropout)

@property
def requires_grad(self):
@@ -199,13 +222,41 @@ class BertWordPieceEncoder(nn.Module):
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。

:param words: batch_size x max_len
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话. 如果不传入,则自动生成(大部分情况,都不需要输入),
第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
with torch.no_grad():
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
if token_type_ids is None:
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long()

word_pieces = self.drop_word(word_pieces)
outputs = self.model(word_pieces, token_type_ids)
outputs = torch.cat([*outputs], dim=-1)

return outputs
return self.dropout_layer(outputs)

def drop_word(self, words):
"""
按照设定随机将words设置为unknown_index。

:param torch.LongTensor words: batch_size x max_len
:return:
"""
if self.word_dropout > 0 and self.training:
with torch.no_grad():
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._wordpiece_unk_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._wordpiece_unk_index)
return words


class _WordBertModel(nn.Module):
@@ -288,11 +339,11 @@ class _WordBertModel(nn.Module):
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces)
word_to_wordpieces.append(word_pieces)
word_pieces_lengths.append(len(word_pieces))
print("Found(Or seg into word pieces) {} words out of {}.".format(found_count, len(vocab)))
self._cls_index = self.tokenzier.vocab['[CLS]']
self._sep_index = self.tokenzier.vocab['[SEP]']
self._word_pad_index = vocab.padding_idx
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
print("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
self.word_to_wordpieces = np.array(word_to_wordpieces)
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False)
print("Successfully generate word pieces.")
@@ -339,7 +390,7 @@ class _WordBertModel(nn.Module):
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).float()
token_type_ids = token_type_ids.eq(0).long()
else:
token_type_ids = torch.zeros_like(word_pieces)
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算


+ 4
- 2
fastNLP/embeddings/static_embedding.py View File

@@ -45,7 +45,7 @@ class StaticEmbedding(TokenEmbedding):
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。
:param int embedding_dim: 随机初始化的embedding的维度,仅在model_dir_or_name为None时有效
:param int embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name
:param bool requires_grad: 是否需要gradient. 默认为True
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
@@ -55,9 +55,11 @@ class StaticEmbedding(TokenEmbedding):
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True,
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=-1, requires_grad: bool=True,
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs):
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if embedding_dim>0:
model_dir_or_name = None

# 得到cache_path
if model_dir_or_name is None:


+ 6
- 0
fastNLP/io/__init__.py View File

@@ -30,6 +30,9 @@ __all__ = [
'Conll2003NERLoader',
'OntoNotesNERLoader',
'CTBLoader',
"MsraNERLoader",
"WeiboNERLoader",
"PeopleDailyNERLoader",

'CSVLoader',
'JsonLoader',
@@ -50,6 +53,9 @@ __all__ = [

"Conll2003NERPipe",
"OntoNotesNERPipe",
"MsraNERPipe",
"PeopleDailyPipe",
"WeiboNERPipe",

"MatchingBertPipe",
"RTEBertPipe",


+ 31
- 8
fastNLP/io/data_bundle.py View File

@@ -133,19 +133,21 @@ class DataBundle:

:param ~fastNLP.Vocabulary vocab: 词表
:param str field_name: 这个vocab对应的field名称
:return:
:return: self
"""
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports."
self.vocabs[field_name] = vocab
return self

def set_dataset(self, dataset, name):
"""

:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet
:param str name: dataset的名称
:return:
:return: self
"""
self.datasets[name] = dataset
return self

def get_dataset(self, name:str)->DataSet:
"""
@@ -165,7 +167,7 @@ class DataBundle:
"""
return self.vocabs[field_name]

def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True):
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True):
"""
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作::

@@ -176,18 +178,21 @@ class DataBundle:
:param bool flag: 将field_name的input状态设置为flag
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
行的数据进行类型和维度推断本列的数据的类型和维度。
:param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:return self
"""
for field_name in field_names:
for name, dataset in self.datasets.items():
if not ignore_miss_field and not dataset.has_field(field_name):
if not ignore_miss_dataset and not dataset.has_field(field_name):
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}")
if not dataset.has_field(field_name):
continue
else:
dataset.set_input(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)
return self

def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True):
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True):
"""
将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作::

@@ -198,16 +203,34 @@ class DataBundle:
:param bool flag: 将field_name的target状态设置为flag
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
行的数据进行类型和维度推断本列的数据的类型和维度。
:param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错
:return self
"""
for field_name in field_names:
for name, dataset in self.datasets.items():
if not ignore_miss_field and not dataset.has_field(field_name):
if not ignore_miss_dataset and not dataset.has_field(field_name):
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}")
if not dataset.has_field(field_name):
continue
else:
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)
return self

def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True):
"""
将DataBundle中所有的field_name复制一份叫new_field_name.

:param str field_name:
:param str new_field_name:
:param bool ignore_miss_dataset: 若DataBundle中的DataSet的
:return: self
"""
for name, dataset in self.datasets.items():
if dataset.has_field(field_name=field_name):
dataset.copy_field(field_name=field_name, new_field_name=new_field_name)
elif ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name}.")
return self

def __repr__(self):
_str = 'In total {} datasets:\n'.format(len(self.datasets))


+ 55
- 38
fastNLP/io/file_utils.py View File

@@ -27,6 +27,7 @@ PRETRAINED_BERT_MODEL_DIR = {
'cn': 'bert-chinese-wwm.zip',
'cn-base': 'bert-base-chinese.zip',
'cn-wwm': 'bert-chinese-wwm.zip',
'cn-wwm-ext': "bert-chinese-wwm-ext.zip"
}

PRETRAINED_ELMO_MODEL_DIR = {
@@ -56,7 +57,7 @@ PRETRAIN_STATIC_FILES = {
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip",
'en-fasttext-crawl': "crawl-300d-2M.vec.zip",

'cn': "tencent_cn.txt.zip",
'cn': "tencent_cn.zip",
'cn-tencent': "tencent_cn.txt.zip",
'cn-fasttext': "cc.zh.300.vec.gz",
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip',
@@ -71,7 +72,10 @@ DATASET_DIR = {
"qnli": "QNLI.zip",
"sst-2": "SST-2.zip",
"sst": "SST.zip",
"rte": "RTE.zip"
"rte": "RTE.zip",
"msra-ner": "MSRA_NER.zip",
"peopledaily": "peopledaily.zip",
"weibo-ner": "weibo_NER.zip"
}

PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,
@@ -320,42 +324,44 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
# GET file object
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"})
if req.status_code == 200:
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=1)
fd, temp_filename = tempfile.mkstemp()
print("%s not found in cache, downloading to %s" % (url, temp_filename))

with open(temp_filename, "wb") as temp_file:
for chunk in req.iter_content(chunk_size=1024 * 16):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
print(f"Finish download from {url}.")

# 开始解压
delete_temp_dir = None
if suffix in ('.zip', '.tar.gz'):
uncompress_temp_dir = tempfile.mkdtemp()
delete_temp_dir = uncompress_temp_dir
print(f"Start to uncompress file to {uncompress_temp_dir}")
if suffix == '.zip':
unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
else:
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir))
filenames = os.listdir(uncompress_temp_dir)
if len(filenames) == 1:
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])):
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0])

cache_path.mkdir(parents=True, exist_ok=True)
print("Finish un-compressing file.")
else:
uncompress_temp_dir = temp_filename
cache_path = str(cache_path) + suffix
success = False
fd, temp_filename = tempfile.mkstemp()
uncompress_temp_dir = None
try:
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=1)
print("%s not found in cache, downloading to %s" % (url, temp_filename))

with open(temp_filename, "wb") as temp_file:
for chunk in req.iter_content(chunk_size=1024 * 16):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
print(f"Finish download from {url}")

# 开始解压
if suffix in ('.zip', '.tar.gz', '.gz'):
uncompress_temp_dir = tempfile.mkdtemp()
print(f"Start to uncompress file to {uncompress_temp_dir}")
if suffix == '.zip':
unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
elif suffix == '.gz':
ungzip_file(temp_filename, uncompress_temp_dir, dir_name)
else:
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir))
filenames = os.listdir(uncompress_temp_dir)
if len(filenames) == 1:
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])):
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0])

cache_path.mkdir(parents=True, exist_ok=True)
print("Finish un-compressing file.")
else:
uncompress_temp_dir = temp_filename
cache_path = str(cache_path) + suffix

# 复制到指定的位置
print(f"Copy file to {cache_path}")
if os.path.isdir(uncompress_temp_dir):
@@ -377,10 +383,12 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
os.remove(cache_path)
else:
shutil.rmtree(cache_path)
if delete_temp_dir:
shutil.rmtree(delete_temp_dir)
os.close(fd)
os.remove(temp_filename)
if os.path.isdir(uncompress_temp_dir):
shutil.rmtree(uncompress_temp_dir)
elif os.path.isfile(uncompress_temp_dir):
os.remove(uncompress_temp_dir)
return get_filepath(cache_path)
else:
raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.")
@@ -402,6 +410,15 @@ def untar_gz_file(file: Path, to: Path):
tar.extractall(to)


def ungzip_file(file: str, to: str, filename:str):
import gzip

g_file = gzip.GzipFile(file)
with open(os.path.join(to, filename), 'wb+') as f:
f.write(g_file.read())
g_file.close()


def match_file(dir_name: str, cache_dir: Path) -> str:
"""
匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。


+ 4
- 0
fastNLP/io/loader/__init__.py View File

@@ -58,6 +58,9 @@ __all__ = [
'Conll2003NERLoader',
'OntoNotesNERLoader',
'CTBLoader',
"MsraNERLoader",
"PeopleDailyNERLoader",
"WeiboNERLoader",

# 'CSVLoader',
# 'JsonLoader',
@@ -77,3 +80,4 @@ from .cws import CWSLoader
from .json import JsonLoader
from .loader import Loader
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader

+ 32
- 66
fastNLP/io/loader/classification.py View File

@@ -6,6 +6,8 @@ import os
import random
import shutil
import numpy as np
import glob
import time


class YelpLoader(Loader):
@@ -57,7 +59,7 @@ class YelpLoader(Loader):


class YelpFullLoader(YelpLoader):
def download(self, dev_ratio: float = 0.1, seed: int = 0):
def download(self, dev_ratio: float = 0.1, re_download:bool=False):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

@@ -68,35 +70,23 @@ class YelpFullLoader(YelpLoader):
dev.csv三个文件。

:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
:param int seed: 划分dev时的随机数种子
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str, 数据集的目录地址
"""
dataset_name = 'yelp-review-full'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载
re_download = True
if dev_ratio > 0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2:
for line in f1:
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
random.seed(int(seed))
try:
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \
@@ -116,44 +106,32 @@ class YelpFullLoader(YelpLoader):


class YelpPolarityLoader(YelpLoader):
def download(self, dev_ratio: float = 0.1, seed: int = 0):
def download(self, dev_ratio: float = 0.1, re_download=False):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)

根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分dev_ratio这么多作为dev

:param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据. 如果为0,则不划分dev
:param int seed: 划分dev时的随机数种子
:param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据。 如果为0,则不划分dev。
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str, 数据集的目录地址
"""
dataset_name = 'yelp-review-polarity'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求
re_download = True
if dev_ratio > 0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2:
for line in f1:
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)

if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
random.seed(int(seed))
try:
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \
@@ -209,7 +187,7 @@ class IMDBLoader(Loader):
return dataset
def download(self, dev_ratio: float = 0.1, seed: int = 0):
def download(self, dev_ratio: float = 0.1, re_download=False):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

@@ -218,34 +196,22 @@ class IMDBLoader(Loader):
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev

:param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev
:param int seed: 划分dev时的随机数种子
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str, 数据集的目录地址
"""
dataset_name = 'aclImdb'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求
re_download = True
if dev_ratio > 0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.txt'), 'r', encoding='utf-8') as f2:
for line in f1:
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
random.seed(int(seed))
try:
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \


+ 175
- 3
fastNLP/io/loader/conll.py View File

@@ -4,10 +4,12 @@ from .loader import Loader
from ...core.dataset import DataSet
from ..file_reader import _read_conll
from ...core.instance import Instance
from .. import DataBundle
from ..utils import check_loader_paths
from ...core.const import Const

import glob
import os
import shutil
import time
import random

class ConllLoader(Loader):
"""
@@ -262,3 +264,173 @@ class CTBLoader(Loader):

def _load(self, path:str):
pass


class CNNERLoader(Loader):
def _load(self, path:str):
"""
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample

Example::

我 O
们 O
变 O
而 O
以 O
书 O
会 O
...

:param str path: 文件路径
:return: DataSet,包含raw_words列和target列
"""
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
raw_chars = []
target = []
for line in f:
line = line.strip()
if line:
parts = line.split()
if len(parts) == 1: # 网上下载的数据有一些列少tag,默认补充O
parts.append('O')
raw_chars.append(parts[0])
target.append(parts[1])
else:
if raw_chars:
ds.append(Instance(raw_chars=raw_chars, target=target))
raw_chars = []
target = []
return ds


class MsraNERLoader(CNNERLoader):
"""
读取MSRA-NER数据,数据中的格式应该类似与下列的内容

Example::

我 O
们 O
变 O
而 O
以 O
书 O
会 O
...

读取后的DataSet包含以下的field

.. csv-table:: target列是基于BIO的编码方式
:header: "raw_chars", "target"

"[我, 们, 变...]", "[O, O, ...]"
"[中, 共, 中, ...]", "[B-ORG, I-ORG, I-ORG, ...]"
"[...]", "[...]"

"""
def __init__(self):
super().__init__()

def download(self, dev_ratio:float=0.1, re_download:bool=False)->str:
"""
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language
Processing Bakeoff: Word Segmentation and Named Entity Recognition.

根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.conll, test.conll,
dev.conll三个文件。

:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str, 数据集的目录地址
:return:
"""
dataset_name = 'msra-ner'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)

if not os.path.exists(os.path.join(data_dir, 'dev.conll')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
try:
with open(os.path.join(data_dir, 'train.conll'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.conll'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.conll'), 'w', encoding='utf-8') as f2:
lines = [] # 一个sample包含很多行
for line in f:
line = line.strip()
if line:
lines.append(line)
else:
if random.random() < dev_ratio:
f2.write('\n'.join(lines) + '\n\n')
else:
f1.write('\n'.join(lines) + '\n\n')
lines.clear()
os.remove(os.path.join(data_dir, 'train.conll'))
os.renames(os.path.join(data_dir, 'middle_file.conll'), os.path.join(data_dir, 'train.conll'))
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.conll')):
os.remove(os.path.join(data_dir, 'middle_file.conll'))

return data_dir


class WeiboNERLoader(CNNERLoader):
def __init__(self):
super().__init__()

def download(self)->str:
"""
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for
Chinese Social Media with Jointly Trained Embeddings.

:return: str
"""
dataset_name = 'weibo-ner'
data_dir = self._get_dataset_path(dataset_name=dataset_name)

return data_dir


class PeopleDailyNERLoader(CNNERLoader):
"""
支持加载的数据格式如下

Example::

当 O
希 O
望 O
工 O
程 O
救 O
助 O
的 O
百 O

读取后的DataSet包含以下的field

.. csv-table:: target列是基于BIO的编码方式
:header: "raw_chars", "target"

"[我, 们, 变...]", "[O, O, ...]"
"[中, 共, 中, ...]", "[B-ORG, I-ORG, I-ORG, ...]"
"[...]", "[...]"

"""
def __init__(self):
super().__init__()

def download(self) -> str:
dataset_name = 'peopledaily'
data_dir = self._get_dataset_path(dataset_name=dataset_name)

return data_dir

+ 7
- 1
fastNLP/io/pipe/__init__.py View File

@@ -8,6 +8,8 @@ Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``proce

"""
__all__ = [
"Pipe",

"YelpFullPipe",
"YelpPolarityPipe",
"SSTPipe",
@@ -16,6 +18,9 @@ __all__ = [

"Conll2003NERPipe",
"OntoNotesNERPipe",
"MsraNERPipe",
"WeiboNERPipe",
"PeopleDailyPipe",

"MatchingBertPipe",
"RTEBertPipe",
@@ -32,6 +37,7 @@ __all__ = [
]

from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe
from .pipe import Pipe

+ 132
- 33
fastNLP/io/pipe/conll.py View File

@@ -4,6 +4,8 @@ from .utils import iob2, iob2bioes
from ...core.const import Const
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
from .utils import _indexize, _add_words_field
from .utils import _add_chars_field
from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader


class _NERPipe(Pipe):
@@ -17,7 +19,7 @@ class _NERPipe(Pipe):

:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。
"""

def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0):
@@ -32,31 +34,16 @@ class _NERPipe(Pipe):
"""
支持的DataSet的field为

.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader
.. csv-table::
:header: "raw_words", "target"

"[Nadim, Ladki]", "[B-PER, I-PER]"
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
"[...]", "[...]"


:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。
在传入DataBundle基础上原位修改。
:return: DataBundle

Example::

data_bundle = Conll2003Loader().load('/path/to/conll2003/')
data_bundle = Conll2003NERPipe().process(data_bundle)

# 获取train
tr_data = data_bundle.get_dataset('train')

# 获取target这个field的词表
target_vocab = data_bundle.get_vocab('target')
# 获取words这个field的词表
word_vocab = data_bundle.get_vocab('words')

"""
# 转换tag
for name, dataset in data_bundle.datasets.items():
@@ -79,18 +66,6 @@ class _NERPipe(Pipe):

return data_bundle

def process_from_file(self, paths) -> DataBundle:
"""

:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。
:return: DataBundle
"""
# 读取数据
data_bundle = Conll2003NERLoader().load(paths)
data_bundle = self.process(data_bundle)

return data_bundle


class Conll2003NERPipe(_NERPipe):
"""
@@ -102,8 +77,8 @@ class Conll2003NERPipe(_NERPipe):
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader
:header: "raw_words", "words", "target", "seq_len"

"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 10
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2
"[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4]", 6
"[...]", "[...]", "[...]", .

raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
@@ -134,10 +109,13 @@ class OntoNotesNERPipe(_NERPipe):
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader
:header: "raw_words", "words", "target", "seq_len"

"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 6
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2
"[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4]", 6
"[...]", "[...]", "[...]", .

raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。

:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。
@@ -146,3 +124,124 @@ class OntoNotesNERPipe(_NERPipe):
def process_from_file(self, paths):
data_bundle = OntoNotesNERLoader().load(paths)
return self.process(data_bundle)


class _CNNERPipe(Pipe):
"""
中文NER任务的处理Pipe, 该Pipe会(1)复制raw_chars列,并命名为chars; (2)在chars, target列建立词表
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将chars,target列根据相应的
Vocabulary转换为index。

raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。

:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。
"""

def __init__(self, encoding_type: str = 'bio', target_pad_val=0):
if encoding_type == 'bio':
self.convert_tag = iob2
else:
self.convert_tag = lambda words: iob2bioes(iob2(words))
self.target_pad_val = int(target_pad_val)

def process(self, data_bundle: DataBundle) -> DataBundle:
"""
支持的DataSet的field为

.. csv-table::
:header: "raw_chars", "target"

"[相, 比, 之, 下,...]", "[O, O, O, O, ...]"
"[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]"
"[...]", "[...]"

raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。

:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。
在传入DataBundle基础上原位修改。
:return: DataBundle
"""
# 转换tag
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)

_add_chars_field(data_bundle, lower=False)

# index
_indexize(data_bundle, input_field_name=Const.CHAR_INPUT, target_field_name=Const.TARGET)

input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]

for name, dataset in data_bundle.datasets.items():
dataset.set_pad_val(Const.TARGET, self.target_pad_val)
dataset.add_seq_len(Const.CHAR_INPUT)

data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)

return data_bundle


class MsraNERPipe(_CNNERPipe):
"""
处理MSRA-NER的数据,处理之后的DataSet的field情况为

.. csv-table::
:header: "raw_chars", "chars", "target", "seq_len"

"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21
"[...]", "[...]", "[...]", .

raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。

"""
def process_from_file(self, paths=None) -> DataBundle:
data_bundle = MsraNERLoader().load(paths)
return self.process(data_bundle)


class PeopleDailyPipe(_CNNERPipe):
"""
处理people daily的ner的数据,处理之后的DataSet的field情况为

.. csv-table::
:header: "raw_chars", "chars", "target", "seq_len"

"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21
"[...]", "[...]", "[...]", .

raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
"""
def process_from_file(self, paths=None) -> DataBundle:
data_bundle = PeopleDailyNERLoader().load(paths)
return self.process(data_bundle)


class WeiboNERPipe(_CNNERPipe):
"""
处理weibo的ner的数据,处理之后的DataSet的field情况为

.. csv-table::
:header: "raw_chars", "chars", "target", "seq_len"

"[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11
"[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21
"[...]", "[...]", "[...]", .

raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。

:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为0。
"""
def process_from_file(self, paths=None) -> DataBundle:
data_bundle = WeiboNERLoader().load(paths)
return self.process(data_bundle)

+ 2
- 2
fastNLP/io/pipe/matching.py View File

@@ -50,8 +50,8 @@ class MatchingBertPipe(Pipe):
dataset.drop(lambda x: x[Const.TARGET] == '-')

for name, dataset in data_bundle.datasets.items():
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0))
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1))
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), )
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), )

if self.lower:
for name, dataset in data_bundle.datasets.items():


+ 28
- 10
fastNLP/io/pipe/utils.py View File

@@ -76,25 +76,27 @@ def _raw_split(sent):
return sent.split()


def _indexize(data_bundle):
def _indexize(data_bundle, input_field_name=Const.INPUT, target_field_name=Const.TARGET):
"""
在dataset中的"words"列建立词表,"target"列建立词表,并把词表加入到data_bundle中。
在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。

:param data_bundle:
:param: str input_field_name:
:param: str target_field_name: 这一列的vocabulary没有unknown和padding
:return:
"""
src_vocab = Vocabulary()
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=input_field_name,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name)

tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.TARGET)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name)
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name)

data_bundle.set_vocab(src_vocab, Const.INPUT)
data_bundle.set_vocab(tgt_vocab, Const.TARGET)
data_bundle.set_vocab(src_vocab, input_field_name)
data_bundle.set_vocab(tgt_vocab, target_field_name)

return data_bundle

@@ -107,14 +109,30 @@ def _add_words_field(data_bundle, lower=False):
:param bool lower:是否要小写化
:return: 传入的DataBundle
"""
for name, dataset in data_bundle.datasets.items():
dataset.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT)
data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True)

if lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUT].lower()
return data_bundle


def _add_chars_field(data_bundle, lower=False):
"""
给data_bundle中的dataset中复制一列chars. 并根据lower参数判断是否需要小写化

:param data_bundle:
:param bool lower:是否要小写化
:return: 传入的DataBundle
"""
data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True)

if lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.CHAR_INPUT].lower()
return data_bundle


def _drop_empty_instance(data_bundle, field_name):
"""
删除data_bundle的DataSet中存在的某个field为空的情况


+ 2
- 1
fastNLP/modules/encoder/bert.py View File

@@ -868,6 +868,7 @@ class _WordPieceBertModel(nn.Module):

self._cls_index = self.tokenzier.vocab['[CLS]']
self._sep_index = self.tokenzier.vocab['[SEP]']
self._wordpiece_unknown_index = self.tokenzier.vocab['[UNK]']
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
self.pooled_cls = pooled_cls

@@ -919,7 +920,7 @@ class _WordPieceBertModel(nn.Module):
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1)))
for l_index, l in enumerate(self.layers):
bert_output = bert_outputs[l]
if l==len(bert_outputs) and self.pooled_cls:
if l in (len(bert_outputs)-1, -1) and self.pooled_cls:
bert_output[:, 0] = pooled_cls
outputs[l_index] = bert_output
return outputs

+ 0
- 115
reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py View File

@@ -1,115 +0,0 @@


from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from fastNLP.io import ConllLoader
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
from fastNLP import Const
from reproduction.utils import check_dataloader_paths
from fastNLP import Vocabulary

class ChineseNERLoader(DataSetLoader):
"""
读取中文命名实体数据集,包括PeopleDaily, MSRA-NER, Weibo。数据在这里可以找到https://github.com/OYE93/Chinese-NLP-Corpus/tree/master/NER
请确保输入数据的格式如下, 共两列,第一列为字,第二列为标签,不同句子以空行隔开
我 O
们 O
变 O
而 O
以 O
书 O
会 O
...

"""
def __init__(self, encoding_type:str='bioes'):
"""

:param str encoding_type: 支持bio和bioes格式
"""
super().__init__()
self._loader = ConllLoader(headers=['raw_chars', 'target'], indexes=[0, 1])

assert encoding_type in ('bio', 'bioes')

self._tag_converters = [iob2]
if encoding_type == 'bioes':
self._tag_converters.append(iob2bioes)

def load(self, path:str):
dataset = self._loader.load(path)
def convert_tag_schema(tags):
for converter in self._tag_converters:
tags = converter(tags)
return tags
if self._tag_converters:
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
return dataset

def process(self, paths, bigrams=False, trigrams=False):
"""

:param paths:
:param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>]
:param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>]
:return: ~fastNLP.io.DataBundle
包含以下的fields
raw_chars: List[str]
chars: List[int]
seq_len: int, 字的长度
bigrams: List[int], optional
trigrams: List[int], optional
target: List[int]
"""
paths = check_dataloader_paths(paths)
data = DataBundle()
input_fields = [Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET]
target_fields = [Const.TARGET, Const.INPUT_LEN]

for name, path in paths.items():
dataset = self.load(path)
if bigrams:
dataset.apply_field(lambda raw_chars: [c1+c2 for c1, c2 in zip(raw_chars, raw_chars[1:]+['<eos>'])],
field_name='raw_chars', new_field_name='bigrams')

if trigrams:
dataset.apply_field(lambda raw_chars: [c1+c2+c3 for c1, c2, c3 in zip(raw_chars,
raw_chars[1:]+['<eos>'],
raw_chars[2:]+['<eos>']*2)],
field_name='raw_chars', new_field_name='trigrams')
data.datasets[name] = dataset

char_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='raw_chars',
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
char_vocab.index_dataset(*data.datasets.values(), field_name='raw_chars', new_field_name=Const.CHAR_INPUT)
data.vocabs[Const.CHAR_INPUT] = char_vocab

target_vocab = Vocabulary(unknown=None, padding=None).from_dataset(data.datasets['train'], field_name=Const.TARGET)
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
data.vocabs[Const.TARGET] = target_vocab

if bigrams:
bigram_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='bigrams',
no_create_entry_dataset=[dataset for name, dataset in
data.datasets.items() if name != 'train'])
bigram_vocab.index_dataset(*data.datasets.values(), field_name='bigrams', new_field_name='bigrams')
data.vocabs['bigrams'] = bigram_vocab
input_fields.append('bigrams')

if trigrams:
trigram_vocab = Vocabulary().from_dataset(data.datasets['train'], field_name='trigrams',
no_create_entry_dataset=[dataset for name, dataset in
data.datasets.items() if name != 'train'])
trigram_vocab.index_dataset(*data.datasets.values(), field_name='trigrams', new_field_name='trigrams')
data.vocabs['trigrams'] = trigram_vocab
input_fields.append('trigrams')

for name, dataset in data.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT)
dataset.set_input(*input_fields)
dataset.set_target(*target_fields)

return data





+ 0
- 0
reproduction/seqence_labelling/chinese_ner/data/__init__.py View File


+ 18
- 15
reproduction/seqence_labelling/chinese_ner/train_bert.py View File

@@ -12,22 +12,23 @@ sys.path.append('../../../')
from torch import nn

from fastNLP.embeddings import BertEmbedding, Embedding
from reproduction.seqence_labelling.chinese_ner.data.ChineseNER import ChineseNERLoader
from fastNLP import Trainer, Const
from fastNLP import BucketSampler, SpanFPreRecMetric, GradientClipCallback
from fastNLP.modules import MLP
from fastNLP.core.callback import WarmupCallback
from fastNLP import CrossEntropyLoss
from fastNLP.core.optimizer import AdamW
import os
from fastNLP.io import MsraNERPipe, MsraNERLoader, WeiboNERPipe

from fastNLP import cache_results

encoding_type = 'bio'

@cache_results('caches/msra.pkl')
@cache_results('caches/weibo.pkl', _refresh=False)
def get_data():
data = ChineseNERLoader(encoding_type=encoding_type).process("MSRA/")
# data_dir = MsraNERLoader().download(dev_ratio=0)
# data = MsraNERPipe(encoding_type=encoding_type, target_pad_val=-100).process_from_file(data_dir)
data = WeiboNERPipe(encoding_type=encoding_type).process_from_file()
return data
data = get_data()
print(data)
@@ -35,10 +36,10 @@ print(data)
class BertCNNER(nn.Module):
def __init__(self, embed, tag_size):
super().__init__()

self.embedding = Embedding(embed, dropout=0.1)
self.embedding = embed
self.tag_size = tag_size
self.mlp = MLP(size_layer=[self.embedding.embedding_dim, tag_size])

def forward(self, chars):
# batch_size, max_len = words.size()
chars = self.embedding(chars)
@@ -46,11 +47,15 @@ class BertCNNER(nn.Module):

return {Const.OUTPUT: outputs}

embed = BertEmbedding(data.vocabs[Const.CHAR_INPUT], model_dir_or_name='en-base',
pool_method='max', requires_grad=True, layers='11')
def predict(self, chars):
# batch_size, max_len = words.size()
chars = self.embedding(chars)
outputs = self.mlp(chars)

for name, dataset in data.datasets.items():
dataset.set_pad_val(Const.TARGET, -100)
return {Const.OUTPUT: outputs}

embed = BertEmbedding(data.get_vocab(Const.CHAR_INPUT), model_dir_or_name='cn-wwm-ext',
pool_method='first', requires_grad=True, layers='11', include_cls_sep=False, dropout=0.5)

callbacks = [
GradientClipCallback(clip_type='norm', clip_value=1),
@@ -58,7 +63,7 @@ callbacks = [
]

model = BertCNNER(embed, len(data.vocabs[Const.TARGET]))
optimizer = AdamW(model.parameters(), lr=1e-4)
optimizer = AdamW(model.parameters(), lr=3e-5)

for name, dataset in data.datasets.items():
original_len = len(dataset)
@@ -66,13 +71,11 @@ for name, dataset in data.datasets.items():
clipped_len = len(dataset)
print("Delete {} instances in {}.".format(original_len-clipped_len, name))

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
device=[0, 1], dev_data=data.datasets['test'], batch_size=20,
device=0, dev_data=data.datasets['test'], batch_size=6,
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
loss=CrossEntropyLoss(reduction='sum'),
callbacks=callbacks, num_workers=2, n_epochs=5,
check_code_level=-1, update_every=3)
check_code_level=0, update_every=3)
trainer.train()


+ 56
- 14
reproduction/seqence_labelling/chinese_ner/train_cn_ner.py View File

@@ -1,7 +1,6 @@
import sys
sys.path.append('../../..')



from reproduction.seqence_labelling.chinese_ner.data.ChineseNER import ChineseNERLoader
from fastNLP.embeddings import StaticEmbedding

from torch import nn
@@ -14,7 +13,51 @@ import torch.nn.functional as F
from fastNLP import seq_len_to_mask
from fastNLP.core.const import Const as C
from fastNLP import SpanFPreRecMetric, Trainer
from fastNLP import cache_results
from fastNLP import cache_results, Vocabulary
from fastNLP.io.pipe.utils import _add_chars_field, _indexize

from fastNLP.io.pipe import Pipe
from fastNLP.core.utils import iob2bioes, iob2
from fastNLP.io import MsraNERLoader, WeiboNERLoader

class ChineseNERPipe(Pipe):
def __init__(self, encoding_type: str = 'bio', target_pad_val=0, bigram=False):
if encoding_type == 'bio':
self.convert_tag = iob2
else:
self.convert_tag = lambda words: iob2bioes(iob2(words))
self.target_pad_val = int(target_pad_val)
self.bigram = bigram

def process(self, data_bundle):
data_bundle.copy_field(C.RAW_CHAR, C.CHAR_INPUT)
input_fields = [C.TARGET, C.CHAR_INPUT, C.INPUT_LEN]
target_fields = [C.TARGET, C.INPUT_LEN]
if self.bigram:
for dataset in data_bundle.datasets.values():
dataset.apply_field(lambda chars:[c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])],
field_name=C.CHAR_INPUT, new_field_name='bigrams')
bigram_vocab = Vocabulary()
bigram_vocab.from_dataset(data_bundle.get_dataset('train'),field_name='bigrams',
no_create_entry_dataset=[ds for name, ds in data_bundle.datasets.items() if name!='train'])
bigram_vocab.index_dataset(*data_bundle.datasets.values(), field_name='bigrams')
data_bundle.set_vocab(bigram_vocab, field_name='bigrams')
input_fields.append('bigrams')

_add_chars_field(data_bundle, lower=False)

# index
_indexize(data_bundle, input_field_name=C.CHAR_INPUT, target_field_name=C.TARGET)

for name, dataset in data_bundle.datasets.items():
dataset.set_pad_val(C.TARGET, self.target_pad_val)
dataset.add_seq_len(C.CHAR_INPUT)

data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)

return data_bundle


class CNBiLSTMCRFNER(nn.Module):
def __init__(self, char_embed, num_classes, bigram_embed=None, trigram_embed=None, num_layers=1, hidden_size=100,
@@ -73,22 +116,21 @@ class CNBiLSTMCRFNER(nn.Module):
return self._forward(chars, bigrams, trigrams, seq_len)

# data_bundle = pickle.load(open('caches/msra.pkl', 'rb'))
@cache_results('caches/msra.pkl', _refresh=True)
@cache_results('caches/weibo-lstm.pkl', _refresh=False)
def get_data():
data_bundle = ChineseNERLoader().process('MSRA-NER/', bigrams=True)
char_embed = StaticEmbedding(data_bundle.vocabs['chars'],
model_dir_or_name='cn-char')
bigram_embed = StaticEmbedding(data_bundle.vocabs['bigrams'],
model_dir_or_name='cn-bigram')
data_bundle = WeiboNERLoader().load()
data_bundle = ChineseNERPipe(encoding_type='bioes', bigram=True).process(data_bundle)
char_embed = StaticEmbedding(data_bundle.get_vocab(C.CHAR_INPUT), model_dir_or_name='cn-fasttext')
bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), embedding_dim=100, min_freq=3)
return data_bundle, char_embed, bigram_embed
data_bundle, char_embed, bigram_embed = get_data()
# data_bundle = get_data()
print(data_bundle)

# exit(0)
data_bundle.datasets['train'].set_input('target')
data_bundle.datasets['dev'].set_input('target')
model = CNBiLSTMCRFNER(char_embed, num_classes=len(data_bundle.vocabs['target']), bigram_embed=bigram_embed)

Trainer(data_bundle.datasets['train'], model, batch_size=640,
Trainer(data_bundle.datasets['train'], model, batch_size=20,
metrics=SpanFPreRecMetric(data_bundle.vocabs['target'], encoding_type='bioes'),
num_workers=2, dev_data=data_bundle. datasets['dev'], device=3).train()
num_workers=2, dev_data=data_bundle. datasets['dev'], device=0).train()


+ 0
- 1
reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py View File

@@ -2,7 +2,6 @@
import torch
from torch import nn
from fastNLP import seq_len_to_mask
from fastNLP.modules import Embedding
from fastNLP.modules import LSTM
from fastNLP.modules import ConditionalRandomField, allowed_transitions
import torch.nn.functional as F


+ 20
- 45
reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py View File

@@ -1,8 +1,7 @@
import sys
sys.path.append('../../..')

from fastNLP.embeddings.embedding import CNNCharEmbedding, StaticEmbedding
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding

from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF
from fastNLP import Trainer
@@ -11,68 +10,44 @@ from fastNLP import BucketSampler
from fastNLP import Const
from torch.optim import SGD
from fastNLP import GradientClipCallback
from fastNLP.core.callback import FitlogCallback, LRScheduler
from fastNLP.core.callback import EvaluateCallback, LRScheduler
from torch.optim.lr_scheduler import LambdaLR
# from reproduction.seqence_labelling.ner.model.swats import SWATS
from fastNLP import cache_results

import fitlog
fitlog.debug()

from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader

from fastNLP.io.pipe.conll import Conll2003NERPipe
encoding_type = 'bioes'
@cache_results('caches/upper_conll2003.pkl')
@cache_results('caches/conll2003_new.pkl', _refresh=True)
def load_data():
data = Conll2003DataLoader(encoding_type=encoding_type).process('../../../../others/data/conll2003',
word_vocab_opt=VocabularyOption(min_freq=1),
lower=False)
# 替换路径
paths = {'test':"NER/corpus/CoNLL-2003/eng.testb",
'train':"NER/corpus/CoNLL-2003/eng.train",
'dev':"NER/corpus/CoNLL-2003/eng.testa"}
data = Conll2003NERPipe(encoding_type=encoding_type, target_pad_val=0).process_from_file(paths)
return data
data = load_data()
print(data)
char_embed = CNNCharEmbedding(vocab=data.vocabs['words'], embed_size=30, char_emb_size=30, filter_nums=[30],
kernel_sizes=[3], word_dropout=0.01, dropout=0.5)
# char_embed = LSTMCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30 ,char_emb_size=30)
word_embed = StaticEmbedding(vocab=data.vocabs['words'],
model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt',
char_embed = CNNCharEmbedding(vocab=data.get_vocab('words'), embed_size=30, char_emb_size=30, filter_nums=[30],
kernel_sizes=[3], word_dropout=0, dropout=0.5)
word_embed = StaticEmbedding(vocab=data.get_vocab('words'),
model_dir_or_name='en-glove-6b-100d',
requires_grad=True, lower=True, word_dropout=0.01, dropout=0.5)
word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std()

# import joblib
# raw_data = joblib.load('/hdd/fudanNLP/fastNLP/others/NER-with-LS/data/conll_with_data.joblib')
# def convert_to_ids(raw_words):
# ids = []
# for word in raw_words:
# id = raw_data['word_to_id'][word]
# id = raw_data['id_to_emb_map'][id]
# ids.append(id)
# return ids
# word_embed = raw_data['emb_matrix']
# for name, dataset in data.datasets.items():
# dataset.apply_field(convert_to_ids, field_name='raw_words', new_field_name=Const.INPUT)

# elmo_embed = ElmoEmbedding(vocab=data.vocabs['cap_words'],
# model_dir_or_name='.',
# requires_grad=True, layers='mix')
# char_embed = StackEmbedding([elmo_embed, char_embed])

model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
encoding_type=encoding_type)

callbacks = [
GradientClipCallback(clip_type='value', clip_value=5),
FitlogCallback({'test':data.datasets['test']}, verbose=1),
# SaveModelCallback('save_models/', top=3, only_param=False, save_on_exception=True)
EvaluateCallback(data=data.get_dataset('test')) # 额外对test上的数据进行性能评测
]
# optimizer = Adam(model.parameters(), lr=0.001)
# optimizer = SWATS(model.parameters(), verbose=True)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)

optimizer = SGD(model.parameters(), lr=0.008, momentum=0.9)
scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)))
callbacks.append(scheduler)


trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(batch_size=20),
device=1, dev_data=data.datasets['dev'], batch_size=20,
trainer = Trainer(train_data=data.get_dataset('train'), model=model, optimizer=optimizer, sampler=BucketSampler(),
device=0, dev_data=data.get_dataset('dev'), batch_size=20,
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
callbacks=callbacks, num_workers=2, n_epochs=100)
callbacks=callbacks, num_workers=2, n_epochs=100, dev_batch_size=512)
trainer.train()

+ 18
- 33
reproduction/seqence_labelling/ner/train_ontonote.py View File

@@ -11,52 +11,37 @@ from fastNLP import Const
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from fastNLP import GradientClipCallback
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.core.callback import FitlogCallback, LRScheduler
from functools import partial
from torch import nn
from fastNLP import BucketSampler
from fastNLP.core.callback import EvaluateCallback, LRScheduler
from fastNLP import cache_results
from fastNLP.io.pipe.conll import OntoNotesNERPipe

import fitlog
fitlog.debug()
fitlog.set_log_dir('logs/')

fitlog.add_hyper_in_file(__file__)
#######hyper
normalize = False
divide_std = True
lower = False
lr = 0.015
lr = 0.01
dropout = 0.5
batch_size = 20
init_method = 'default'
batch_size = 32
job_embed = False
data_name = 'ontonote'
#######hyper


init_method = {'default': None,
'xavier': partial(nn.init.xavier_normal_, gain=0.02),
'normal': partial(nn.init.normal_, std=0.02)
}[init_method]


from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader

encoding_type = 'bioes'

@cache_results('caches/ontonotes.pkl')
@cache_results('caches/ontonotes.pkl', _refresh=True)
def cache():
data = OntoNoteNERDataLoader(encoding_type=encoding_type).process('../../../../others/data/v4/english',
lower=lower,
word_vocab_opt=VocabularyOption(min_freq=1))
char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30],
kernel_sizes=[3])
data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file('../../../../others/data/v4/english')
char_embed = CNNCharEmbedding(vocab=data.vocabs['words'], embed_size=30, char_emb_size=30, filter_nums=[30],
kernel_sizes=[3], dropout=dropout)
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
model_dir_or_name='/remote-home/hyan01/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt',
model_dir_or_name='en-glove-100d',
requires_grad=True,
normalize=normalize,
init_method=init_method)
word_dropout=0.01,
dropout=dropout,
lower=True,
min_freq=2)
return data, char_embed, word_embed
data, char_embed, word_embed = cache()

@@ -67,7 +52,7 @@ model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=1200, num_layers=1, tag

callbacks = [
GradientClipCallback(clip_value=5, clip_type='value'),
FitlogCallback(data.datasets['test'], verbose=1)
EvaluateCallback(data.datasets['test'])
]

optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
@@ -75,8 +60,8 @@ scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.0
callbacks.append(scheduler)


trainer = Trainer(train_data=data.datasets['dev'][:100], model=model, optimizer=optimizer, sampler=None,
device=0, dev_data=data.datasets['dev'][:100], batch_size=batch_size,
trainer = Trainer(train_data=data.get_dataset('train'), model=model, optimizer=optimizer, sampler=BucketSampler(num_buckets=100),
device=0, dev_data=data.get_dataset('dev'), batch_size=batch_size,
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
callbacks=callbacks, num_workers=1, n_epochs=100)
callbacks=callbacks, num_workers=1, n_epochs=100, dev_batch_size=256)
trainer.train()

+ 14
- 0
test/embeddings/test_bert_embedding.py View File

@@ -0,0 +1,14 @@
import unittest
from fastNLP import Vocabulary
from fastNLP.embeddings import BertEmbedding
import torch
import os

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestDownload(unittest.TestCase):
def test_download(self):
# import os
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = BertEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/embedding/bert-base-cased')
words = torch.LongTensor([[0, 1, 2]])
print(embed(words).size())

+ 21
- 0
test/io/loader/test_conll_loader.py View File

@@ -0,0 +1,21 @@

import unittest
import os
from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader

class MSRANERTest(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_download(self):
MsraNERLoader().download(re_download=False)
data_bundle = MsraNERLoader().load()
print(data_bundle)

class PeopleDailyTest(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_download(self):
PeopleDailyNERLoader().download()

class WeiboNERTest(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_download(self):
WeiboNERLoader().download()

+ 12
- 0
test/io/pipe/test_conll.py View File

@@ -0,0 +1,12 @@
import unittest
import os
from fastNLP.io import MsraNERPipe, PeopleDailyPipe, WeiboNERPipe

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [MsraNERPipe, PeopleDailyPipe, WeiboNERPipe]:
with self.subTest(pipe=pipe):
print(pipe)
data_bundle = pipe().process_from_file()
print(data_bundle)

Loading…
Cancel
Save