diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index d5549cec..7402a568 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -238,8 +238,8 @@ class CrossEntropyLoss(LossBase): pred = pred.tranpose(-1, pred) pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) - if seq_len is not None: - mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) + if seq_len is not None and target.dim()>1: + mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) target = target.masked_fill(mask, self.padding_idx) return F.cross_entropy(input=pred, target=target, diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index b06e5459..c0f14c90 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -347,7 +347,7 @@ class AccuracyMetric(MetricBase): pass elif pred.dim() == target.dim() + 1: pred = pred.argmax(dim=-1) - if seq_len is None: + if seq_len is None and target.dim()>1: warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") else: raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index f6c36623..08615fe0 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -68,7 +68,7 @@ class BertEmbedding(ContextualEmbedding): def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, - pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False): + pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False): super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: @@ -165,7 +165,7 @@ class BertWordPieceEncoder(nn.Module): """ def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, - word_dropout=0, dropout=0, requires_grad: bool = False): + word_dropout=0, dropout=0, requires_grad: bool = True): super().__init__() self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) @@ -288,7 +288,7 @@ class _WordBertModel(nn.Module): self.auto_truncate = auto_truncate # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] - logger.info("Start to generating word pieces for word.") + logger.info("Start to generate word pieces for word.") # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 found_count = 0 @@ -374,7 +374,8 @@ class _WordBertModel(nn.Module): else: raise RuntimeError( "After split words into word pieces, the lengths of word pieces are longer than the " - f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") + f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " + f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") # +2是由于需要加入[CLS]与[SEP] word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), @@ -407,15 +408,26 @@ class _WordBertModel(nn.Module): # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size if self.include_cls_sep: - outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, - bert_outputs[-1].size(-1)) s_shift = 1 + outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, + bert_outputs[-1].size(-1)) + else: + s_shift = 0 outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, bert_outputs[-1].size(-1)) - s_shift = 0 batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len + + if self.pool_method == 'first': + batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] + batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) + batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) + elif self.pool_method == 'last': + batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 + batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) + batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) + for l_index, l in enumerate(self.layers): output_layer = bert_outputs[l] real_word_piece_length = output_layer.size(1) - 2 @@ -426,16 +438,15 @@ class _WordBertModel(nn.Module): output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() # 从word_piece collapse到word的表示 truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size - outputs_seq_len = seq_len + s_shift if self.pool_method == 'first': - for i in range(batch_size): - i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置 - outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[ - i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size + tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] + tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) + outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp + elif self.pool_method == 'last': - for i in range(batch_size): - i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i] + 1] - 1 # 每个word的end - outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] + tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] + tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) + outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp elif self.pool_method == 'max': for i in range(batch_size): for j in range(seq_len[i]): @@ -452,5 +463,6 @@ class _WordBertModel(nn.Module): else: outputs[l_index, :, 0] = output_layer[:, 0] outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] + # 3. 最终的embedding结果 return outputs diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index 251b7292..6f727f05 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -24,6 +24,7 @@ __all__ = [ 'IMDBLoader', 'SSTLoader', 'SST2Loader', + "ChnSentiCorpLoader", 'ConllLoader', 'Conll2003Loader', @@ -52,8 +53,9 @@ __all__ = [ "SSTPipe", "SST2Pipe", "IMDBPipe", - "Conll2003Pipe", + "ChnSentiCorpPipe", + "Conll2003Pipe", "Conll2003NERPipe", "OntoNotesNERPipe", "MsraNERPipe", diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 3e7f39d3..19b48828 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -306,12 +306,15 @@ class DataBundle: return self def __repr__(self): - _str = 'In total {} datasets:\n'.format(len(self.datasets)) - for name, dataset in self.datasets.items(): - _str += '\t{} has {} instances.\n'.format(name, len(dataset)) - _str += 'In total {} vocabs:\n'.format(len(self.vocabs)) - for name, vocab in self.vocabs.items(): - _str += '\t{} has {} entries.\n'.format(name, len(vocab)) + _str = '' + if len(self.datasets): + _str += 'In total {} datasets:\n'.format(len(self.datasets)) + for name, dataset in self.datasets.items(): + _str += '\t{} has {} instances.\n'.format(name, len(dataset)) + if len(self.vocabs): + _str += 'In total {} vocabs:\n'.format(len(self.vocabs)) + for name, vocab in self.vocabs.items(): + _str += '\t{} has {} entries.\n'.format(name, len(vocab)) return _str diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 8ecdff25..f76bcd26 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -77,6 +77,9 @@ PRETRAIN_STATIC_FILES = { 'cn-tencent': "tencent_cn.zip", 'cn-fasttext': "cc.zh.300.vec.gz", 'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', + 'cn-char-fastnlp-100d': "cn_char_fastnlp_100d.zip", + 'cn-bi-fastnlp-100d': "cn_bi_fastnlp_100d.zip", + "cn-tri-fastnlp-100d": "cn_tri_fastnlp_100d.zip" } DATASET_DIR = { @@ -96,7 +99,9 @@ DATASET_DIR = { "cws-pku": 'cws_pku.zip', "cws-cityu": "cws_cityu.zip", "cws-as": 'cws_as.zip', - "cws-msra": 'cws_msra.zip' + "cws-msra": 'cws_msra.zip', + + "chn-senti-corp":"chn_senti_corp.zip" } PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index 6c23f213..3ad1b47d 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -52,6 +52,7 @@ __all__ = [ 'IMDBLoader', 'SSTLoader', 'SST2Loader', + "ChnSentiCorpLoader", 'ConllLoader', 'Conll2003Loader', @@ -73,7 +74,7 @@ __all__ = [ "QNLILoader", "RTELoader" ] -from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader +from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader from .csv import CSVLoader from .cws import CWSLoader diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index ec00d2b4..4ebd58e1 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -7,6 +7,7 @@ __all__ = [ "IMDBLoader", "SSTLoader", "SST2Loader", + "ChnSentiCorpLoader" ] import glob @@ -346,3 +347,59 @@ class SST2Loader(Loader): """ output_dir = self._get_dataset_path(dataset_name='sst-2') return output_dir + + +class ChnSentiCorpLoader(Loader): + """ + 支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 + 一个制表符及之后认为是句子 + + Example:: + + label raw_chars + 1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~ + 1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道... + 0 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货... + + 读取后的DataSet具有以下的field + + .. csv-table:: + :header: "raw_chars", "target" + + "這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" + "<荐书> 推荐所有喜欢<红楼>...", "1" + "..." + + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + """ + 从path中读取数据 + + :param path: + :return: + """ + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + f.readline() + for line in f: + line = line.strip() + tab_index = line.index('\t') + if tab_index!=-1: + target = line[:tab_index] + raw_chars = line[tab_index+1:] + if raw_chars: + ds.append(Instance(raw_chars=raw_chars, target=target)) + return ds + + def download(self)->str: + """ + 自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 + https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 + + :return: + """ + output_dir = self._get_dataset_path('chn-senti-corp') + return output_dir diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 048e4cfe..943709e7 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -17,6 +17,7 @@ __all__ = [ "SSTPipe", "SST2Pipe", "IMDBPipe", + "ChnSentiCorpPipe", "Conll2003NERPipe", "OntoNotesNERPipe", @@ -39,7 +40,7 @@ __all__ = [ "MNLIPipe", ] -from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe +from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index 30c591a4..d1c7aa0e 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -5,7 +5,8 @@ __all__ = [ "YelpPolarityPipe", "SSTPipe", "SST2Pipe", - 'IMDBPipe' + 'IMDBPipe', + "ChnSentiCorpPipe" ] import re @@ -13,18 +14,18 @@ import re from nltk import Tree from .pipe import Pipe -from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance +from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field from ..data_bundle import DataBundle from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader from ...core.const import Const from ...core.dataset import DataSet from ...core.instance import Instance from ...core.vocabulary import Vocabulary +from ..loader.classification import ChnSentiCorpLoader nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') - class _CLSPipe(Pipe): """ 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 @@ -457,3 +458,97 @@ class IMDBPipe(_CLSPipe): data_bundle = self.process(data_bundle) return data_bundle + + +class ChnSentiCorpPipe(Pipe): + """ + 处理之后的DataSet有以下的结构 + + .. csv-table:: + :header: "raw_chars", "chars", "target", "seq_len" + + "這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31 + "<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25 + "..." + + 其中chars, seq_len是input,target是target + + :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 + 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('bigrams')获取. + :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('trigrams')获取. + """ + def __init__(self, bigrams=False, trigrams=False): + super().__init__() + + self.bigrams = bigrams + self.trigrams = trigrams + + def _tokenize(self, data_bundle): + """ + 将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 + + :param data_bundle: + :return: + """ + data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) + return data_bundle + + def process(self, data_bundle:DataBundle): + """ + 可以处理的DataSet应该具备以下的field + + .. csv-table:: + :header: "raw_chars", "target" + + "這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" + "<荐书> 推荐所有喜欢<红楼>...", "1" + "..." + + :param data_bundle: + :return: + """ + _add_chars_field(data_bundle, lower=False) + + data_bundle = self._tokenize(data_bundle) + + input_field_names = [Const.CHAR_INPUT] + if self.bigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], + field_name=Const.CHAR_INPUT, new_field_name='bigrams') + input_field_names.append('bigrams') + if self.trigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], + field_name=Const.CHAR_INPUT, new_field_name='trigrams') + input_field_names.append('trigrams') + + # index + _indexize(data_bundle, input_field_names, Const.TARGET) + + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names + target_fields = [Const.TARGET] + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.CHAR_INPUT) + + data_bundle.set_input(*input_fields) + data_bundle.set_target(*target_fields) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = ChnSentiCorpLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle \ No newline at end of file diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index 2edc9008..a96b259a 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -222,14 +222,23 @@ class _CNNERPipe(Pipe): target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 + :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 + 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('bigrams')获取. + :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('trigrams')获取. """ - def __init__(self, encoding_type: str = 'bio'): + def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): if encoding_type == 'bio': self.convert_tag = iob2 else: self.convert_tag = lambda words: iob2bioes(iob2(words)) - + + self.bigrams = bigrams + self.trigrams = trigrams + def process(self, data_bundle: DataBundle) -> DataBundle: """ 支持的DataSet的field为 @@ -241,11 +250,11 @@ class _CNNERPipe(Pipe): "[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]" "[...]", "[...]" - raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 - target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 + raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], + 是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 - :param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 - 在传入DataBundle基础上原位修改。 + :param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field + 的内容均为List[str]。在传入DataBundle基础上原位修改。 :return: DataBundle """ # 转换tag @@ -253,11 +262,24 @@ class _CNNERPipe(Pipe): dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) _add_chars_field(data_bundle, lower=False) - + + input_field_names = [Const.CHAR_INPUT] + if self.bigrams: + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], + field_name=Const.CHAR_INPUT, new_field_name='bigrams') + input_field_names.append('bigrams') + if self.trigrams: + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], + field_name=Const.CHAR_INPUT, new_field_name='trigrams') + input_field_names.append('trigrams') + # index - _indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) + _indexize(data_bundle, input_field_names, Const.TARGET) - input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET, Const.INPUT_LEN] for name, dataset in data_bundle.datasets.items(): diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index 46ad74c3..6a4a0ffa 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -13,6 +13,12 @@ class TestDownload(unittest.TestCase): words = torch.LongTensor([[2, 3, 4, 0]]) print(embed(words).size()) + for pool_method in ['first', 'last', 'max', 'avg']: + for include_cls_sep in [True, False]: + embed = BertEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method, + include_cls_sep=include_cls_sep) + print(embed(words).size()) + def test_word_drop(self): vocab = Vocabulary().add_word_lst("This is a test .".split()) embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) diff --git a/test/io/loader/test_classification_loader.py b/test/io/loader/test_classification_loader.py index 1438a014..f099c1b2 100644 --- a/test/io/loader/test_classification_loader.py +++ b/test/io/loader/test_classification_loader.py @@ -5,22 +5,22 @@ from fastNLP.io.loader.classification import YelpPolarityLoader from fastNLP.io.loader.classification import IMDBLoader from fastNLP.io.loader.classification import SST2Loader from fastNLP.io.loader.classification import SSTLoader +from fastNLP.io.loader.classification import ChnSentiCorpLoader import os @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") class TestDownload(unittest.TestCase): def test_download(self): - for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: + for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]: loader().download() def test_load(self): - for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: + for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]: data_bundle = loader().load() print(data_bundle) class TestLoad(unittest.TestCase): - def test_load(self): for loader in [IMDBLoader]: data_bundle = loader().load('test/data_for_tests/io/imdb') diff --git a/test/io/loader/test_conll_loader.py b/test/io/loader/test_conll_loader.py index 861de5a5..31859a6b 100644 --- a/test/io/loader/test_conll_loader.py +++ b/test/io/loader/test_conll_loader.py @@ -5,7 +5,7 @@ from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNE Conll2003Loader -class MSRANERTest(unittest.TestCase): +class TestMSRANER(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_download(self): MsraNERLoader().download(re_download=False) @@ -13,13 +13,13 @@ class MSRANERTest(unittest.TestCase): print(data_bundle) -class PeopleDailyTest(unittest.TestCase): +class TestPeopleDaily(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_download(self): PeopleDailyNERLoader().download() -class WeiboNERTest(unittest.TestCase): +class TestWeiboNER(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_download(self): WeiboNERLoader().download() diff --git a/test/io/loader/test_cws_loader.py b/test/io/loader/test_cws_loader.py index 8b5d4081..55e48910 100644 --- a/test/io/loader/test_cws_loader.py +++ b/test/io/loader/test_cws_loader.py @@ -3,7 +3,7 @@ import os from fastNLP.io.loader import CWSLoader -class CWSLoaderTest(unittest.TestCase): +class TestCWSLoader(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_download(self): dataset_names = ['pku', 'cityu', 'as', 'msra'] @@ -13,7 +13,7 @@ class CWSLoaderTest(unittest.TestCase): print(data_bundle) -class RunCWSLoaderTest(unittest.TestCase): +class TestRunCWSLoader(unittest.TestCase): def test_cws_loader(self): dataset_names = ['msra'] for dataset_name in dataset_names: diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py index 652cf161..cb1334e0 100644 --- a/test/io/loader/test_matching_loader.py +++ b/test/io/loader/test_matching_loader.py @@ -8,7 +8,7 @@ from fastNLP.io.loader.matching import MNLILoader import os @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") -class TestDownload(unittest.TestCase): +class TestMatchingDownload(unittest.TestCase): def test_download(self): for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: loader().download() @@ -21,8 +21,7 @@ class TestDownload(unittest.TestCase): print(data_bundle) -class TestLoad(unittest.TestCase): - +class TestMatchingLoad(unittest.TestCase): def test_load(self): for loader in [RTELoader]: data_bundle = loader().load('test/data_for_tests/io/rte') diff --git a/test/io/pipe/test_classification.py b/test/io/pipe/test_classification.py index c6e2005e..45c276a3 100644 --- a/test/io/pipe/test_classification.py +++ b/test/io/pipe/test_classification.py @@ -2,9 +2,10 @@ import unittest import os from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe +from fastNLP.io.pipe.classification import ChnSentiCorpPipe @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") -class TestPipe(unittest.TestCase): +class TestClassificationPipe(unittest.TestCase): def test_process_from_file(self): for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: with self.subTest(pipe=pipe): @@ -14,8 +15,16 @@ class TestPipe(unittest.TestCase): class TestRunPipe(unittest.TestCase): - def test_load(self): for pipe in [IMDBPipe]: data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb') print(data_bundle) + + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestCNClassificationPipe(unittest.TestCase): + def test_process_from_file(self): + for pipe in [ChnSentiCorpPipe]: + with self.subTest(pipe=pipe): + data_bundle = pipe(bigrams=True, trigrams=True).process_from_file() + print(data_bundle) \ No newline at end of file diff --git a/test/io/pipe/test_conll.py b/test/io/pipe/test_conll.py index 6f6c4fad..4ecd7969 100644 --- a/test/io/pipe/test_conll.py +++ b/test/io/pipe/test_conll.py @@ -4,12 +4,14 @@ from fastNLP.io import MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, Conll2003Pipe @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") -class TestPipe(unittest.TestCase): +class TestConllPipe(unittest.TestCase): def test_process_from_file(self): for pipe in [MsraNERPipe, PeopleDailyPipe, WeiboNERPipe]: with self.subTest(pipe=pipe): print(pipe) - data_bundle = pipe().process_from_file() + data_bundle = pipe(bigrams=True, trigrams=True).process_from_file() + print(data_bundle) + data_bundle = pipe(encoding_type='bioes').process_from_file() print(data_bundle) diff --git a/test/io/pipe/test_cws.py b/test/io/pipe/test_cws.py index dd901a25..063b6d9a 100644 --- a/test/io/pipe/test_cws.py +++ b/test/io/pipe/test_cws.py @@ -4,7 +4,7 @@ import os from fastNLP.io.pipe.cws import CWSPipe -class CWSPipeTest(unittest.TestCase): +class TestCWSPipe(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_process_from_file(self): dataset_names = ['pku', 'cityu', 'as', 'msra'] @@ -14,7 +14,7 @@ class CWSPipeTest(unittest.TestCase): print(data_bundle) -class RunCWSPipeTest(unittest.TestCase): +class TestRunCWSPipe(unittest.TestCase): def test_process_from_file(self): dataset_names = ['msra'] for dataset_name in dataset_names: diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py index 33904e7a..932d8289 100644 --- a/test/io/pipe/test_matching.py +++ b/test/io/pipe/test_matching.py @@ -7,7 +7,7 @@ from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MN @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") -class TestPipe(unittest.TestCase): +class TestMatchingPipe(unittest.TestCase): def test_process_from_file(self): for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]: with self.subTest(pipe=pipe): @@ -17,7 +17,7 @@ class TestPipe(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") -class TestBertPipe(unittest.TestCase): +class TestMatchingBertPipe(unittest.TestCase): def test_process_from_file(self): for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]: with self.subTest(pipe=pipe): @@ -26,7 +26,7 @@ class TestBertPipe(unittest.TestCase): print(data_bundle) -class TestRunPipe(unittest.TestCase): +class TestRunMatchingPipe(unittest.TestCase): def test_load(self): for pipe in [RTEPipe, RTEBertPipe]: