|
- import os
- import logging
- import torch
- import torch.nn as nn
- import numpy as np
- from collections import OrderedDict
- import json
- from pytorch.callbacks import LRSchedulerCallback
- from pytorch.trainer import BaseTrainer, TorchTensorEncoder
- from pytorch.utils import dump_global_result
-
- from model import CNN
- from pdartsmutator import PdartsMutator
- from pytorch.darts.utils import accuracy
- from pytorch.darts import datasets
- from pytorch.darts.dartstrainer import DartsTrainer
-
- logger = logging.getLogger(__name__)
-
- class PdartsTrainer(BaseTrainer):
- """
- This trainer implements the PDARTS algorithm.
- PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
- This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
- pdarts_num_layers means how many layers more than first epoch.
- pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
- So that the grew network can in similar size.
- """
-
- def __init__(self, init_layers, pdarts_num_layers, pdarts_num_to_drop, pdarts_dropout_rates, num_epochs, num_pre_epochs, model_lr, class_num,
- arch_lr, channels, batch_size, result_path, log_frequency, unrolled, data_dir, search_space_path,
- best_selected_space_path, device=None, workers=4):
- super(PdartsTrainer, self).__init__()
- self.init_layers = init_layers
- self.class_num = class_num
- self.channels = channels
- self.model_lr = model_lr
- self.num_epochs = num_epochs
- self.class_num = class_num
- self.pdarts_num_layers = pdarts_num_layers
- self.pdarts_num_to_drop = pdarts_num_to_drop
- self.pdarts_dropout_rates = pdarts_dropout_rates
- self.pdarts_epoches = len(pdarts_num_to_drop)
- self.search_space_path = search_space_path
- self.best_selected_space_path = best_selected_space_path
-
- logger.info("loading data")
- dataset_train, dataset_valid = datasets.get_dataset(
- "cifar10", root=data_dir)
- self.darts_parameters = {
- "metrics": lambda output, target: accuracy(output, target, topk=(1,)),
- "arch_lr": arch_lr,
- "num_epochs": num_epochs,
- "num_pre_epochs": num_pre_epochs,
- "dataset_train": dataset_train,
- "dataset_valid": dataset_valid,
- "batch_size": batch_size,
- "result_path": result_path,
- "workers": workers,
- "device": device,
- "log_frequency": log_frequency,
- "unrolled": unrolled,
- "search_space_path": None
- }
-
- def train(self, validate=False):
- switches = None
- last = False
- for epoch in range(self.pdarts_epoches):
- if epoch == self.pdarts_epoches - 1:
- last = True
- # create network for each stage
- layers = self.init_layers + self.pdarts_num_layers[epoch]
- init_dropout_rate = float(self.pdarts_dropout_rates[epoch])
- model = CNN(32, 3, self.channels, self.class_num, layers,
- init_dropout_rate, n_nodes=4, search=True)
- criterion = nn.CrossEntropyLoss()
- optim = torch.optim.SGD(
- model.parameters(), self.model_lr, momentum=0.9, weight_decay=3.0E-4)
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
- optim, self.num_epochs, eta_min=0.001)
-
- logger.info(
- "############Start PDARTS training epoch %s############", epoch)
- self.mutator = PdartsMutator(
- model, epoch, self.pdarts_num_to_drop, switches)
- if epoch == 0:
- # only write original search space in first stage
- search_space = self.mutator._generate_search_space()
- dump_global_result(self.search_space_path,
- search_space)
-
- darts_callbacks = []
- if lr_scheduler is not None:
- darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
- # darts_callbacks.append(ArchitectureCheckpoint(
- # os.path.join(self.selected_space_path, "stage_{}".format(epoch))))
- self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion,
- optimizer=optim, callbacks=darts_callbacks, **self.darts_parameters)
-
- for train_epoch in range(self.darts_parameters["num_epochs"]):
- for callback in darts_callbacks:
- callback.on_epoch_begin(train_epoch)
-
- # training
- logger.info("Epoch %d Training", train_epoch)
- if train_epoch < self.darts_parameters["num_pre_epochs"]:
- dropout_rate = init_dropout_rate * \
- (self.darts_parameters["num_epochs"] - train_epoch -
- 1) / self.darts_parameters["num_epochs"]
- else:
- # scale_factor = 0.2
- dropout_rate = init_dropout_rate * \
- np.exp(-(epoch -
- self.darts_parameters["num_pre_epochs"]) * 0.2)
-
- model.drop_path_prob(search=True, p=dropout_rate)
- self.trainer.train_one_epoch(train_epoch)
-
- if validate:
- # validation
- logger.info("Epoch %d Validating", train_epoch + 1)
- self.trainer.validate_one_epoch(
- train_epoch, log_print=True if last else False)
-
- for callback in darts_callbacks:
- callback.on_epoch_end(train_epoch)
-
- switches = self.mutator.drop_paths()
-
- # In last pdarts_epoches, need to restrict skipconnection and save best structure
- if last:
- res = OrderedDict()
- op_value = [value for value in search_space["op_list"]["_value"] if value != 'none']
- res["op_list"] = search_space["op_list"]
- res["op_list"]["_value"] = op_value
- res["best_selected_space"] = self.mutator.export(last, switches)
- logger.info(res)
- dump_global_result(self.best_selected_space_path, res)
-
- def validate(self):
- self.trainer.validate()
-
- def export(self, file, last, switches):
- self.mutator.export(last, switches)
- mutator_export = self.mutator.export()
-
- with open(file, "w") as f:
- json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
-
- def checkpoint(self, file_path, epoch):
- if isinstance(self.model, nn.DataParallel):
- child_model_state_dict = self.model.module.state_dict()
- else:
- child_model_state_dict = self.model.state_dict()
-
- save_state = {'child_model_state_dict': child_model_state_dict,
- 'optimizer_state_dict': self.optimizer.state_dict(),
- 'epoch': epoch}
-
- dest_path = os.path.join(
- file_path, "best_checkpoint_epoch_{}.pth.tar".format(epoch))
- logger.info("Saving model to %s", dest_path)
- torch.save(save_state, dest_path)
- raise NotImplementedError("Not implemented yet")
-
|