Browse Source

Delete matching.py

tags/v0.4.10
xuyige 6 years ago
parent
commit
bc5e071253
1 changed files with 0 additions and 44 deletions
  1. +0
    -44
      reproduction/matching/matching.py

+ 0
- 44
reproduction/matching/matching.py View File

@@ -1,44 +0,0 @@
import os

import torch

from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const

from fastNLP.io.dataset_loader import MatchingLoader

from reproduction.matching.model.bert import BertForNLI
from reproduction.matching.model.esim import ESIMModel


bert_dirs = 'path/to/bert/dir'

# load data set
# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(...
data_info = MatchingLoader(data_format='snli', for_model='esim').process(
{'train': './data/snli/snli_1.0_train.jsonl',
'dev': './data/snli/snli_1.0_dev.jsonl',
'test': './data/snli/snli_1.0_test.jsonl'},
input_field=[Const.TARGET]
)

# model = BertForNLI(bert_dir=bert_dirs)
model = ESIMModel(data_info.embeddings['elmo'],)

trainer = Trainer(train_data=data_info.datasets['train'], model=model,
optimizer=Adam(lr=1e-4, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1,
dev_data=data_info.datasets['dev'],
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(
data=data_info.datasets['test'],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



Loading…
Cancel
Save