Browse Source

- add test

- fix jsonloader
tags/v0.4.10
yunfan 5 years ago
parent
commit
799d4dbc68
4 changed files with 20 additions and 5 deletions
  1. +4
    -1
      fastNLP/io/dataset_loader.py
  2. +1
    -2
      test/core/test_batch.py
  3. +13
    -0
      test/core/test_metrics.py
  4. +2
    -2
      test/models/test_biaffine_parser.py

+ 4
- 1
fastNLP/io/dataset_loader.py View File

@@ -330,7 +330,10 @@ class JsonLoader(DataSetLoader):
def load(self, path): def load(self, path):
ds = DataSet() ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
ins = {self.fields[k]:v for k,v in d.items()}
if self.fields:
ins = {self.fields[k]:v for k,v in d.items()}
else:
ins = d
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds




+ 1
- 2
test/core/test_batch.py View File

@@ -142,13 +142,12 @@ class TestCase1(unittest.TestCase):


def test_sequential_batch(self): def test_sequential_batch(self):
batch_size = 32 batch_size = 32
pause_seconds = 0.01
num_samples = 1000 num_samples = 1000
dataset = generate_fake_dataset(num_samples) dataset = generate_fake_dataset(num_samples)


batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
time.sleep(pause_seconds)
pass


""" """
def test_multi_workers_batch(self): def test_multi_workers_batch(self):


+ 13
- 0
test/core/test_metrics.py View File

@@ -132,6 +132,19 @@ class TestAccuracyMetric(unittest.TestCase):
return return
self.assertTrue(True, False), "No exception catches." self.assertTrue(True, False), "No exception catches."


def test_seq_len(self):
N = 256
seq_len = torch.zeros(N).long()
seq_len[0] = 2
pred = {'pred': torch.ones(N, 2)}
target = {'target': torch.ones(N, 2), 'seq_len': seq_len}
metric = AccuracyMetric()
metric(pred_dict=pred, target_dict=target)
self.assertDictEqual(metric.get_metric(), {'acc': 1.})
seq_len[1:] = 1
metric(pred_dict=pred, target_dict=target)
self.assertDictEqual(metric.get_metric(), {'acc': 1.})

class SpanF1PreRecMetric(unittest.TestCase): class SpanF1PreRecMetric(unittest.TestCase):
def test_case1(self): def test_case1(self):
from fastNLP.core.metrics import _bmes_tag_to_spans from fastNLP.core.metrics import _bmes_tag_to_spans


+ 2
- 2
test/models/test_biaffine_parser.py View File

@@ -77,13 +77,13 @@ def init_data():
class TestBiaffineParser(unittest.TestCase): class TestBiaffineParser(unittest.TestCase):
def test_train(self): def test_train(self):
ds, v1, v2, v3 = init_data() ds, v1, v2, v3 = init_data()
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30,
model = BiaffineParser(init_embed=(len(v1), 30),
pos_vocab_size=len(v2), pos_emb_dim=30, pos_vocab_size=len(v2), pos_emb_dim=30,
num_label=len(v3), encoder='var-lstm') num_label=len(v3), encoder='var-lstm')
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds,
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
batch_size=1, validate_every=10, batch_size=1, validate_every=10,
n_epochs=10, use_cuda=False, use_tqdm=False)
n_epochs=10, use_tqdm=False)
trainer.train(load_best_model=False) trainer.train(load_best_model=False)






Loading…
Cancel
Save