@@ -238,8 +238,8 @@ class CrossEntropyLoss(LossBase): | |||||
pred = pred.tranpose(-1, pred) | pred = pred.tranpose(-1, pred) | ||||
pred = pred.reshape(-1, pred.size(-1)) | pred = pred.reshape(-1, pred.size(-1)) | ||||
target = target.reshape(-1) | target = target.reshape(-1) | ||||
if seq_len is not None: | |||||
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) | |||||
if seq_len is not None and target.dim()>1: | |||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) | |||||
target = target.masked_fill(mask, self.padding_idx) | target = target.masked_fill(mask, self.padding_idx) | ||||
return F.cross_entropy(input=pred, target=target, | return F.cross_entropy(input=pred, target=target, | ||||
@@ -347,7 +347,7 @@ class AccuracyMetric(MetricBase): | |||||
pass | pass | ||||
elif pred.dim() == target.dim() + 1: | elif pred.dim() == target.dim() + 1: | ||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
if seq_len is None: | |||||
if seq_len is None and target.dim()>1: | |||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | ||||
else: | else: | ||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | ||||
@@ -68,7 +68,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | ||||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | ||||
pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False): | |||||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False): | |||||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | ||||
@@ -165,7 +165,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
""" | """ | ||||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | ||||
word_dropout=0, dropout=0, requires_grad: bool = False): | |||||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||||
super().__init__() | super().__init__() | ||||
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | ||||
@@ -288,7 +288,7 @@ class _WordBertModel(nn.Module): | |||||
self.auto_truncate = auto_truncate | self.auto_truncate = auto_truncate | ||||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | ||||
logger.info("Start to generating word pieces for word.") | |||||
logger.info("Start to generate word pieces for word.") | |||||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | ||||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | ||||
found_count = 0 | found_count = 0 | ||||
@@ -374,7 +374,8 @@ class _WordBertModel(nn.Module): | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"After split words into word pieces, the lengths of word pieces are longer than the " | "After split words into word pieces, the lengths of word pieces are longer than the " | ||||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") | |||||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " | |||||
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | |||||
# +2是由于需要加入[CLS]与[SEP] | # +2是由于需要加入[CLS]与[SEP] | ||||
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), | word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), | ||||
@@ -407,15 +408,26 @@ class _WordBertModel(nn.Module): | |||||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | ||||
if self.include_cls_sep: | if self.include_cls_sep: | ||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||||
bert_outputs[-1].size(-1)) | |||||
s_shift = 1 | s_shift = 1 | ||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||||
bert_outputs[-1].size(-1)) | |||||
else: | else: | ||||
s_shift = 0 | |||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | ||||
bert_outputs[-1].size(-1)) | bert_outputs[-1].size(-1)) | ||||
s_shift = 0 | |||||
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) | batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) | ||||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | ||||
if self.pool_method == 'first': | |||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | |||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||||
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
elif self.pool_method == 'last': | |||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 | |||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||||
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
for l_index, l in enumerate(self.layers): | for l_index, l in enumerate(self.layers): | ||||
output_layer = bert_outputs[l] | output_layer = bert_outputs[l] | ||||
real_word_piece_length = output_layer.size(1) - 2 | real_word_piece_length = output_layer.size(1) - 2 | ||||
@@ -426,16 +438,15 @@ class _WordBertModel(nn.Module): | |||||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | ||||
# 从word_piece collapse到word的表示 | # 从word_piece collapse到word的表示 | ||||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | ||||
outputs_seq_len = seq_len + s_shift | |||||
if self.pool_method == 'first': | if self.pool_method == 'first': | ||||
for i in range(batch_size): | |||||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置 | |||||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[ | |||||
i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size | |||||
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] | |||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | |||||
elif self.pool_method == 'last': | elif self.pool_method == 'last': | ||||
for i in range(batch_size): | |||||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i] + 1] - 1 # 每个word的end | |||||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] | |||||
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] | |||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | |||||
elif self.pool_method == 'max': | elif self.pool_method == 'max': | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
for j in range(seq_len[i]): | for j in range(seq_len[i]): | ||||
@@ -452,5 +463,6 @@ class _WordBertModel(nn.Module): | |||||
else: | else: | ||||
outputs[l_index, :, 0] = output_layer[:, 0] | outputs[l_index, :, 0] = output_layer[:, 0] | ||||
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] | outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] | ||||
# 3. 最终的embedding结果 | # 3. 最终的embedding结果 | ||||
return outputs | return outputs |
@@ -24,6 +24,7 @@ __all__ = [ | |||||
'IMDBLoader', | 'IMDBLoader', | ||||
'SSTLoader', | 'SSTLoader', | ||||
'SST2Loader', | 'SST2Loader', | ||||
"ChnSentiCorpLoader", | |||||
'ConllLoader', | 'ConllLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
@@ -52,8 +53,9 @@ __all__ = [ | |||||
"SSTPipe", | "SSTPipe", | ||||
"SST2Pipe", | "SST2Pipe", | ||||
"IMDBPipe", | "IMDBPipe", | ||||
"Conll2003Pipe", | |||||
"ChnSentiCorpPipe", | |||||
"Conll2003Pipe", | |||||
"Conll2003NERPipe", | "Conll2003NERPipe", | ||||
"OntoNotesNERPipe", | "OntoNotesNERPipe", | ||||
"MsraNERPipe", | "MsraNERPipe", | ||||
@@ -306,12 +306,15 @@ class DataBundle: | |||||
return self | return self | ||||
def __repr__(self): | def __repr__(self): | ||||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||||
for name, dataset in self.datasets.items(): | |||||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||||
for name, vocab in self.vocabs.items(): | |||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||||
_str = '' | |||||
if len(self.datasets): | |||||
_str += 'In total {} datasets:\n'.format(len(self.datasets)) | |||||
for name, dataset in self.datasets.items(): | |||||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||||
if len(self.vocabs): | |||||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||||
for name, vocab in self.vocabs.items(): | |||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||||
return _str | return _str | ||||
@@ -77,6 +77,9 @@ PRETRAIN_STATIC_FILES = { | |||||
'cn-tencent': "tencent_cn.zip", | 'cn-tencent': "tencent_cn.zip", | ||||
'cn-fasttext': "cc.zh.300.vec.gz", | 'cn-fasttext': "cc.zh.300.vec.gz", | ||||
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', | 'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', | ||||
'cn-char-fastnlp-100d': "cn_char_fastnlp_100d.zip", | |||||
'cn-bi-fastnlp-100d': "cn_bi_fastnlp_100d.zip", | |||||
"cn-tri-fastnlp-100d": "cn_tri_fastnlp_100d.zip" | |||||
} | } | ||||
DATASET_DIR = { | DATASET_DIR = { | ||||
@@ -96,7 +99,9 @@ DATASET_DIR = { | |||||
"cws-pku": 'cws_pku.zip', | "cws-pku": 'cws_pku.zip', | ||||
"cws-cityu": "cws_cityu.zip", | "cws-cityu": "cws_cityu.zip", | ||||
"cws-as": 'cws_as.zip', | "cws-as": 'cws_as.zip', | ||||
"cws-msra": 'cws_msra.zip' | |||||
"cws-msra": 'cws_msra.zip', | |||||
"chn-senti-corp":"chn_senti_corp.zip" | |||||
} | } | ||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | ||||
@@ -52,6 +52,7 @@ __all__ = [ | |||||
'IMDBLoader', | 'IMDBLoader', | ||||
'SSTLoader', | 'SSTLoader', | ||||
'SST2Loader', | 'SST2Loader', | ||||
"ChnSentiCorpLoader", | |||||
'ConllLoader', | 'ConllLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
@@ -73,7 +74,7 @@ __all__ = [ | |||||
"QNLILoader", | "QNLILoader", | ||||
"RTELoader" | "RTELoader" | ||||
] | ] | ||||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader | |||||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | |||||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | ||||
from .csv import CSVLoader | from .csv import CSVLoader | ||||
from .cws import CWSLoader | from .cws import CWSLoader | ||||
@@ -7,6 +7,7 @@ __all__ = [ | |||||
"IMDBLoader", | "IMDBLoader", | ||||
"SSTLoader", | "SSTLoader", | ||||
"SST2Loader", | "SST2Loader", | ||||
"ChnSentiCorpLoader" | |||||
] | ] | ||||
import glob | import glob | ||||
@@ -346,3 +347,59 @@ class SST2Loader(Loader): | |||||
""" | """ | ||||
output_dir = self._get_dataset_path(dataset_name='sst-2') | output_dir = self._get_dataset_path(dataset_name='sst-2') | ||||
return output_dir | return output_dir | ||||
class ChnSentiCorpLoader(Loader): | |||||
""" | |||||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | |||||
一个制表符及之后认为是句子 | |||||
Example:: | |||||
label raw_chars | |||||
1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~ | |||||
1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道... | |||||
0 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货... | |||||
读取后的DataSet具有以下的field | |||||
.. csv-table:: | |||||
:header: "raw_chars", "target" | |||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||||
"..." | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path:str): | |||||
""" | |||||
从path中读取数据 | |||||
:param path: | |||||
:return: | |||||
""" | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
f.readline() | |||||
for line in f: | |||||
line = line.strip() | |||||
tab_index = line.index('\t') | |||||
if tab_index!=-1: | |||||
target = line[:tab_index] | |||||
raw_chars = line[tab_index+1:] | |||||
if raw_chars: | |||||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||||
return ds | |||||
def download(self)->str: | |||||
""" | |||||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | |||||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path('chn-senti-corp') | |||||
return output_dir |
@@ -17,6 +17,7 @@ __all__ = [ | |||||
"SSTPipe", | "SSTPipe", | ||||
"SST2Pipe", | "SST2Pipe", | ||||
"IMDBPipe", | "IMDBPipe", | ||||
"ChnSentiCorpPipe", | |||||
"Conll2003NERPipe", | "Conll2003NERPipe", | ||||
"OntoNotesNERPipe", | "OntoNotesNERPipe", | ||||
@@ -39,7 +40,7 @@ __all__ = [ | |||||
"MNLIPipe", | "MNLIPipe", | ||||
] | ] | ||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe | |||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | |||||
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | ||||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | ||||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | ||||
@@ -5,7 +5,8 @@ __all__ = [ | |||||
"YelpPolarityPipe", | "YelpPolarityPipe", | ||||
"SSTPipe", | "SSTPipe", | ||||
"SST2Pipe", | "SST2Pipe", | ||||
'IMDBPipe' | |||||
'IMDBPipe', | |||||
"ChnSentiCorpPipe" | |||||
] | ] | ||||
import re | import re | ||||
@@ -13,18 +14,18 @@ import re | |||||
from nltk import Tree | from nltk import Tree | ||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field | |||||
from ..data_bundle import DataBundle | from ..data_bundle import DataBundle | ||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | ||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ..loader.classification import ChnSentiCorpLoader | |||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | ||||
class _CLSPipe(Pipe): | class _CLSPipe(Pipe): | ||||
""" | """ | ||||
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | ||||
@@ -457,3 +458,97 @@ class IMDBPipe(_CLSPipe): | |||||
data_bundle = self.process(data_bundle) | data_bundle = self.process(data_bundle) | ||||
return data_bundle | return data_bundle | ||||
class ChnSentiCorpPipe(Pipe): | |||||
""" | |||||
处理之后的DataSet有以下的结构 | |||||
.. csv-table:: | |||||
:header: "raw_chars", "chars", "target", "seq_len" | |||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31 | |||||
"<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25 | |||||
"..." | |||||
其中chars, seq_len是input,target是target | |||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||||
data_bundle.get_vocab('bigrams')获取. | |||||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||||
data_bundle.get_vocab('trigrams')获取. | |||||
""" | |||||
def __init__(self, bigrams=False, trigrams=False): | |||||
super().__init__() | |||||
self.bigrams = bigrams | |||||
self.trigrams = trigrams | |||||
def _tokenize(self, data_bundle): | |||||
""" | |||||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) | |||||
return data_bundle | |||||
def process(self, data_bundle:DataBundle): | |||||
""" | |||||
可以处理的DataSet应该具备以下的field | |||||
.. csv-table:: | |||||
:header: "raw_chars", "target" | |||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||||
"..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
_add_chars_field(data_bundle, lower=False) | |||||
data_bundle = self._tokenize(data_bundle) | |||||
input_field_names = [Const.CHAR_INPUT] | |||||
if self.bigrams: | |||||
for name, dataset in data_bundle.iter_datasets(): | |||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||||
input_field_names.append('bigrams') | |||||
if self.trigrams: | |||||
for name, dataset in data_bundle.iter_datasets(): | |||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
field_name=Const.CHAR_INPUT, new_field_name='trigrams') | |||||
input_field_names.append('trigrams') | |||||
# index | |||||
_indexize(data_bundle, input_field_names, Const.TARGET) | |||||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||||
target_fields = [Const.TARGET] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.CHAR_INPUT) | |||||
data_bundle.set_input(*input_fields) | |||||
data_bundle.set_target(*target_fields) | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | |||||
""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||||
:return: DataBundle | |||||
""" | |||||
# 读取数据 | |||||
data_bundle = ChnSentiCorpLoader().load(paths) | |||||
data_bundle = self.process(data_bundle) | |||||
return data_bundle |
@@ -222,14 +222,23 @@ class _CNNERPipe(Pipe): | |||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 | ||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||||
data_bundle.get_vocab('bigrams')获取. | |||||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||||
data_bundle.get_vocab('trigrams')获取. | |||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio'): | |||||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | |||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
else: | else: | ||||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | self.convert_tag = lambda words: iob2bioes(iob2(words)) | ||||
self.bigrams = bigrams | |||||
self.trigrams = trigrams | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | """ | ||||
支持的DataSet的field为 | 支持的DataSet的field为 | ||||
@@ -241,11 +250,11 @@ class _CNNERPipe(Pipe): | |||||
"[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]" | "[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]" | ||||
"[...]", "[...]" | "[...]", "[...]" | ||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], | |||||
是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||||
在传入DataBundle基础上原位修改。 | |||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field | |||||
的内容均为List[str]。在传入DataBundle基础上原位修改。 | |||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
# 转换tag | # 转换tag | ||||
@@ -253,11 +262,24 @@ class _CNNERPipe(Pipe): | |||||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | ||||
_add_chars_field(data_bundle, lower=False) | _add_chars_field(data_bundle, lower=False) | ||||
input_field_names = [Const.CHAR_INPUT] | |||||
if self.bigrams: | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||||
input_field_names.append('bigrams') | |||||
if self.trigrams: | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||||
field_name=Const.CHAR_INPUT, new_field_name='trigrams') | |||||
input_field_names.append('trigrams') | |||||
# index | # index | ||||
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) | |||||
_indexize(data_bundle, input_field_names, Const.TARGET) | |||||
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] | |||||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | target_fields = [Const.TARGET, Const.INPUT_LEN] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
@@ -13,6 +13,12 @@ class TestDownload(unittest.TestCase): | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | words = torch.LongTensor([[2, 3, 4, 0]]) | ||||
print(embed(words).size()) | print(embed(words).size()) | ||||
for pool_method in ['first', 'last', 'max', 'avg']: | |||||
for include_cls_sep in [True, False]: | |||||
embed = BertEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method, | |||||
include_cls_sep=include_cls_sep) | |||||
print(embed(words).size()) | |||||
def test_word_drop(self): | def test_word_drop(self): | ||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | vocab = Vocabulary().add_word_lst("This is a test .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) | embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) | ||||
@@ -5,22 +5,22 @@ from fastNLP.io.loader.classification import YelpPolarityLoader | |||||
from fastNLP.io.loader.classification import IMDBLoader | from fastNLP.io.loader.classification import IMDBLoader | ||||
from fastNLP.io.loader.classification import SST2Loader | from fastNLP.io.loader.classification import SST2Loader | ||||
from fastNLP.io.loader.classification import SSTLoader | from fastNLP.io.loader.classification import SSTLoader | ||||
from fastNLP.io.loader.classification import ChnSentiCorpLoader | |||||
import os | import os | ||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
class TestDownload(unittest.TestCase): | class TestDownload(unittest.TestCase): | ||||
def test_download(self): | def test_download(self): | ||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: | |||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]: | |||||
loader().download() | loader().download() | ||||
def test_load(self): | def test_load(self): | ||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: | |||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]: | |||||
data_bundle = loader().load() | data_bundle = loader().load() | ||||
print(data_bundle) | print(data_bundle) | ||||
class TestLoad(unittest.TestCase): | class TestLoad(unittest.TestCase): | ||||
def test_load(self): | def test_load(self): | ||||
for loader in [IMDBLoader]: | for loader in [IMDBLoader]: | ||||
data_bundle = loader().load('test/data_for_tests/io/imdb') | data_bundle = loader().load('test/data_for_tests/io/imdb') | ||||
@@ -5,7 +5,7 @@ from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNE | |||||
Conll2003Loader | Conll2003Loader | ||||
class MSRANERTest(unittest.TestCase): | |||||
class TestMSRANER(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
def test_download(self): | def test_download(self): | ||||
MsraNERLoader().download(re_download=False) | MsraNERLoader().download(re_download=False) | ||||
@@ -13,13 +13,13 @@ class MSRANERTest(unittest.TestCase): | |||||
print(data_bundle) | print(data_bundle) | ||||
class PeopleDailyTest(unittest.TestCase): | |||||
class TestPeopleDaily(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
def test_download(self): | def test_download(self): | ||||
PeopleDailyNERLoader().download() | PeopleDailyNERLoader().download() | ||||
class WeiboNERTest(unittest.TestCase): | |||||
class TestWeiboNER(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
def test_download(self): | def test_download(self): | ||||
WeiboNERLoader().download() | WeiboNERLoader().download() | ||||
@@ -3,7 +3,7 @@ import os | |||||
from fastNLP.io.loader import CWSLoader | from fastNLP.io.loader import CWSLoader | ||||
class CWSLoaderTest(unittest.TestCase): | |||||
class TestCWSLoader(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
def test_download(self): | def test_download(self): | ||||
dataset_names = ['pku', 'cityu', 'as', 'msra'] | dataset_names = ['pku', 'cityu', 'as', 'msra'] | ||||
@@ -13,7 +13,7 @@ class CWSLoaderTest(unittest.TestCase): | |||||
print(data_bundle) | print(data_bundle) | ||||
class RunCWSLoaderTest(unittest.TestCase): | |||||
class TestRunCWSLoader(unittest.TestCase): | |||||
def test_cws_loader(self): | def test_cws_loader(self): | ||||
dataset_names = ['msra'] | dataset_names = ['msra'] | ||||
for dataset_name in dataset_names: | for dataset_name in dataset_names: | ||||
@@ -8,7 +8,7 @@ from fastNLP.io.loader.matching import MNLILoader | |||||
import os | import os | ||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
class TestDownload(unittest.TestCase): | |||||
class TestMatchingDownload(unittest.TestCase): | |||||
def test_download(self): | def test_download(self): | ||||
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: | for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: | ||||
loader().download() | loader().download() | ||||
@@ -21,8 +21,7 @@ class TestDownload(unittest.TestCase): | |||||
print(data_bundle) | print(data_bundle) | ||||
class TestLoad(unittest.TestCase): | |||||
class TestMatchingLoad(unittest.TestCase): | |||||
def test_load(self): | def test_load(self): | ||||
for loader in [RTELoader]: | for loader in [RTELoader]: | ||||
data_bundle = loader().load('test/data_for_tests/io/rte') | data_bundle = loader().load('test/data_for_tests/io/rte') | ||||
@@ -2,9 +2,10 @@ import unittest | |||||
import os | import os | ||||
from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe | from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe | ||||
from fastNLP.io.pipe.classification import ChnSentiCorpPipe | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
class TestPipe(unittest.TestCase): | |||||
class TestClassificationPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | ||||
with self.subTest(pipe=pipe): | with self.subTest(pipe=pipe): | ||||
@@ -14,8 +15,16 @@ class TestPipe(unittest.TestCase): | |||||
class TestRunPipe(unittest.TestCase): | class TestRunPipe(unittest.TestCase): | ||||
def test_load(self): | def test_load(self): | ||||
for pipe in [IMDBPipe]: | for pipe in [IMDBPipe]: | ||||
data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb') | data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb') | ||||
print(data_bundle) | print(data_bundle) | ||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
class TestCNClassificationPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | |||||
for pipe in [ChnSentiCorpPipe]: | |||||
with self.subTest(pipe=pipe): | |||||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file() | |||||
print(data_bundle) |
@@ -4,12 +4,14 @@ from fastNLP.io import MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, Conll2003Pipe | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
class TestPipe(unittest.TestCase): | |||||
class TestConllPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
for pipe in [MsraNERPipe, PeopleDailyPipe, WeiboNERPipe]: | for pipe in [MsraNERPipe, PeopleDailyPipe, WeiboNERPipe]: | ||||
with self.subTest(pipe=pipe): | with self.subTest(pipe=pipe): | ||||
print(pipe) | print(pipe) | ||||
data_bundle = pipe().process_from_file() | |||||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file() | |||||
print(data_bundle) | |||||
data_bundle = pipe(encoding_type='bioes').process_from_file() | |||||
print(data_bundle) | print(data_bundle) | ||||
@@ -4,7 +4,7 @@ import os | |||||
from fastNLP.io.pipe.cws import CWSPipe | from fastNLP.io.pipe.cws import CWSPipe | ||||
class CWSPipeTest(unittest.TestCase): | |||||
class TestCWSPipe(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
dataset_names = ['pku', 'cityu', 'as', 'msra'] | dataset_names = ['pku', 'cityu', 'as', 'msra'] | ||||
@@ -14,7 +14,7 @@ class CWSPipeTest(unittest.TestCase): | |||||
print(data_bundle) | print(data_bundle) | ||||
class RunCWSPipeTest(unittest.TestCase): | |||||
class TestRunCWSPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
dataset_names = ['msra'] | dataset_names = ['msra'] | ||||
for dataset_name in dataset_names: | for dataset_name in dataset_names: | ||||
@@ -7,7 +7,7 @@ from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MN | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
class TestPipe(unittest.TestCase): | |||||
class TestMatchingPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]: | for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]: | ||||
with self.subTest(pipe=pipe): | with self.subTest(pipe=pipe): | ||||
@@ -17,7 +17,7 @@ class TestPipe(unittest.TestCase): | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | ||||
class TestBertPipe(unittest.TestCase): | |||||
class TestMatchingBertPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]: | for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]: | ||||
with self.subTest(pipe=pipe): | with self.subTest(pipe=pipe): | ||||
@@ -26,7 +26,7 @@ class TestBertPipe(unittest.TestCase): | |||||
print(data_bundle) | print(data_bundle) | ||||
class TestRunPipe(unittest.TestCase): | |||||
class TestRunMatchingPipe(unittest.TestCase): | |||||
def test_load(self): | def test_load(self): | ||||
for pipe in [RTEPipe, RTEBertPipe]: | for pipe in [RTEPipe, RTEBertPipe]: | ||||