|
|
@@ -203,7 +203,7 @@ callbacks.append( |
|
|
|
def train(model,datainfo,loss,metrics,optimizer,num_epochs=100): |
|
|
|
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size, |
|
|
|
metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1, |
|
|
|
n_epochs=num_epochs) |
|
|
|
n_epochs=num_epochs,callbacks=callbacks) |
|
|
|
print(trainer.train()) |
|
|
|
|
|
|
|
|
|
|
|