|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
-
- import copy
- import logging
- import os
- import argparse
- import logging
- import sys
- sys.path.append('..'+ '/' + '..')
- from collections import OrderedDict
-
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from torchvision import datasets, transforms
- from model import Net
-
- from pytorch.trainer import Trainer
- from pytorch.utils import AverageMeterGroup
- from pytorch.utils import mkdirs
- from pytorch.mutables import LayerChoice, InputChoice
-
- from mutator import ClassicMutator
- import numpy as np
- import time
- import json
-
- logger = logging.getLogger(__name__)
- #logger.setLevel(logging.INFO)
-
-
- class ClassicnasTrainer(Trainer):
- """
- Classicnas trainer.
-
- Parameters
- ----------
- model : nn.Module
- PyTorch model to be trained.
- loss : callable
- Receives logits and ground truth label, return a loss tensor.
- metrics : callable
- Receives logits and ground truth label, return a dict of metrics.
- optimizer : Optimizer
- The optimizer used for optimizing the model.
- num_epochs : int
- Number of epochs planned for training.
- dataset_train : Dataset
- Dataset for training. Will be split for training weights and architecture weights.
- dataset_valid : Dataset
- Dataset for testing.
- mutator : ClassicMutator
- Use in case of customizing your own ClassicMutator. By default will instantiate a ClassicMutator.
- batch_size : int
- Batch size.
- workers : int
- Workers for data loading.
- device : torch.device
- ``torch.device("cpu")`` or ``torch.device("cuda")``.
- log_frequency : int
- Step count per logging.
- callbacks : list of Callback
- list of callbacks to trigger at events.
- arc_learning_rate : float
- Learning rate of architecture parameters.
- unrolled : float
- ``True`` if using second order optimization, else first order optimization.
- """
- def __init__(self, model, loss, metrics,
- optimizer, epochs, dataset_train, dataset_valid, search_space_path,selected_space_path,trial_id,
- mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
- callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
-
-
- self.model = model
-
- self.loss = loss
- self.metrics = metrics
- self.optimizer = optimizer
- self.epochs = epochs
- self.device = device
- self.batch_size = batch_size
-
- self.train_loader = torch.utils.data.DataLoader(
- datasets.MNIST(dataset_train, train=True, download=False,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=batch_size, shuffle=True, **kwargs)
-
- self.test_loader = torch.utils.data.DataLoader(
- datasets.MNIST(dataset_valid, train=False, transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=1000, shuffle=True, **kwargs)
-
-
-
- self.search_space_path = search_space_path
- self.selected_space_path =selected_space_path
- self.trial_id = trial_id
- self.num_epochs = 10
- self.classicmutator=ClassicMutator(self.model,trial_id=self.trial_id,selected_path=self.selected_space_path,search_space_path=self.search_space_path)
-
- self.result = {"accuracy": [],"cost_time": 0.}
-
- def train_one_epoch(self, epoch):
-
-
- # t1 = time()
- # phase 1. architecture step
- self.classicmutator.trial_id = epoch
- self.classicmutator._chosen_arch=self.classicmutator.random_generate_chosen()
- #print('epoch:',epoch,'\n',self.classicmutator._chosen_arch)
-
- # phase 2: child network step
- for child_epoch in range(1, self.epochs + 1):
-
- self.model.train()
- for batch_idx, (data, target) in enumerate(self.train_loader):
- data, target = data.to(self.device), target.to(self.device)
- optimizer.zero_grad()
- output = self.model(data)
- loss = F.nll_loss(output, target)
- loss.backward()
- optimizer.step()
- if batch_idx % args['log_interval'] == 0:
- logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
- child_epoch, batch_idx * len(data), len(self.train_loader.dataset),
- 100. * batch_idx / len(self.train_loader), loss.item()))
-
- test_acc = self.validate_one_epoch(epoch)
- print({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} )
- with open(args['result_path'], "a") as ss_file:
- ss_file.write(json.dumps({"type":"accuracy","result":{"sequence":child_epoch,"category":"epoch","value":test_acc}} ) + '\n')
- self.result['accuracy'].append(test_acc)
-
- def validate_one_epoch(self, epoch):
- self.model.eval()
- test_loss = 0
- correct = 0
- with torch.no_grad():
- for data, target in self.test_loader:
- data, target = data.to(self.device), target.to(self.device)
- output = self.model(data)
- # sum up batch loss
- test_loss += F.nll_loss(output, target, reduction='sum').item()
- # get the index of the max log-probability
- pred = output.argmax(dim=1, keepdim=True)
- correct += pred.eq(target.view_as(pred)).sum().item()
-
- test_loss /= len(self.test_loader.dataset)
-
- accuracy = 100. * correct / len(self.test_loader.dataset)
-
- logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
- test_loss, correct, len(self.test_loader.dataset), accuracy))
-
- return accuracy
- def train(self):
- """
- Train ``num_epochs``.
- Trigger callbacks at the start and the end of each epoch.
-
- Parameters
- ----------
- validate : bool
- If ``true``, will do validation every epoch.
- """
- for epoch in range(self.num_epochs):
- # training
- self.train_one_epoch(epoch)
-
-
- def dump_global_result(args,global_result):
- with open(args['result_path'], "w") as ss_file:
- json.dump(global_result, ss_file, sort_keys=True, indent=2)
-
- def get_params():
- # Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument("--data_dir", type=str,
- default='./data', help="data directory")
- parser.add_argument("--model_selected_space_path", type=str,
- default='./selected_space.json', help="selected_space_path")
- parser.add_argument("--search_space_path", type=str,
- default='./selected_space.json', help="search_space_path")
- parser.add_argument("--result_path", type=str,
- default='./model_result.json', help="result_path")
- parser.add_argument('--batch_size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
- help='hidden layer size (default: 512)')
- parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
- help='learning rate (default: 0.01)')
- parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
- help='SGD momentum (default: 0.5)')
- parser.add_argument('--epochs', type=int, default=10, metavar='N',
- help='number of epochs to train (default: 10)')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--no_cuda', default=False,
- help='disables CUDA training')
- parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
- help='how many batches to wait before logging training status')
- parser.add_argument('--trial_id', type=int, default=0, metavar='N',
- help='trial_id,start from 0')
-
- args, _ = parser.parse_known_args()
- return args
-
- if __name__ == '__main__':
- try:
- start=time.time()
-
- params = vars(get_params())
- args =params
-
- use_cuda = not args['no_cuda'] and torch.cuda.is_available()
-
- torch.manual_seed(args['seed'])
-
- device = torch.device("cuda" if use_cuda else "cpu")
-
- kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
-
- data_dir = args['data_dir']
-
- hidden_size = args['hidden_size']
-
- model = Net(hidden_size=hidden_size).to(device)
-
- optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'],
- momentum=args['momentum'])
-
- mkdirs(args['search_space_path'])
- mkdirs(args['model_selected_space_path'])
- mkdirs(args['result_path'])
- trainer = ClassicnasTrainer(model,
- loss=None,
- metrics=None,
- optimizer=optimizer,
- epochs=args['epochs'],
- dataset_train=data_dir,
- dataset_valid=data_dir,
- search_space_path = args['search_space_path'],
- selected_space_path = args['model_selected_space_path'],
- trial_id = args['trial_id'],
- batch_size=args['batch_size'],
- log_frequency=args['log_interval'],
- device= device,
- unrolled=None,
- callbacks=None)
-
- with open(args['result_path'], "w") as ss_file:
- ss_file.write('')
- trainer.train_one_epoch(args['trial_id'])
- #trainer.train()
- global_result = trainer.result
- #global_result['cost_time'] = str(time.time() - start) +'s'
- #dump_global_result(params,global_result)
- except Exception as exception:
- logger.exception(exception)
- raise
|