@@ -238,8 +238,8 @@ class CrossEntropyLoss(LossBase): | |||
pred = pred.tranpose(-1, pred) | |||
pred = pred.reshape(-1, pred.size(-1)) | |||
target = target.reshape(-1) | |||
if seq_len is not None: | |||
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) | |||
if seq_len is not None and target.dim()>1: | |||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) | |||
target = target.masked_fill(mask, self.padding_idx) | |||
return F.cross_entropy(input=pred, target=target, | |||
@@ -347,7 +347,7 @@ class AccuracyMetric(MetricBase): | |||
pass | |||
elif pred.dim() == target.dim() + 1: | |||
pred = pred.argmax(dim=-1) | |||
if seq_len is None: | |||
if seq_len is None and target.dim()>1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
@@ -68,7 +68,7 @@ class BertEmbedding(ContextualEmbedding): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | |||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | |||
pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False): | |||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False): | |||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
@@ -165,7 +165,7 @@ class BertWordPieceEncoder(nn.Module): | |||
""" | |||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | |||
word_dropout=0, dropout=0, requires_grad: bool = False): | |||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||
super().__init__() | |||
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | |||
@@ -288,7 +288,7 @@ class _WordBertModel(nn.Module): | |||
self.auto_truncate = auto_truncate | |||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | |||
logger.info("Start to generating word pieces for word.") | |||
logger.info("Start to generate word pieces for word.") | |||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | |||
found_count = 0 | |||
@@ -374,7 +374,8 @@ class _WordBertModel(nn.Module): | |||
else: | |||
raise RuntimeError( | |||
"After split words into word pieces, the lengths of word pieces are longer than the " | |||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") | |||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " | |||
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | |||
# +2是由于需要加入[CLS]与[SEP] | |||
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), | |||
@@ -407,15 +408,26 @@ class _WordBertModel(nn.Module): | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | |||
if self.include_cls_sep: | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
bert_outputs[-1].size(-1)) | |||
s_shift = 1 | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
bert_outputs[-1].size(-1)) | |||
else: | |||
s_shift = 0 | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | |||
bert_outputs[-1].size(-1)) | |||
s_shift = 0 | |||
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) | |||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | |||
if self.pool_method == 'first': | |||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | |||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||
elif self.pool_method == 'last': | |||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 | |||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||
for l_index, l in enumerate(self.layers): | |||
output_layer = bert_outputs[l] | |||
real_word_piece_length = output_layer.size(1) - 2 | |||
@@ -426,16 +438,15 @@ class _WordBertModel(nn.Module): | |||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | |||
# 从word_piece collapse到word的表示 | |||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | |||
outputs_seq_len = seq_len + s_shift | |||
if self.pool_method == 'first': | |||
for i in range(batch_size): | |||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置 | |||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[ | |||
i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size | |||
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] | |||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | |||
elif self.pool_method == 'last': | |||
for i in range(batch_size): | |||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i] + 1] - 1 # 每个word的end | |||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] | |||
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] | |||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | |||
elif self.pool_method == 'max': | |||
for i in range(batch_size): | |||
for j in range(seq_len[i]): | |||
@@ -452,5 +463,6 @@ class _WordBertModel(nn.Module): | |||
else: | |||
outputs[l_index, :, 0] = output_layer[:, 0] | |||
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] | |||
# 3. 最终的embedding结果 | |||
return outputs |
@@ -24,6 +24,7 @@ __all__ = [ | |||
'IMDBLoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
@@ -52,8 +53,9 @@ __all__ = [ | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"Conll2003Pipe", | |||
"ChnSentiCorpPipe", | |||
"Conll2003Pipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
@@ -306,12 +306,15 @@ class DataBundle: | |||
return self | |||
def __repr__(self): | |||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
_str = '' | |||
if len(self.datasets): | |||
_str += 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
if len(self.vocabs): | |||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
@@ -77,6 +77,9 @@ PRETRAIN_STATIC_FILES = { | |||
'cn-tencent': "tencent_cn.zip", | |||
'cn-fasttext': "cc.zh.300.vec.gz", | |||
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', | |||
'cn-char-fastnlp-100d': "cn_char_fastnlp_100d.zip", | |||
'cn-bi-fastnlp-100d': "cn_bi_fastnlp_100d.zip", | |||
"cn-tri-fastnlp-100d": "cn_tri_fastnlp_100d.zip" | |||
} | |||
DATASET_DIR = { | |||
@@ -96,7 +99,9 @@ DATASET_DIR = { | |||
"cws-pku": 'cws_pku.zip', | |||
"cws-cityu": "cws_cityu.zip", | |||
"cws-as": 'cws_as.zip', | |||
"cws-msra": 'cws_msra.zip' | |||
"cws-msra": 'cws_msra.zip', | |||
"chn-senti-corp":"chn_senti_corp.zip" | |||
} | |||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||
@@ -52,6 +52,7 @@ __all__ = [ | |||
'IMDBLoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
@@ -73,7 +74,7 @@ __all__ = [ | |||
"QNLILoader", | |||
"RTELoader" | |||
] | |||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader | |||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | |||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | |||
from .csv import CSVLoader | |||
from .cws import CWSLoader | |||
@@ -7,6 +7,7 @@ __all__ = [ | |||
"IMDBLoader", | |||
"SSTLoader", | |||
"SST2Loader", | |||
"ChnSentiCorpLoader" | |||
] | |||
import glob | |||
@@ -346,3 +347,59 @@ class SST2Loader(Loader): | |||
""" | |||
output_dir = self._get_dataset_path(dataset_name='sst-2') | |||
return output_dir | |||
class ChnSentiCorpLoader(Loader): | |||
""" | |||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | |||
一个制表符及之后认为是句子 | |||
Example:: | |||
label raw_chars | |||
1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~ | |||
1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道... | |||
0 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货... | |||
读取后的DataSet具有以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||
"..." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
""" | |||
从path中读取数据 | |||
:param path: | |||
:return: | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() | |||
for line in f: | |||
line = line.strip() | |||
tab_index = line.index('\t') | |||
if tab_index!=-1: | |||
target = line[:tab_index] | |||
raw_chars = line[tab_index+1:] | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
def download(self)->str: | |||
""" | |||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | |||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('chn-senti-corp') | |||
return output_dir |
@@ -17,6 +17,7 @@ __all__ = [ | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"ChnSentiCorpPipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
@@ -39,7 +40,7 @@ __all__ = [ | |||
"MNLIPipe", | |||
] | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe | |||
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | |||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | |||
@@ -5,7 +5,8 @@ __all__ = [ | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
'IMDBPipe' | |||
'IMDBPipe', | |||
"ChnSentiCorpPipe" | |||
] | |||
import re | |||
@@ -13,18 +14,18 @@ import re | |||
from nltk import Tree | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field | |||
from ..data_bundle import DataBundle | |||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.vocabulary import Vocabulary | |||
from ..loader.classification import ChnSentiCorpLoader | |||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
class _CLSPipe(Pipe): | |||
""" | |||
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | |||
@@ -457,3 +458,97 @@ class IMDBPipe(_CLSPipe): | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class ChnSentiCorpPipe(Pipe): | |||
""" | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31 | |||
"<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25 | |||
"..." | |||
其中chars, seq_len是input,target是target | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('bigrams')获取. | |||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
super().__init__() | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
def _tokenize(self, data_bundle): | |||
""" | |||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | |||
:param data_bundle: | |||
:return: | |||
""" | |||
data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) | |||
return data_bundle | |||
def process(self, data_bundle:DataBundle): | |||
""" | |||
可以处理的DataSet应该具备以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||
"..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
_add_chars_field(data_bundle, lower=False) | |||
data_bundle = self._tokenize(data_bundle) | |||
input_field_names = [Const.CHAR_INPUT] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name=Const.CHAR_INPUT, new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
# index | |||
_indexize(data_bundle, input_field_names, Const.TARGET) | |||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.CHAR_INPUT) | |||
data_bundle.set_input(*input_fields) | |||
data_bundle.set_target(*target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = ChnSentiCorpLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle |
@@ -222,14 +222,23 @@ class _CNNERPipe(Pipe): | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('bigrams')获取. | |||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
def __init__(self, encoding_type: str = 'bio'): | |||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
支持的DataSet的field为 | |||
@@ -241,11 +250,11 @@ class _CNNERPipe(Pipe): | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]" | |||
"[...]", "[...]" | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], | |||
是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||
在传入DataBundle基础上原位修改。 | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field | |||
的内容均为List[str]。在传入DataBundle基础上原位修改。 | |||
:return: DataBundle | |||
""" | |||
# 转换tag | |||
@@ -253,11 +262,24 @@ class _CNNERPipe(Pipe): | |||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
_add_chars_field(data_bundle, lower=False) | |||
input_field_names = [Const.CHAR_INPUT] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name=Const.CHAR_INPUT, new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name=Const.CHAR_INPUT, new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
# index | |||
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) | |||
_indexize(data_bundle, input_field_names, Const.TARGET) | |||
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] | |||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
@@ -13,6 +13,12 @@ class TestDownload(unittest.TestCase): | |||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||
print(embed(words).size()) | |||
for pool_method in ['first', 'last', 'max', 'avg']: | |||
for include_cls_sep in [True, False]: | |||
embed = BertEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method, | |||
include_cls_sep=include_cls_sep) | |||
print(embed(words).size()) | |||
def test_word_drop(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) | |||
@@ -5,22 +5,22 @@ from fastNLP.io.loader.classification import YelpPolarityLoader | |||
from fastNLP.io.loader.classification import IMDBLoader | |||
from fastNLP.io.loader.classification import SST2Loader | |||
from fastNLP.io.loader.classification import SSTLoader | |||
from fastNLP.io.loader.classification import ChnSentiCorpLoader | |||
import os | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestDownload(unittest.TestCase): | |||
def test_download(self): | |||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: | |||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]: | |||
loader().download() | |||
def test_load(self): | |||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: | |||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]: | |||
data_bundle = loader().load() | |||
print(data_bundle) | |||
class TestLoad(unittest.TestCase): | |||
def test_load(self): | |||
for loader in [IMDBLoader]: | |||
data_bundle = loader().load('test/data_for_tests/io/imdb') | |||
@@ -5,7 +5,7 @@ from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNE | |||
Conll2003Loader | |||
class MSRANERTest(unittest.TestCase): | |||
class TestMSRANER(unittest.TestCase): | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_download(self): | |||
MsraNERLoader().download(re_download=False) | |||
@@ -13,13 +13,13 @@ class MSRANERTest(unittest.TestCase): | |||
print(data_bundle) | |||
class PeopleDailyTest(unittest.TestCase): | |||
class TestPeopleDaily(unittest.TestCase): | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_download(self): | |||
PeopleDailyNERLoader().download() | |||
class WeiboNERTest(unittest.TestCase): | |||
class TestWeiboNER(unittest.TestCase): | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_download(self): | |||
WeiboNERLoader().download() | |||
@@ -3,7 +3,7 @@ import os | |||
from fastNLP.io.loader import CWSLoader | |||
class CWSLoaderTest(unittest.TestCase): | |||
class TestCWSLoader(unittest.TestCase): | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_download(self): | |||
dataset_names = ['pku', 'cityu', 'as', 'msra'] | |||
@@ -13,7 +13,7 @@ class CWSLoaderTest(unittest.TestCase): | |||
print(data_bundle) | |||
class RunCWSLoaderTest(unittest.TestCase): | |||
class TestRunCWSLoader(unittest.TestCase): | |||
def test_cws_loader(self): | |||
dataset_names = ['msra'] | |||
for dataset_name in dataset_names: | |||
@@ -8,7 +8,7 @@ from fastNLP.io.loader.matching import MNLILoader | |||
import os | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestDownload(unittest.TestCase): | |||
class TestMatchingDownload(unittest.TestCase): | |||
def test_download(self): | |||
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: | |||
loader().download() | |||
@@ -21,8 +21,7 @@ class TestDownload(unittest.TestCase): | |||
print(data_bundle) | |||
class TestLoad(unittest.TestCase): | |||
class TestMatchingLoad(unittest.TestCase): | |||
def test_load(self): | |||
for loader in [RTELoader]: | |||
data_bundle = loader().load('test/data_for_tests/io/rte') | |||
@@ -2,9 +2,10 @@ import unittest | |||
import os | |||
from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe | |||
from fastNLP.io.pipe.classification import ChnSentiCorpPipe | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestPipe(unittest.TestCase): | |||
class TestClassificationPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | |||
with self.subTest(pipe=pipe): | |||
@@ -14,8 +15,16 @@ class TestPipe(unittest.TestCase): | |||
class TestRunPipe(unittest.TestCase): | |||
def test_load(self): | |||
for pipe in [IMDBPipe]: | |||
data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb') | |||
print(data_bundle) | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestCNClassificationPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
for pipe in [ChnSentiCorpPipe]: | |||
with self.subTest(pipe=pipe): | |||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file() | |||
print(data_bundle) |
@@ -4,12 +4,14 @@ from fastNLP.io import MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, Conll2003Pipe | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestPipe(unittest.TestCase): | |||
class TestConllPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
for pipe in [MsraNERPipe, PeopleDailyPipe, WeiboNERPipe]: | |||
with self.subTest(pipe=pipe): | |||
print(pipe) | |||
data_bundle = pipe().process_from_file() | |||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file() | |||
print(data_bundle) | |||
data_bundle = pipe(encoding_type='bioes').process_from_file() | |||
print(data_bundle) | |||
@@ -4,7 +4,7 @@ import os | |||
from fastNLP.io.pipe.cws import CWSPipe | |||
class CWSPipeTest(unittest.TestCase): | |||
class TestCWSPipe(unittest.TestCase): | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
def test_process_from_file(self): | |||
dataset_names = ['pku', 'cityu', 'as', 'msra'] | |||
@@ -14,7 +14,7 @@ class CWSPipeTest(unittest.TestCase): | |||
print(data_bundle) | |||
class RunCWSPipeTest(unittest.TestCase): | |||
class TestRunCWSPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
dataset_names = ['msra'] | |||
for dataset_name in dataset_names: | |||
@@ -7,7 +7,7 @@ from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MN | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestPipe(unittest.TestCase): | |||
class TestMatchingPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]: | |||
with self.subTest(pipe=pipe): | |||
@@ -17,7 +17,7 @@ class TestPipe(unittest.TestCase): | |||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||
class TestBertPipe(unittest.TestCase): | |||
class TestMatchingBertPipe(unittest.TestCase): | |||
def test_process_from_file(self): | |||
for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]: | |||
with self.subTest(pipe=pipe): | |||
@@ -26,7 +26,7 @@ class TestBertPipe(unittest.TestCase): | |||
print(data_bundle) | |||
class TestRunPipe(unittest.TestCase): | |||
class TestRunMatchingPipe(unittest.TestCase): | |||
def test_load(self): | |||
for pipe in [RTEPipe, RTEBertPipe]: | |||