|
|
@@ -27,7 +27,7 @@ class Trainer(object): |
|
|
|
""" |
|
|
|
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, |
|
|
|
dev_data=None, use_cuda=False, save_path="./save", |
|
|
|
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), need_check_code=True, |
|
|
|
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, |
|
|
|
**kwargs): |
|
|
|
super(Trainer, self).__init__() |
|
|
|
|
|
|
@@ -84,7 +84,14 @@ class Trainer(object): |
|
|
|
start = time.time() |
|
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) |
|
|
|
print("training epochs started " + self.start_time) |
|
|
|
if self.save_path is not None: |
|
|
|
if self.save_path is None: |
|
|
|
class psudoSW: |
|
|
|
def __getattr__(self, item): |
|
|
|
def pass_func(*args, **kwargs): |
|
|
|
pass |
|
|
|
return pass_func |
|
|
|
self._summary_writer = psudoSW() |
|
|
|
else: |
|
|
|
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) |
|
|
|
self._summary_writer = SummaryWriter(path) |
|
|
|
|
|
|
@@ -98,7 +105,6 @@ class Trainer(object): |
|
|
|
# validate_every override validation at end of epochs |
|
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
|
self.do_validation() |
|
|
|
self.save_model(self.model, 'training_model_' + self.start_time) |
|
|
|
epoch += 1 |
|
|
|
finally: |
|
|
|
self._summary_writer.close() |
|
|
|