Browse Source

!5 修复了一些错误

Merge pull request !5 from WillQvQ/dev
tags/v1.0.0alpha
WillQvQ Gitee 4 years ago
parent
commit
22c6e6d59c
100 changed files with 68 additions and 34 deletions
  1. +1
    -1
      .Jenkinsfile
  2. +1
    -1
      .travis.yml
  3. +1
    -1
      MANIFEST.in
  4. +18
    -6
      fastNLP/embeddings/bert_embedding.py
  5. +17
    -7
      fastNLP/embeddings/roberta_embedding.py
  6. +15
    -3
      fastNLP/modules/encoder/bert.py
  7. +2
    -2
      fastNLP/modules/encoder/roberta.py
  8. +1
    -1
      fastNLP/modules/encoder/seq2seq_encoder.py
  9. +0
    -0
      tests/__init__.py
  10. +0
    -0
      tests/core/__init__.py
  11. +0
    -0
      tests/core/test_batch.py
  12. +0
    -0
      tests/core/test_callbacks.py
  13. +1
    -1
      tests/core/test_dataset.py
  14. +0
    -0
      tests/core/test_dist_trainer.py
  15. +0
    -0
      tests/core/test_field.py
  16. +0
    -0
      tests/core/test_instance.py
  17. +0
    -0
      tests/core/test_logger.py
  18. +0
    -0
      tests/core/test_loss.py
  19. +0
    -0
      tests/core/test_metrics.py
  20. +0
    -0
      tests/core/test_optimizer.py
  21. +0
    -0
      tests/core/test_predictor.py
  22. +0
    -0
      tests/core/test_sampler.py
  23. +0
    -0
      tests/core/test_tester.py
  24. +0
    -0
      tests/core/test_trainer.py
  25. +11
    -11
      tests/core/test_utils.py
  26. +0
    -0
      tests/core/test_vocabulary.py
  27. +0
    -0
      tests/data_for_tests/config
  28. +0
    -0
      tests/data_for_tests/conll_2003_example.txt
  29. +0
    -0
      tests/data_for_tests/conll_example.txt
  30. +0
    -0
      tests/data_for_tests/cws_pku_utf_8
  31. +0
    -0
      tests/data_for_tests/cws_test
  32. +0
    -0
      tests/data_for_tests/cws_train
  33. +0
    -0
      tests/data_for_tests/embedding/small_bert/config.json
  34. +0
    -0
      tests/data_for_tests/embedding/small_bert/small_pytorch_model.bin
  35. +0
    -0
      tests/data_for_tests/embedding/small_bert/vocab.txt
  36. +0
    -0
      tests/data_for_tests/embedding/small_elmo/char.dic
  37. +0
    -0
      tests/data_for_tests/embedding/small_elmo/elmo_1x16_16_32cnn_1xhighway_options.json
  38. +0
    -0
      tests/data_for_tests/embedding/small_elmo/elmo_mini_for_testing.pkl
  39. +0
    -0
      tests/data_for_tests/embedding/small_gpt2/config.json
  40. +0
    -0
      tests/data_for_tests/embedding/small_gpt2/merges.txt
  41. +0
    -0
      tests/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin
  42. +0
    -0
      tests/data_for_tests/embedding/small_gpt2/vocab.json
  43. +0
    -0
      tests/data_for_tests/embedding/small_roberta/config.json
  44. +0
    -0
      tests/data_for_tests/embedding/small_roberta/merges.txt
  45. +0
    -0
      tests/data_for_tests/embedding/small_roberta/small_pytorch_model.bin
  46. +0
    -0
      tests/data_for_tests/embedding/small_roberta/vocab.json
  47. +0
    -0
      tests/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt
  48. +0
    -0
      tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt
  49. +0
    -0
      tests/data_for_tests/io/BQCorpus/dev.txt
  50. +0
    -0
      tests/data_for_tests/io/BQCorpus/test.txt
  51. +0
    -0
      tests/data_for_tests/io/BQCorpus/train.txt
  52. +0
    -0
      tests/data_for_tests/io/ChnSentiCorp/dev.txt
  53. +0
    -0
      tests/data_for_tests/io/ChnSentiCorp/test.txt
  54. +0
    -0
      tests/data_for_tests/io/ChnSentiCorp/train.txt
  55. +0
    -0
      tests/data_for_tests/io/LCQMC/dev.txt
  56. +0
    -0
      tests/data_for_tests/io/LCQMC/test.txt
  57. +0
    -0
      tests/data_for_tests/io/LCQMC/train.txt
  58. +0
    -0
      tests/data_for_tests/io/MNLI/dev_matched.tsv
  59. +0
    -0
      tests/data_for_tests/io/MNLI/dev_mismatched.tsv
  60. +0
    -0
      tests/data_for_tests/io/MNLI/test_matched.tsv
  61. +0
    -0
      tests/data_for_tests/io/MNLI/test_mismatched.tsv
  62. +0
    -0
      tests/data_for_tests/io/MNLI/train.tsv
  63. +0
    -0
      tests/data_for_tests/io/MSRA_NER/dev.conll
  64. +0
    -0
      tests/data_for_tests/io/MSRA_NER/test.conll
  65. +0
    -0
      tests/data_for_tests/io/MSRA_NER/train.conll
  66. +0
    -0
      tests/data_for_tests/io/OntoNotes/dev.txt
  67. +0
    -0
      tests/data_for_tests/io/OntoNotes/test.txt
  68. +0
    -0
      tests/data_for_tests/io/OntoNotes/train.txt
  69. +0
    -0
      tests/data_for_tests/io/QNLI/dev.tsv
  70. +0
    -0
      tests/data_for_tests/io/QNLI/test.tsv
  71. +0
    -0
      tests/data_for_tests/io/QNLI/train.tsv
  72. +0
    -0
      tests/data_for_tests/io/Quora/dev.tsv
  73. +0
    -0
      tests/data_for_tests/io/Quora/test.tsv
  74. +0
    -0
      tests/data_for_tests/io/Quora/train.tsv
  75. +0
    -0
      tests/data_for_tests/io/RTE/dev.tsv
  76. +0
    -0
      tests/data_for_tests/io/RTE/test.tsv
  77. +0
    -0
      tests/data_for_tests/io/RTE/train.tsv
  78. +0
    -0
      tests/data_for_tests/io/SNLI/snli_1.0_dev.jsonl
  79. +0
    -0
      tests/data_for_tests/io/SNLI/snli_1.0_test.jsonl
  80. +0
    -0
      tests/data_for_tests/io/SNLI/snli_1.0_train.jsonl
  81. +0
    -0
      tests/data_for_tests/io/SST-2/dev.tsv
  82. +0
    -0
      tests/data_for_tests/io/SST-2/test.tsv
  83. +0
    -0
      tests/data_for_tests/io/SST-2/train.tsv
  84. +0
    -0
      tests/data_for_tests/io/SST/dev.txt
  85. +0
    -0
      tests/data_for_tests/io/SST/test.txt
  86. +0
    -0
      tests/data_for_tests/io/SST/train.txt
  87. +0
    -0
      tests/data_for_tests/io/THUCNews/dev.txt
  88. +0
    -0
      tests/data_for_tests/io/THUCNews/test.txt
  89. +0
    -0
      tests/data_for_tests/io/THUCNews/train.txt
  90. +0
    -0
      tests/data_for_tests/io/WeiboSenti100k/dev.txt
  91. +0
    -0
      tests/data_for_tests/io/WeiboSenti100k/test.txt
  92. +0
    -0
      tests/data_for_tests/io/WeiboSenti100k/train.txt
  93. +0
    -0
      tests/data_for_tests/io/XNLI/dev.txt
  94. +0
    -0
      tests/data_for_tests/io/XNLI/test.txt
  95. +0
    -0
      tests/data_for_tests/io/XNLI/train.txt
  96. +0
    -0
      tests/data_for_tests/io/ag/test.csv
  97. +0
    -0
      tests/data_for_tests/io/ag/train.csv
  98. +0
    -0
      tests/data_for_tests/io/cmrc/dev.json
  99. +0
    -0
      tests/data_for_tests/io/cmrc/train.json
  100. +0
    -0
      tests/data_for_tests/io/cnndm/dev.label.jsonl

