|
|
@@ -27,7 +27,7 @@ class Trainer(object): |
|
|
|
""" |
|
|
|
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, |
|
|
|
dev_data=None, use_cuda=False, save_path="./save", |
|
|
|
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), |
|
|
|
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), need_check_code=True, |
|
|
|
**kwargs): |
|
|
|
super(Trainer, self).__init__() |
|
|
|
|
|
|
@@ -37,9 +37,13 @@ class Trainer(object): |
|
|
|
self.n_epochs = int(n_epochs) |
|
|
|
self.batch_size = int(batch_size) |
|
|
|
self.use_cuda = bool(use_cuda) |
|
|
|
self.save_path = str(save_path) |
|
|
|
self.save_path = save_path |
|
|
|
self.print_every = int(print_every) |
|
|
|
self.validate_every = int(validate_every) |
|
|
|
self._best_accuracy = 0 |
|
|
|
|
|
|
|
if need_check_code: |
|
|
|
_check_code(dataset=train_data, model=model, dev_data=dev_data) |
|
|
|
|
|
|
|
model_name = model.__class__.__name__ |
|
|
|
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) |
|
|
@@ -56,16 +60,11 @@ class Trainer(object): |
|
|
|
self.tester = Tester(model=self.model, |
|
|
|
data=self.dev_data, |
|
|
|
batch_size=self.batch_size, |
|
|
|
save_path=self.save_path, |
|
|
|
use_cuda=self.use_cuda) |
|
|
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
|
setattr(self, k, v) |
|
|
|
|
|
|
|
self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs') |
|
|
|
if os.path.exists(self.tensorboard_path): |
|
|
|
shutil.rmtree(self.tensorboard_path) |
|
|
|
self._graph_summaried = False |
|
|
|
self.step = 0 |
|
|
|
self.start_time = None # start timestamp |
|
|
|
|
|
|
@@ -77,8 +76,6 @@ class Trainer(object): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
try: |
|
|
|
self._summary_writer = SummaryWriter(self.tensorboard_path) |
|
|
|
|
|
|
|
if torch.cuda.is_available() and self.use_cuda: |
|
|
|
self.model = self.model.cuda() |
|
|
|
|
|
|
@@ -87,6 +84,9 @@ class Trainer(object): |
|
|
|
start = time.time() |
|
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) |
|
|
|
print("training epochs started " + self.start_time) |
|
|
|
if self.save_path is not None: |
|
|
|
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) |
|
|
|
self._summary_writer = SummaryWriter(path) |
|
|
|
|
|
|
|
epoch = 1 |
|
|
|
while epoch <= self.n_epochs: |
|
|
@@ -143,7 +143,8 @@ class Trainer(object): |
|
|
|
res = self.tester.test() |
|
|
|
for name, num in res.items(): |
|
|
|
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) |
|
|
|
self.save_model(self.model, 'best_model_' + self.start_time) |
|
|
|
if self.save_path is not None and self.best_eval_result(res): |
|
|
|
self.save_model(self.model, 'best_model_' + self.start_time) |
|
|
|
|
|
|
|
def mode(self, model, is_test=False): |
|
|
|
"""Train mode or Test mode. This is for PyTorch currently. |
|
|
@@ -166,9 +167,6 @@ class Trainer(object): |
|
|
|
def data_forward(self, network, x): |
|
|
|
x = _build_args(network.forward, **x) |
|
|
|
y = network(**x) |
|
|
|
if not self._graph_summaried: |
|
|
|
# self._summary_writer.add_graph(network, x, verbose=False) |
|
|
|
self._graph_summaried = True |
|
|
|
return y |
|
|
|
|
|
|
|
def grad_backward(self, loss): |
|
|
@@ -199,28 +197,27 @@ class Trainer(object): |
|
|
|
else: |
|
|
|
torch.save(model, model_name) |
|
|
|
|
|
|
|
def best_eval_result(self, metrics): |
|
|
|
"""Check if the current epoch yields better validation results. |
|
|
|
|
|
|
|
def best_eval_result(self, metrics): |
|
|
|
"""Check if the current epoch yields better validation results. |
|
|
|
|
|
|
|
:return: bool, True means current results on dev set is the best. |
|
|
|
""" |
|
|
|
if isinstance(metrics, tuple): |
|
|
|
loss, metrics = metrics |
|
|
|
:return: bool, True means current results on dev set is the best. |
|
|
|
""" |
|
|
|
if isinstance(metrics, tuple): |
|
|
|
loss, metrics = metrics |
|
|
|
|
|
|
|
if isinstance(metrics, dict): |
|
|
|
if len(metrics) == 1: |
|
|
|
accuracy = list(metrics.values())[0] |
|
|
|
if isinstance(metrics, dict): |
|
|
|
if len(metrics) == 1: |
|
|
|
accuracy = list(metrics.values())[0] |
|
|
|
else: |
|
|
|
accuracy = metrics[self.eval_sort_key] |
|
|
|
else: |
|
|
|
accuracy = metrics[self.eval_sort_key] |
|
|
|
else: |
|
|
|
accuracy = metrics |
|
|
|
accuracy = metrics |
|
|
|
|
|
|
|
if accuracy > self._best_accuracy: |
|
|
|
self._best_accuracy = accuracy |
|
|
|
return True |
|
|
|
else: |
|
|
|
return False |
|
|
|
if accuracy > self._best_accuracy: |
|
|
|
self._best_accuracy = accuracy |
|
|
|
return True |
|
|
|
else: |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
|
@@ -268,9 +265,6 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
|
loss.backward() |
|
|
|
if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE: |
|
|
|
break |
|
|
|
if check_level > IGNORE_CHECK_LEVEL: |
|
|
|
print('Finish checking training process.', flush=True) |
|
|
|
|
|
|
|
|
|
|
|
if dev_data is not None: |
|
|
|
if not hasattr(model, 'evaluate'): |
|
|
@@ -310,8 +304,6 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
|
func_signature = get_func_signature(model.evaluate) |
|
|
|
assert isinstance(metrics, dict), "The return value of {} should be dict.". \ |
|
|
|
format(func_signature) |
|
|
|
if check_level > IGNORE_CHECK_LEVEL: |
|
|
|
print("Finish checking evaluate process.", flush=True) |
|
|
|
|
|
|
|
|
|
|
|
def _check_forward_error(model_func, check_level, batch_x): |
|
|
|