@@ -330,7 +330,10 @@ class JsonLoader(DataSetLoader): | |||||
def load(self, path): | def load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | ||||
ins = {self.fields[k]:v for k,v in d.items()} | |||||
if self.fields: | |||||
ins = {self.fields[k]:v for k,v in d.items()} | |||||
else: | |||||
ins = d | |||||
ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
return ds | return ds | ||||
@@ -142,13 +142,12 @@ class TestCase1(unittest.TestCase): | |||||
def test_sequential_batch(self): | def test_sequential_batch(self): | ||||
batch_size = 32 | batch_size = 32 | ||||
pause_seconds = 0.01 | |||||
num_samples = 1000 | num_samples = 1000 | ||||
dataset = generate_fake_dataset(num_samples) | dataset = generate_fake_dataset(num_samples) | ||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
time.sleep(pause_seconds) | |||||
pass | |||||
""" | """ | ||||
def test_multi_workers_batch(self): | def test_multi_workers_batch(self): | ||||
@@ -132,6 +132,19 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
def test_seq_len(self): | |||||
N = 256 | |||||
seq_len = torch.zeros(N).long() | |||||
seq_len[0] = 2 | |||||
pred = {'pred': torch.ones(N, 2)} | |||||
target = {'target': torch.ones(N, 2), 'seq_len': seq_len} | |||||
metric = AccuracyMetric() | |||||
metric(pred_dict=pred, target_dict=target) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | |||||
seq_len[1:] = 1 | |||||
metric(pred_dict=pred, target_dict=target) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | |||||
class SpanF1PreRecMetric(unittest.TestCase): | class SpanF1PreRecMetric(unittest.TestCase): | ||||
def test_case1(self): | def test_case1(self): | ||||
from fastNLP.core.metrics import _bmes_tag_to_spans | from fastNLP.core.metrics import _bmes_tag_to_spans | ||||
@@ -77,13 +77,13 @@ def init_data(): | |||||
class TestBiaffineParser(unittest.TestCase): | class TestBiaffineParser(unittest.TestCase): | ||||
def test_train(self): | def test_train(self): | ||||
ds, v1, v2, v3 = init_data() | ds, v1, v2, v3 = init_data() | ||||
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | |||||
model = BiaffineParser(init_embed=(len(v1), 30), | |||||
pos_vocab_size=len(v2), pos_emb_dim=30, | pos_vocab_size=len(v2), pos_emb_dim=30, | ||||
num_label=len(v3), encoder='var-lstm') | num_label=len(v3), encoder='var-lstm') | ||||
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | ||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | ||||
batch_size=1, validate_every=10, | batch_size=1, validate_every=10, | ||||
n_epochs=10, use_cuda=False, use_tqdm=False) | |||||
n_epochs=10, use_tqdm=False) | |||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||