+ 1
- 1
.Jenkinsfile View File

@@ -29,7 +29,7 @@ pipeline {
steps {
sh 'python -m spacy download en'
sh 'pip install fitlog'
sh 'pytest ./test --html=test_results.html --self-contained-html'
sh 'pytest ./tests --html=test_results.html --self-contained-html'
}
}
}


+ 1
- 1
.travis.yml View File

@@ -14,7 +14,7 @@ install:
# command to run tests
script:
- python -m spacy download en
- pytest --cov=fastNLP test/
- pytest --cov=fastNLP tests/

after_success:
- bash <(curl -s https://codecov.io/bash)


+ 1
- 1
MANIFEST.in View File

@@ -1,7 +1,7 @@
include requirements.txt
include LICENSE
include README.md
prune test/
prune tests/
prune reproduction/
prune fastNLP/api
prune fastNLP/automl

+ 18
- 6
fastNLP/embeddings/bert_embedding.py View File

@@ -93,7 +93,7 @@ class BertEmbedding(ContextualEmbedding):
"""
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

if word_dropout>0:
if word_dropout > 0:
assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token."

if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
@@ -370,17 +370,29 @@ class _BertWordModel(nn.Module):
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
super().__init__()

self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
if isinstance(layers, list):
self.layers = [int(l) for l in layers]
elif isinstance(layers, str):
self.layers = list(map(int, layers.split(',')))
else:
raise TypeError("`layers` only supports str or list[int]")
assert len(self.layers) > 0, "There is no layer selected!"

neg_num_output_layer = -16384
pos_num_output_layer = 0
for layer in self.layers:
if layer < 0:
neg_num_output_layer = max(layer, neg_num_output_layer)
else:
pos_num_output_layer = max(layer, pos_num_output_layer)

self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name,
neg_num_output_layer=neg_num_output_layer,
pos_num_output_layer=pos_num_output_layer)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
for layer in self.layers:
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \


+ 17
- 7
fastNLP/embeddings/roberta_embedding.py View File

@@ -196,20 +196,30 @@ class _RobertaWordModel(nn.Module):
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
super().__init__()

self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name)
self.encoder = RobertaModel.from_pretrained(model_dir_or_name)
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)

if isinstance(layers, list):
self.layers = [int(l) for l in layers]
elif isinstance(layers, str):
self.layers = list(map(int, layers.split(',')))
else:
raise TypeError("`layers` only supports str or list[int]")
assert len(self.layers) > 0, "There is no layer selected!"

neg_num_output_layer = -16384
pos_num_output_layer = 0
for layer in self.layers:
if layer < 0:
neg_num_output_layer = max(layer, neg_num_output_layer)
else:
pos_num_output_layer = max(layer, pos_num_output_layer)

self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name)
self.encoder = RobertaModel.from_pretrained(model_dir_or_name,
neg_num_output_layer=neg_num_output_layer,
pos_num_output_layer=pos_num_output_layer)
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
for layer in self.layers:
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \


+ 15
- 3
fastNLP/modules/encoder/bert.py View File

@@ -366,19 +366,28 @@ class BertLayer(nn.Module):


class BertEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config, num_output_layer=-1):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
num_output_layer = num_output_layer if num_output_layer >= 0 else (len(self.layer) + num_output_layer)
self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0)
if self.num_output_layer + 1 < len(self.layer):
logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} '
f'(start from 0)!')

def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
for idx, layer_module in enumerate(self.layer):
if idx > self.num_output_layer:
break
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if len(all_encoder_layers) == 0:
all_encoder_layers.append(hidden_states)
return all_encoder_layers


@@ -435,6 +444,9 @@ class BertModel(nn.Module):
self.config = config
self.hidden_size = self.config.hidden_size
self.model_type = 'bert'
neg_num_output_layer = kwargs.get('neg_num_output_layer', -1)
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1)
self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer)
if hasattr(config, 'sinusoidal_pos_embds'):
self.model_type = 'distilbert'
elif 'model_type' in kwargs:
@@ -445,7 +457,7 @@ class BertModel(nn.Module):
else:
self.embeddings = BertEmbeddings(config)

self.encoder = BertEncoder(config)
self.encoder = BertEncoder(config, num_output_layer=self.num_output_layer)
if self.model_type != 'distilbert':
self.pooler = BertPooler(config)
else:


+ 2
- 2
fastNLP/modules/encoder/roberta.py View File

@@ -64,8 +64,8 @@ class RobertaModel(BertModel):
undocumented
"""

def __init__(self, config):
super().__init__(config)
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

self.embeddings = RobertaEmbeddings(config)
self.apply(self.init_bert_weights)


+ 1
- 1
fastNLP/modules/encoder/seq2seq_encoder.py View File

@@ -132,7 +132,7 @@ class TransformerSeq2SeqEncoder(Seq2SeqEncoder):
x = self.input_fc(x)
x = F.dropout(x, p=self.dropout, training=self.training)

encoder_mask = seq_len_to_mask(seq_len)
encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len)
encoder_mask = encoder_mask.to(device)

