|
|
@@ -77,13 +77,13 @@ def init_data(): |
|
|
|
class TestBiaffineParser(unittest.TestCase): |
|
|
|
def test_train(self): |
|
|
|
ds, v1, v2, v3 = init_data() |
|
|
|
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, |
|
|
|
model = BiaffineParser(init_embed=(len(v1), 30), |
|
|
|
pos_vocab_size=len(v2), pos_emb_dim=30, |
|
|
|
num_label=len(v3), encoder='var-lstm') |
|
|
|
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, |
|
|
|
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', |
|
|
|
batch_size=1, validate_every=10, |
|
|
|
n_epochs=10, use_cuda=False, use_tqdm=False) |
|
|
|
n_epochs=10, use_tqdm=False) |
|
|
|
trainer.train(load_best_model=False) |
|
|
|
|
|
|
|
|
|
|
|