|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
-
- import sys
- import os
- import logging
- import pickle
- import shutil
- import random
- import math
-
- import time
- import datetime
- import argparse
- import distutils.util
-
- import numpy as np
- import torch
- from torch import nn
- from torch import optim
- from torch.utils.data import DataLoader
- import torch.nn.functional as Func
-
- from model import Model
- from pytorch.fixed import apply_fixed_architecture
- from dataloader import read_data_sst
-
-
- logger = logging.getLogger("nni.textnas")
-
-
- def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--reset_output_dir",
- type=distutils.util.strtobool,
- default=True,
- help="Whether to clean the output dir if existed. (default: %(default)s)")
- parser.add_argument(
- "--child_fixed_arc",
- type=str,
- required=True,
- help="Architecture json file. (default: %(default)s)")
- parser.add_argument(
- "--data_path",
- type=str,
- default="data",
- help="Directory containing the dataset and embedding file. (default: %(default)s)")
- parser.add_argument(
- "--output_dir",
- type=str,
- default="output",
- help="The output directory. (default: %(default)s)")
- parser.add_argument(
- "--child_lr_decay_scheme",
- type=str,
- default="cosine",
- help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)")
- parser.add_argument(
- "--batch_size",
- type=int,
- default=128,
- help="Number of samples each batch for training. (default: %(default)s)")
- parser.add_argument(
- "--eval_batch_size",
- type=int,
- default=128,
- help="Number of samples each batch for evaluation. (default: %(default)s)")
- parser.add_argument(
- "--class_num",
- type=int,
- default=5,
- help="The number of categories. (default: %(default)s)")
- parser.add_argument(
- "--global_seed",
- type=int,
- default=1234,
- help="Seed for reproduction. (default: %(default)s)")
- parser.add_argument(
- "--max_input_length",
- type=int,
- default=64,
- help="The maximum length of the sentence. (default: %(default)s)")
- parser.add_argument(
- "--num_epochs",
- type=int,
- default=10,
- help="The number of training epochs. (default: %(default)s)")
- parser.add_argument(
- "--child_num_layers",
- type=int,
- default=24,
- help="The layer number of the architecture. (default: %(default)s)")
- parser.add_argument(
- "--child_out_filters",
- type=int,
- default=256,
- help="The dimension of hidden states. (default: %(default)s)")
- parser.add_argument(
- "--child_out_filters_scale",
- type=int,
- default=1,
- help="The scale of hidden state dimension. (default: %(default)s)")
- parser.add_argument(
- "--child_lr_T_0",
- type=int,
- default=10,
- help="The length of one cycle. (default: %(default)s)")
- parser.add_argument(
- "--child_lr_T_mul",
- type=int,
- default=2,
- help="The multiplication factor per cycle. (default: %(default)s)")
- parser.add_argument(
- "--min_count",
- type=int,
- default=1,
- help="The threshold to cut off low frequent words. (default: %(default)s)")
- parser.add_argument(
- "--train_ratio",
- type=float,
- default=1.0,
- help="The sample ratio for the training set. (default: %(default)s)")
- parser.add_argument(
- "--valid_ratio",
- type=float,
- default=1.0,
- help="The sample ratio for the dev set. (default: %(default)s)")
- parser.add_argument(
- "--child_grad_bound",
- type=float,
- default=5.0,
- help="The threshold for gradient clipping. (default: %(default)s)")
- parser.add_argument(
- "--child_lr",
- type=float,
- default=0.02,
- help="The initial learning rate. (default: %(default)s)")
- parser.add_argument(
- "--cnn_keep_prob",
- type=float,
- default=0.8,
- help="Keep prob for cnn layer. (default: %(default)s)")
- parser.add_argument(
- "--final_output_keep_prob",
- type=float,
- default=1.0,
- help="Keep prob for the last output layer. (default: %(default)s)")
- parser.add_argument(
- "--lstm_out_keep_prob",
- type=float,
- default=0.8,
- help="Keep prob for the RNN layer. (default: %(default)s)")
- parser.add_argument(
- "--embed_keep_prob",
- type=float,
- default=0.8,
- help="Keep prob for the embedding layer. (default: %(default)s)")
- parser.add_argument(
- "--attention_keep_prob",
- type=float,
- default=0.8,
- help="Keep prob for the self-attention layer. (default: %(default)s)")
- parser.add_argument(
- "--child_l2_reg",
- type=float,
- default=3e-6,
- help="Weight decay factor. (default: %(default)s)")
- parser.add_argument(
- "--child_lr_max",
- type=float,
- default=0.002,
- help="The max learning rate. (default: %(default)s)")
- parser.add_argument(
- "--child_lr_min",
- type=float,
- default=0.001,
- help="The min learning rate. (default: %(default)s)")
- parser.add_argument(
- "--child_optim_algo",
- type=str,
- default="adam",
- help="Optimization algorithm. (default: %(default)s)")
- parser.add_argument(
- "--checkpoint_dir",
- type=str,
- default="best_checkpoint",
- help="Path for saved checkpoints. (default: %(default)s)")
- parser.add_argument(
- "--output_type",
- type=str,
- default="avg",
- help="Opertor type for the time steps reduction. (default: %(default)s)")
- parser.add_argument(
- "--multi_path",
- type=distutils.util.strtobool,
- default=False,
- help="Search for multiple path in the architecture. (default: %(default)s)")
- parser.add_argument(
- "--is_binary",
- type=distutils.util.strtobool,
- default=False,
- help="Binary label for sst dataset. (default: %(default)s)")
- parser.add_argument(
- "--is_cuda",
- type=distutils.util.strtobool,
- default=True,
- help="Specify the device type. (default: %(default)s)")
- parser.add_argument(
- "--is_mask",
- type=distutils.util.strtobool,
- default=True,
- help="Apply mask. (default: %(default)s)")
- parser.add_argument(
- "--fixed_seed",
- type=distutils.util.strtobool,
- default=True,
- help="Fix the seed. (default: %(default)s)")
- parser.add_argument(
- "--load_checkpoint",
- type=distutils.util.strtobool,
- default=False,
- help="Wether to load checkpoint. (default: %(default)s)")
- parser.add_argument(
- "--log_every",
- type=int,
- default=50,
- help="How many steps to log. (default: %(default)s)")
- parser.add_argument(
- "--eval_every_epochs",
- type=int,
- default=1,
- help="How many epochs to eval. (default: %(default)s)")
-
- global FLAGS
-
- FLAGS = parser.parse_args()
-
-
- def set_random_seed(seed):
- logger.info("set random seed for data reading: {}".format(seed))
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- random.seed(seed)
- torch.manual_seed(seed)
- if FLAGS.is_cuda:
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.deterministic = True
-
-
- def get_model(embedding, num_layers):
- logger.info("num layers: {0}".format(num_layers))
- assert FLAGS.child_fixed_arc is not None, "Architecture should be provided."
-
- child_model = Model(
- embedding=embedding,
- hidden_units=FLAGS.child_out_filters_scale * FLAGS.child_out_filters,
- num_layers=num_layers,
- num_classes=FLAGS.class_num,
- choose_from_k=5 if FLAGS.multi_path else 1,
- lstm_keep_prob=FLAGS.lstm_out_keep_prob,
- cnn_keep_prob=FLAGS.cnn_keep_prob,
- att_keep_prob=FLAGS.attention_keep_prob,
- att_mask=FLAGS.is_mask,
- embed_keep_prob=FLAGS.embed_keep_prob,
- final_output_keep_prob=FLAGS.final_output_keep_prob,
- global_pool=FLAGS.output_type)
-
- apply_fixed_architecture(child_model, FLAGS.child_fixed_arc)
- return child_model
-
-
- def eval_once(child_model, device, eval_set, criterion, valid_dataloader=None, test_dataloader=None):
- if eval_set == "test":
- assert test_dataloader is not None
- dataloader = test_dataloader
- elif eval_set == "valid":
- assert valid_dataloader is not None
- dataloader = valid_dataloader
- else:
- raise NotImplementedError("Unknown eval_set '{}'".format(eval_set))
-
- tot_acc = 0
- tot = 0
- losses = []
-
- with torch.no_grad(): # save memory
- for batch in dataloader:
- (sent_ids, mask), labels = batch
-
- sent_ids = sent_ids.to(device, non_blocking=True)
- mask = mask.to(device, non_blocking=True)
- labels = labels.to(device, non_blocking=True)
-
- logits = child_model((sent_ids, mask)) # run
-
- loss = criterion(logits, labels.long())
- loss = loss.mean()
- preds = logits.argmax(dim=1).long()
- acc = torch.eq(preds, labels.long()).long().sum().item()
-
- losses.append(loss)
- tot_acc += acc
- tot += len(labels)
-
- losses = torch.tensor(losses)
- loss = losses.mean()
- if tot > 0:
- final_acc = float(tot_acc) / tot
- else:
- final_acc = 0
- logger.info("Error in calculating final_acc")
- return final_acc, loss
-
-
- def print_user_flags(FLAGS, line_limit=80):
- log_strings = "\n" + "-" * line_limit + "\n"
- for flag_name in sorted(vars(FLAGS)):
- value = "{}".format(getattr(FLAGS, flag_name))
- log_string = flag_name
- log_string += "." * (line_limit - len(flag_name) - len(value))
- log_string += value
- log_strings = log_strings + log_string
- log_strings = log_strings + "\n"
- log_strings += "-" * line_limit
- logger.info(log_strings)
-
-
- def count_model_params(trainable_params):
- num_vars = 0
- for var in trainable_params:
- num_vars += np.prod([dim for dim in var.size()])
- return num_vars
-
-
- def update_lr(
- optimizer,
- epoch,
- l2_reg=1e-4,
- lr_warmup_val=None,
- lr_init=0.1,
- lr_decay_scheme="cosine",
- lr_max=0.002,
- lr_min=0.000000001,
- lr_T_0=4,
- lr_T_mul=1,
- sync_replicas=False,
- num_aggregate=None,
- num_replicas=None):
- if lr_decay_scheme == "cosine":
- assert lr_max is not None, "Need lr_max to use lr_cosine"
- assert lr_min is not None, "Need lr_min to use lr_cosine"
- assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine"
- assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine"
-
- T_i = lr_T_0
- t_epoch = epoch
- last_reset = 0
- while True:
- t_epoch -= T_i
- if t_epoch < 0:
- break
- last_reset += T_i
- T_i *= lr_T_mul
-
- T_curr = epoch - last_reset
-
- def _update():
- rate = T_curr / T_i * 3.1415926
- lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + math.cos(rate))
- return lr
-
- learning_rate = _update()
- else:
- raise ValueError("Unknown learning rate decay scheme {}".format(lr_decay_scheme))
-
- #update lr in optimizer
- for params_group in optimizer.param_groups:
- params_group['lr'] = learning_rate
- return learning_rate
-
-
- def train(device, data_path, output_dir, num_layers):
- logger.info("Build dataloader")
- train_dataset, valid_dataset, test_dataset, embedding = \
- read_data_sst(data_path,
- FLAGS.max_input_length,
- FLAGS.min_count,
- train_ratio=FLAGS.train_ratio,
- valid_ratio=FLAGS.valid_ratio,
- is_binary=FLAGS.is_binary)
- train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, pin_memory=True)
- test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.eval_batch_size, pin_memory=True)
- valid_dataloader = DataLoader(valid_dataset, batch_size=FLAGS.eval_batch_size, pin_memory=True)
-
- logger.info("Build model")
- child_model = get_model(embedding, num_layers)
- logger.info("Finish build model")
-
- #for name, var in child_model.named_parameters():
- # logger.info(name, var.size(), var.requires_grad) # output all params
-
- num_vars = count_model_params(child_model.parameters())
- logger.info("Model has {} params".format(num_vars))
-
- for m in child_model.modules(): # initializer
- if isinstance(m, (nn.Conv1d, nn.Linear)):
- nn.init.xavier_uniform_(m.weight)
-
- criterion = nn.CrossEntropyLoss()
-
- # get optimizer
- if FLAGS.child_optim_algo == "adam":
- optimizer = optim.Adam(child_model.parameters(), eps=1e-3, weight_decay=FLAGS.child_l2_reg) # with L2
- else:
- raise ValueError("Unknown optim_algo {}".format(FLAGS.child_optim_algo))
-
- child_model.to(device)
- criterion.to(device)
-
- logger.info("Start training")
- start_time = time.time()
- step = 0
-
- # save path
- model_save_path = os.path.join(FLAGS.output_dir, "model.pth")
- best_model_save_path = os.path.join(FLAGS.output_dir, "best_model.pth")
- best_acc = 0
- start_epoch = 0
- if FLAGS.load_checkpoint:
- if os.path.isfile(model_save_path):
- checkpoint = torch.load(model_save_path, map_location = torch.device('cpu'))
- step = checkpoint['step']
- start_epoch = checkpoint['epoch']
- child_model.load_state_dict(checkpoint['child_model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-
- for epoch in range(start_epoch, FLAGS.num_epochs):
- lr = update_lr(optimizer,
- epoch,
- l2_reg=FLAGS.child_l2_reg,
- lr_warmup_val=None,
- lr_init=FLAGS.child_lr,
- lr_decay_scheme=FLAGS.child_lr_decay_scheme,
- lr_max=FLAGS.child_lr_max,
- lr_min=FLAGS.child_lr_min,
- lr_T_0=FLAGS.child_lr_T_0,
- lr_T_mul=FLAGS.child_lr_T_mul)
- child_model.train()
- for batch in train_dataloader:
- (sent_ids, mask), labels = batch
-
- sent_ids = sent_ids.to(device, non_blocking=True)
- mask = mask.to(device, non_blocking=True)
- labels = labels.to(device, non_blocking=True)
-
- step += 1
-
- logits = child_model((sent_ids, mask)) # run
-
- loss = criterion(logits, labels.long())
- loss = loss.mean()
- preds = logits.argmax(dim=1).long()
- acc = torch.eq(preds, labels.long()).long().sum().item()
-
- optimizer.zero_grad()
- loss.backward()
- grad_norm = 0
- trainable_params = child_model.parameters()
-
- assert FLAGS.child_grad_bound is not None, "Need grad_bound to clip gradients."
- # compute the gradient norm value
- grad_norm = nn.utils.clip_grad_norm_(trainable_params, 99999999)
- for param in trainable_params:
- nn.utils.clip_grad_norm_(param, FLAGS.child_grad_bound) # clip grad
-
- optimizer.step()
-
- if step % FLAGS.log_every == 0:
- curr_time = time.time()
- log_string = ""
- log_string += "epoch={:<6d}".format(epoch)
- log_string += "ch_step={:<6d}".format(step)
- log_string += " loss={:<8.6f}".format(loss)
- log_string += " lr={:<8.4f}".format(lr)
- log_string += " |g|={:<8.4f}".format(grad_norm)
- log_string += " tr_acc={:<3d}/{:>3d}".format(acc, logits.size()[0])
- log_string += " mins={:<10.2f}".format(float(curr_time - start_time) / 60)
- logger.info(log_string)
-
- epoch += 1
- save_state = {
- 'step' : step,
- 'epoch' : epoch,
- 'child_model_state_dict' : child_model.state_dict(),
- 'optimizer_state_dict' : optimizer.state_dict()}
- torch.save(save_state, model_save_path)
- child_model.eval()
- logger.info("Epoch {}: Eval".format(epoch))
- eval_acc, eval_loss = eval_once(child_model, device, "test", criterion, test_dataloader=test_dataloader)
- logger.info("ch_step={} {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format(step, "test", eval_acc, "test", eval_loss))
- if eval_acc > best_acc:
- best_acc = eval_acc
- logger.info("Save best model")
- save_state = {
- 'step' : step,
- 'epoch' : epoch,
- 'child_model_state_dict' : child_model.state_dict(),
- 'optimizer_state_dict' : optimizer.state_dict()}
- torch.save(save_state, best_model_save_path)
-
- return eval_acc
-
-
- def main():
- parse_args()
- if not os.path.isdir(FLAGS.output_dir):
- logger.info("Path {} does not exist. Creating.".format(FLAGS.output_dir))
- os.makedirs(FLAGS.output_dir)
- elif FLAGS.reset_output_dir:
- logger.info("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
- shutil.rmtree(FLAGS.output_dir, ignore_errors=True)
- os.makedirs(FLAGS.output_dir)
-
- print_user_flags(FLAGS)
-
- if FLAGS.fixed_seed:
- set_random_seed(FLAGS.global_seed)
-
- device = torch.device("cuda" if FLAGS.is_cuda else "cpu")
- train(device, FLAGS.data_path, FLAGS.output_dir, FLAGS.child_num_layers)
-
-
- if __name__ == "__main__":
- main()
|