for layer in self.layer_stacks:


test/__init__.py → tests/__init__.py View File


test/core/__init__.py → tests/core/__init__.py View File


test/core/test_batch.py → tests/core/test_batch.py View File


test/core/test_callbacks.py → tests/core/test_callbacks.py View File


test/core/test_dataset.py → tests/core/test_dataset.py View File

@@ -228,7 +228,7 @@ class TestDataSetMethods(unittest.TestCase):
def split_sent(ins):
return ins['raw_sentence'].split()
csv_loader = CSVLoader(headers=['raw_sentence', 'label'], sep='\t')
data_bundle = csv_loader.load('test/data_for_tests/tutorial_sample_dataset.csv')
data_bundle = csv_loader.load('tests/data_for_tests/tutorial_sample_dataset.csv')
dataset = data_bundle.datasets['train']
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True)
dataset.apply(split_sent, new_field_name='words', is_input=True)

test/core/test_dist_trainer.py → tests/core/test_dist_trainer.py View File


test/core/test_field.py → tests/core/test_field.py View File


test/core/test_instance.py → tests/core/test_instance.py View File


test/core/test_logger.py → tests/core/test_logger.py View File


test/core/test_loss.py → tests/core/test_loss.py View File


test/core/test_metrics.py → tests/core/test_metrics.py View File


