|
|
@@ -21,9 +21,8 @@ class Trainer(object): |
|
|
|
|
|
|
|
""" |
|
|
|
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, |
|
|
|
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", |
|
|
|
dev_data=None, use_cuda=False, save_path="./save", |
|
|
|
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), |
|
|
|
evaluator=Evaluator(), |
|
|
|
**kwargs): |
|
|
|
super(Trainer, self).__init__() |
|
|
|
|
|
|
@@ -36,9 +35,16 @@ class Trainer(object): |
|
|
|
self.save_path = str(save_path) |
|
|
|
self.print_every = int(print_every) |
|
|
|
|
|
|
|
self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get() |
|
|
|
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) |
|
|
|
self.evaluator = evaluator |
|
|
|
model_name = model.__class__.__name__ |
|
|
|
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) |
|
|
|
self.loss_func = self.model.get_loss |
|
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
|
self.optimizer = optimizer |
|
|
|
else: |
|
|
|
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) |
|
|
|
|
|
|
|
assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name) |
|
|
|
self.evaluator = self.model.evaluate |
|
|
|
|
|
|
|
if self.dev_data is not None: |
|
|
|
valid_args = {"batch_size": self.batch_size, "save_path": self.save_path, |
|
|
@@ -48,7 +54,10 @@ class Trainer(object): |
|
|
|
for k, v in kwargs.items(): |
|
|
|
setattr(self, k, v) |
|
|
|
|
|
|
|
self._summary_writer = SummaryWriter(os.path.join(self.save_path, 'tensorboard_logs')) |
|
|
|
self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs') |
|
|
|
if os.path.exists(self.tensorboard_path): |
|
|
|
os.rmdir(self.tensorboard_path) |
|
|
|
self._summary_writer = SummaryWriter(self.tensorboard_path) |
|
|
|
self._graph_summaried = False |
|
|
|
self.step = 0 |
|
|
|
self.start_time = None # start timestamp |
|
|
@@ -138,6 +147,7 @@ class Trainer(object): |
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
def data_forward(self, network, x): |
|
|
|
x = _build_args(network.forward, **x) |
|
|
|
y = network(**x) |
|
|
|
if not self._graph_summaried: |
|
|
|
# self._summary_writer.add_graph(network, x, verbose=False) |
|
|
@@ -161,12 +171,9 @@ class Trainer(object): |
|
|
|
:param truth: ground truth label vector |
|
|
|
:return: a scalar |
|
|
|
""" |
|
|
|
if isinstance(predict, dict) and isinstance(truth, dict): |
|
|
|
return self.loss_func(**predict, **truth) |
|
|
|
if len(truth) > 1: |
|
|
|
raise NotImplementedError("Not ready to handle multi-labels.") |
|
|
|
truth = list(truth.values())[0] if len(truth) > 0 else None |
|
|
|
return self.loss_func(predict, truth) |
|
|
|
assert isinstance(predict, dict) and isinstance(truth, dict) |
|
|
|
args = _build_args(self.loss_func, **predict, **truth) |
|
|
|
return self.loss_func(**args) |
|
|
|
|
|
|
|
def save_model(self, model, model_name, only_param=False): |
|
|
|
model_name = os.path.join(self.save_path, model_name) |
|
|
|