test/core/test_optimizer.py → tests/core/test_optimizer.py View File


test/core/test_predictor.py → tests/core/test_predictor.py View File


test/core/test_sampler.py → tests/core/test_sampler.py View File


test/core/test_tester.py → tests/core/test_tester.py View File


test/core/test_trainer.py → tests/core/test_trainer.py View File


test/core/test_utils.py → tests/core/test_utils.py View File

@@ -120,8 +120,8 @@ class TestCache(unittest.TestCase):
def test_cache_save(self):
try:
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train')
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'tests/data_for_tests/cws_train')
end_time = time.time()
pre_time = end_time - start_time
with open('test/demo1.pkl', 'rb') as f:
@@ -130,8 +130,8 @@ class TestCache(unittest.TestCase):
for i in range(embed.shape[0]):
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train')
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'tests/data_for_tests/cws_train')
end_time = time.time()
read_time = end_time - start_time
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
@@ -142,7 +142,7 @@ class TestCache(unittest.TestCase):
def test_cache_save_overwrite_path(self):
try:
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', 'test/data_for_tests/cws_train',
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt', 'tests/data_for_tests/cws_train',
_cache_fp='test/demo_overwrite.pkl')
end_time = time.time()
pre_time = end_time - start_time
@@ -152,8 +152,8 @@ class TestCache(unittest.TestCase):
for i in range(embed.shape[0]):
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train',
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'tests/data_for_tests/cws_train',
_cache_fp='test/demo_overwrite.pkl')
end_time = time.time()
read_time = end_time - start_time
@@ -165,8 +165,8 @@ class TestCache(unittest.TestCase):
def test_cache_refresh(self):
try:
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train',
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'tests/data_for_tests/cws_train',
_refresh=True)
end_time = time.time()
pre_time = end_time - start_time
@@ -176,8 +176,8 @@ class TestCache(unittest.TestCase):
for i in range(embed.shape[0]):
self.assertListEqual(embed[i].tolist(), _embed[i].tolist())
start_time = time.time()
embed, vocab, d = process_data_1('test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'test/data_for_tests/cws_train',
embed, vocab, d = process_data_1('tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt',
'tests/data_for_tests/cws_train',
_refresh=True)
end_time = time.time()
read_time = end_time - start_time

test/core/test_vocabulary.py → tests/core/test_vocabulary.py View File


test/data_for_tests/config → tests/data_for_tests/config View File


test/data_for_tests/conll_2003_example.txt → tests/data_for_tests/conll_2003_example.txt View File


test/data_for_tests/conll_example.txt → tests/data_for_tests/conll_example.txt View File


test/data_for_tests/cws_pku_utf_8 → tests/data_for_tests/cws_pku_utf_8 View File


test/data_for_tests/cws_test → tests/data_for_tests/cws_test View File


test/data_for_tests/cws_train → tests/data_for_tests/cws_train View File


test/data_for_tests/embedding/small_bert/config.json → tests/data_for_tests/embedding/small_bert/config.json View File


test/data_for_tests/embedding/small_bert/small_pytorch_model.bin → tests/data_for_tests/embedding/small_bert/small_pytorch_model.bin View File


test/data_for_tests/embedding/small_bert/vocab.txt → tests/data_for_tests/embedding/small_bert/vocab.txt View File


test/data_for_tests/embedding/small_elmo/char.dic → tests/data_for_tests/embedding/small_elmo/char.dic View File


test/data_for_tests/embedding/small_elmo/elmo_1x16_16_32cnn_1xhighway_options.json → tests/data_for_tests/embedding/small_elmo/elmo_1x16_16_32cnn_1xhighway_options.json View File


test/data_for_tests/embedding/small_elmo/elmo_mini_for_testing.pkl → tests/data_for_tests/embedding/small_elmo/elmo_mini_for_testing.pkl View File


test/data_for_tests/embedding/small_gpt2/config.json → tests/data_for_tests/embedding/small_gpt2/config.json View File


test/data_for_tests/embedding/small_gpt2/merges.txt → tests/data_for_tests/embedding/small_gpt2/merges.txt View File


test/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin → tests/data_for_tests/embedding/small_gpt2/small_pytorch_model.bin View File


test/data_for_tests/embedding/small_gpt2/vocab.json → tests/data_for_tests/embedding/small_gpt2/vocab.json View File


test/data_for_tests/embedding/small_roberta/config.json → tests/data_for_tests/embedding/small_roberta/config.json View File


test/data_for_tests/embedding/small_roberta/merges.txt → tests/data_for_tests/embedding/small_roberta/merges.txt View File


test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin → tests/data_for_tests/embedding/small_roberta/small_pytorch_model.bin View File


test/data_for_tests/embedding/small_roberta/vocab.json → tests/data_for_tests/embedding/small_roberta/vocab.json View File


test/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt → tests/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt View File


test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt → tests/data_for_tests/embedding/small_static_embedding/word2vec_test.txt View File


test/data_for_tests/io/BQCorpus/dev.txt → tests/data_for_tests/io/BQCorpus/dev.txt View File


test/data_for_tests/io/BQCorpus/test.txt → tests/data_for_tests/io/BQCorpus/test.txt View File


test/data_for_tests/io/BQCorpus/train.txt → tests/data_for_tests/io/BQCorpus/train.txt View File


test/data_for_tests/io/ChnSentiCorp/dev.txt → tests/data_for_tests/io/ChnSentiCorp/dev.txt View File


test/data_for_tests/io/ChnSentiCorp/test.txt → tests/data_for_tests/io/ChnSentiCorp/test.txt View File


test/data_for_tests/io/ChnSentiCorp/train.txt → tests/data_for_tests/io/ChnSentiCorp/train.txt View File


test/data_for_tests/io/LCQMC/dev.txt → tests/data_for_tests/io/LCQMC/dev.txt View File


test/data_for_tests/io/LCQMC/test.txt → tests/data_for_tests/io/LCQMC/test.txt View File


test/data_for_tests/io/LCQMC/train.txt → tests/data_for_tests/io/LCQMC/train.txt View File


test/data_for_tests/io/MNLI/dev_matched.tsv → tests/data_for_tests/io/MNLI/dev_matched.tsv View File


test/data_for_tests/io/MNLI/dev_mismatched.tsv → tests/data_for_tests/io/MNLI/dev_mismatched.tsv View File


test/data_for_tests/io/MNLI/test_matched.tsv → tests/data_for_tests/io/MNLI/test_matched.tsv View File


test/data_for_tests/io/MNLI/test_mismatched.tsv → tests/data_for_tests/io/MNLI/test_mismatched.tsv View File


test/data_for_tests/io/MNLI/train.tsv → tests/data_for_tests/io/MNLI/train.tsv View File


test/data_for_tests/io/MSRA_NER/dev.conll → tests/data_for_tests/io/MSRA_NER/dev.conll View File


test/data_for_tests/io/MSRA_NER/test.conll → tests/data_for_tests/io/MSRA_NER/test.conll View File


test/data_for_tests/io/MSRA_NER/train.conll → tests/data_for_tests/io/MSRA_NER/train.conll View File


test/data_for_tests/io/OntoNotes/dev.txt → tests/data_for_tests/io/OntoNotes/dev.txt View File


test/data_for_tests/io/OntoNotes/test.txt → tests/data_for_tests/io/OntoNotes/test.txt View File


test/data_for_tests/io/OntoNotes/train.txt → tests/data_for_tests/io/OntoNotes/train.txt View File


test/data_for_tests/io/QNLI/dev.tsv → tests/data_for_tests/io/QNLI/dev.tsv View File


test/data_for_tests/io/QNLI/test.tsv → tests/data_for_tests/io/QNLI/test.tsv View File


test/data_for_tests/io/QNLI/train.tsv → tests/data_for_tests/io/QNLI/train.tsv View File


test/data_for_tests/io/Quora/dev.tsv → tests/data_for_tests/io/Quora/dev.tsv View File


test/data_for_tests/io/Quora/test.tsv → tests/data_for_tests/io/Quora/test.tsv View File


test/data_for_tests/io/Quora/train.tsv → tests/data_for_tests/io/Quora/train.tsv View File


test/data_for_tests/io/RTE/dev.tsv → tests/data_for_tests/io/RTE/dev.tsv View File


test/data_for_tests/io/RTE/test.tsv → tests/data_for_tests/io/RTE/test.tsv View File


test/data_for_tests/io/RTE/train.tsv → tests/data_for_tests/io/RTE/train.tsv View File


test/data_for_tests/io/SNLI/snli_1.0_dev.jsonl → tests/data_for_tests/io/SNLI/snli_1.0_dev.jsonl View File


test/data_for_tests/io/SNLI/snli_1.0_test.jsonl → tests/data_for_tests/io/SNLI/snli_1.0_test.jsonl View File


test/data_for_tests/io/SNLI/snli_1.0_train.jsonl → tests/data_for_tests/io/SNLI/snli_1.0_train.jsonl View File


test/data_for_tests/io/SST-2/dev.tsv → tests/data_for_tests/io/SST-2/dev.tsv View File


test/data_for_tests/io/SST-2/test.tsv → tests/data_for_tests/io/SST-2/test.tsv View File


test/data_for_tests/io/SST-2/train.tsv → tests/data_for_tests/io/SST-2/train.tsv View File


test/data_for_tests/io/SST/dev.txt → tests/data_for_tests/io/SST/dev.txt View File


test/data_for_tests/io/SST/test.txt → tests/data_for_tests/io/SST/test.txt View File


test/data_for_tests/io/SST/train.txt → tests/data_for_tests/io/SST/train.txt View File


test/data_for_tests/io/THUCNews/dev.txt → tests/data_for_tests/io/THUCNews/dev.txt View File


test/data_for_tests/io/THUCNews/test.txt → tests/data_for_tests/io/THUCNews/test.txt View File


test/data_for_tests/io/THUCNews/train.txt → tests/data_for_tests/io/THUCNews/train.txt View File


test/data_for_tests/io/WeiboSenti100k/dev.txt → tests/data_for_tests/io/WeiboSenti100k/dev.txt View File


test/data_for_tests/io/WeiboSenti100k/test.txt → tests/data_for_tests/io/WeiboSenti100k/test.txt View File


test/data_for_tests/io/WeiboSenti100k/train.txt → tests/data_for_tests/io/WeiboSenti100k/train.txt View File


test/data_for_tests/io/XNLI/dev.txt → tests/data_for_tests/io/XNLI/dev.txt View File


test/data_for_tests/io/XNLI/test.txt → tests/data_for_tests/io/XNLI/test.txt View File


test/data_for_tests/io/XNLI/train.txt → tests/data_for_tests/io/XNLI/train.txt View File


test/data_for_tests/io/ag/test.csv → tests/data_for_tests/io/ag/test.csv View File


test/data_for_tests/io/ag/train.csv → tests/data_for_tests/io/ag/train.csv View File


test/data_for_tests/io/cmrc/dev.json → tests/data_for_tests/io/cmrc/dev.json View File


test/data_for_tests/io/cmrc/train.json → tests/data_for_tests/io/cmrc/train.json View File


test/data_for_tests/io/cnndm/dev.label.jsonl → tests/data_for_tests/io/cnndm/dev.label.jsonl View File


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save