Browse Source

Add ENAS (Efficient Neural Architecture Search)

tags/v0.4.0
chenkaiyu1997 5 years ago
parent
commit
efeac2c427
5 changed files with 1164 additions and 0 deletions
  1. +223
    -0
      fastNLP/models/enas_controller.py
  2. +388
    -0
      fastNLP/models/enas_model.py
  3. +385
    -0
      fastNLP/models/enas_trainer.py
  4. +56
    -0
      fastNLP/models/enas_utils.py
  5. +112
    -0
      test/models/test_enas.py

+ 223
- 0
fastNLP/models/enas_controller.py View File

@@ -0,0 +1,223 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch
"""A module with NAS controller-related code."""
import collections
import os

import torch
import torch.nn.functional as F
import fastNLP
import fastNLP.models.enas_utils as utils
from fastNLP.models.enas_utils import Node


def _construct_dags(prev_nodes, activations, func_names, num_blocks):
"""Constructs a set of DAGs based on the actions, i.e., previous nodes and
activation functions, sampled from the controller/policy pi.

Args:
prev_nodes: Previous node actions from the policy.
activations: Activations sampled from the policy.
func_names: Mapping from activation function names to functions.
num_blocks: Number of blocks in the target RNN cell.

Returns:
A list of DAGs defined by the inputs.

RNN cell DAGs are represented in the following way:

1. Each element (node) in a DAG is a list of `Node`s.

2. The `Node`s in the list dag[i] correspond to the subsequent nodes
that take the output from node i as their own input.

3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}.
dag[-1] always feeds dag[0].
dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its
weights.

4. dag[N - 1] is the node that produces the hidden state passed to
the next timestep. dag[N - 1] is also always a leaf node, and therefore
is always averaged with the other leaf nodes and fed to the output
decoder.
"""
dags = []
for nodes, func_ids in zip(prev_nodes, activations):
dag = collections.defaultdict(list)

# add first node
dag[-1] = [Node(0, func_names[func_ids[0]])]
dag[-2] = [Node(0, func_names[func_ids[0]])]

# add following nodes
for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])):
dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id]))

leaf_nodes = set(range(num_blocks)) - dag.keys()

# merge with avg
for idx in leaf_nodes:
dag[idx] = [Node(num_blocks, 'avg')]

# This is actually y^{(t)}. h^{(t)} is node N - 1 in
# the graph, where N Is the number of nodes. I.e., h^{(t)} takes
# only one other node as its input.
# last h[t] node
last_node = Node(num_blocks + 1, 'h[t]')
dag[num_blocks] = [last_node]
dags.append(dag)

return dags


class Controller(torch.nn.Module):
"""Based on
https://github.com/pytorch/examples/blob/master/word_language_model/model.py

RL controllers do not necessarily have much to do with
language models.

Base the controller RNN on the GRU from:
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
"""
def __init__(self, num_blocks=4, controller_hid=100, cuda=False):
torch.nn.Module.__init__(self)

# `num_tokens` here is just the activation function
# for every even step,
self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid']
self.num_tokens = [len(self.shared_rnn_activations)]
self.controller_hid = controller_hid
self.use_cuda = cuda
self.num_blocks = num_blocks
for idx in range(num_blocks):
self.num_tokens += [idx + 1, len(self.shared_rnn_activations)]
self.func_names = self.shared_rnn_activations

num_total_tokens = sum(self.num_tokens)

self.encoder = torch.nn.Embedding(num_total_tokens,
controller_hid)
self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid)

# Perhaps these weights in the decoder should be
# shared? At least for the activation functions, which all have the
# same size.
self.decoders = []
for idx, size in enumerate(self.num_tokens):
decoder = torch.nn.Linear(controller_hid, size)
self.decoders.append(decoder)

self._decoders = torch.nn.ModuleList(self.decoders)

self.reset_parameters()
self.static_init_hidden = utils.keydefaultdict(self.init_hidden)

def _get_default_hidden(key):
return utils.get_variable(
torch.zeros(key, self.controller_hid),
self.use_cuda,
requires_grad=False)

self.static_inputs = utils.keydefaultdict(_get_default_hidden)

def reset_parameters(self):
init_range = 0.1
for param in self.parameters():
param.data.uniform_(-init_range, init_range)
for decoder in self.decoders:
decoder.bias.data.fill_(0)

def forward(self, # pylint:disable=arguments-differ
inputs,
hidden,
block_idx,
is_embed):
if not is_embed:
embed = self.encoder(inputs)
else:
embed = inputs

hx, cx = self.lstm(embed, hidden)
logits = self.decoders[block_idx](hx)

logits /= 5.0

# # exploration
# if self.args.mode == 'train':
# logits = (2.5 * F.tanh(logits))

return logits, (hx, cx)

def sample(self, batch_size=1, with_details=False, save_dir=None):
"""Samples a set of `args.num_blocks` many computational nodes from the
controller, where each node is made up of an activation function, and
each node except the last also includes a previous node.
"""
if batch_size < 1:
raise Exception(f'Wrong batch_size: {batch_size} < 1')

# [B, L, H]
inputs = self.static_inputs[batch_size]
hidden = self.static_init_hidden[batch_size]

activations = []
entropies = []
log_probs = []
prev_nodes = []
# The RNN controller alternately outputs an activation,
# followed by a previous node, for each block except the last one,
# which only gets an activation function. The last node is the output
# node, and its previous node is the average of all leaf nodes.
for block_idx in range(2*(self.num_blocks - 1) + 1):
logits, hidden = self.forward(inputs,
hidden,
block_idx,
is_embed=(block_idx == 0))

probs = F.softmax(logits, dim=-1)
log_prob = F.log_softmax(logits, dim=-1)
# .mean() for entropy?
entropy = -(log_prob * probs).sum(1, keepdim=False)

action = probs.multinomial(num_samples=1).data
selected_log_prob = log_prob.gather(
1, utils.get_variable(action, requires_grad=False))

# why the [:, 0] here? Should it be .squeeze(), or
# .view()? Same below with `action`.
entropies.append(entropy)
log_probs.append(selected_log_prob[:, 0])

# 0: function, 1: previous node
mode = block_idx % 2
inputs = utils.get_variable(
action[:, 0] + sum(self.num_tokens[:mode]),
requires_grad=False)

if mode == 0:
activations.append(action[:, 0])
elif mode == 1:
prev_nodes.append(action[:, 0])

prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
activations = torch.stack(activations).transpose(0, 1)

dags = _construct_dags(prev_nodes,
activations,
self.func_names,
self.num_blocks)

if save_dir is not None:
for idx, dag in enumerate(dags):
utils.draw_network(dag,
os.path.join(save_dir, f'graph{idx}.png'))

if with_details:
return dags, torch.cat(log_probs), torch.cat(entropies)

return dags

def init_hidden(self, batch_size):
zeros = torch.zeros(batch_size, self.controller_hid)
return (utils.get_variable(zeros, self.use_cuda, requires_grad=False),
utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False))

+ 388
- 0
fastNLP/models/enas_model.py View File

@@ -0,0 +1,388 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

"""Module containing the shared RNN model."""
import numpy as np
import collections

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable

import fastNLP.models.enas_utils as utils
from fastNLP.models.base_model import BaseModel
import fastNLP.modules.encoder as encoder

def _get_dropped_weights(w_raw, dropout_p, is_training):
"""Drops out weights to implement DropConnect.

Args:
w_raw: Full, pre-dropout, weights to be dropped out.
dropout_p: Proportion of weights to drop out.
is_training: True iff _shared_ model is training.

Returns:
The dropped weights.

Why does torch.nn.functional.dropout() return:
1. `torch.autograd.Variable()` on the training loop
2. `torch.nn.Parameter()` on the controller or eval loop, when
training = False...

Even though the call to `_setweights` in the Smerity repo's
`weight_drop.py` does not have this behaviour, and `F.dropout` always
returns `torch.autograd.Variable` there, even when `training=False`?

The above TODO is the reason for the hacky check for `torch.nn.Parameter`.
"""
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training)

if isinstance(dropped_w, torch.nn.Parameter):
dropped_w = dropped_w.clone()

return dropped_w

class EmbeddingDropout(torch.nn.Embedding):
"""Class for dropping out embeddings by zero'ing out parameters in the
embedding matrix.

This is equivalent to dropping out particular words, e.g., in the sentence
'the quick brown fox jumps over the lazy dog', dropping out 'the' would
lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the
embedding vector space).

See 'A Theoretically Grounded Application of Dropout in Recurrent Neural
Networks', (Gal and Ghahramani, 2016).
"""
def __init__(self,
num_embeddings,
embedding_dim,
max_norm=None,
norm_type=2,
scale_grad_by_freq=False,
sparse=False,
dropout=0.1,
scale=None):
"""Embedding constructor.

Args:
dropout: Dropout probability.
scale: Used to scale parameters of embedding weight matrix that are
not dropped out. Note that this is _in addition_ to the
`1/(1 - dropout)` scaling.

See `torch.nn.Embedding` for remaining arguments.
"""
torch.nn.Embedding.__init__(self,
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
self.dropout = dropout
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 '
'and < 1.0')
self.scale = scale

def forward(self, inputs): # pylint:disable=arguments-differ
"""Embeds `inputs` with the dropped out embedding weight matrix."""
if self.training:
dropout = self.dropout
else:
dropout = 0

if dropout:
mask = self.weight.data.new(self.weight.size(0), 1)
mask.bernoulli_(1 - dropout)
mask = mask.expand_as(self.weight)
mask = mask / (1 - dropout)
masked_weight = self.weight * Variable(mask)
else:
masked_weight = self.weight
if self.scale and self.scale != 1:
masked_weight = masked_weight * self.scale

return F.embedding(inputs,
masked_weight,
max_norm=self.max_norm,
norm_type=self.norm_type,
scale_grad_by_freq=self.scale_grad_by_freq,
sparse=self.sparse)


class LockedDropout(nn.Module):
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py
def __init__(self):
super().__init__()

def forward(self, x, dropout=0.5):
if not self.training or not dropout:
return x
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
mask = Variable(m, requires_grad=False) / (1 - dropout)
mask = mask.expand_as(x)
return mask * x


class ENASModel(BaseModel):
"""Shared RNN model."""
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000):
super(ENASModel, self).__init__()

self.use_cuda = cuda

self.shared_hid = shared_hid
self.num_blocks = num_blocks
self.decoder = nn.Linear(self.shared_hid, num_classes)
self.encoder = EmbeddingDropout(embed_num,
shared_embed,
dropout=0.1)
self.lockdrop = LockedDropout()
self.dag = None

# Tie weights
# self.decoder.weight = self.encoder.weight

# Since W^{x, c} and W^{h, c} are always summed, there
# is no point duplicating their bias offset parameter. Likewise for
# W^{x, h} and W^{h, h}.
self.w_xc = nn.Linear(shared_embed, self.shared_hid)
self.w_xh = nn.Linear(shared_embed, self.shared_hid)

# The raw weights are stored here because the hidden-to-hidden weights
# are weight dropped on the forward pass.
self.w_hc_raw = torch.nn.Parameter(
torch.Tensor(self.shared_hid, self.shared_hid))
self.w_hh_raw = torch.nn.Parameter(
torch.Tensor(self.shared_hid, self.shared_hid))
self.w_hc = None
self.w_hh = None

self.w_h = collections.defaultdict(dict)
self.w_c = collections.defaultdict(dict)

for idx in range(self.num_blocks):
for jdx in range(idx + 1, self.num_blocks):
self.w_h[idx][jdx] = nn.Linear(self.shared_hid,
self.shared_hid,
bias=False)
self.w_c[idx][jdx] = nn.Linear(self.shared_hid,
self.shared_hid,
bias=False)

self._w_h = nn.ModuleList([self.w_h[idx][jdx]
for idx in self.w_h
for jdx in self.w_h[idx]])
self._w_c = nn.ModuleList([self.w_c[idx][jdx]
for idx in self.w_c
for jdx in self.w_c[idx]])

self.batch_norm = None
# if args.mode == 'train':
# self.batch_norm = nn.BatchNorm1d(self.shared_hid)
# else:
# self.batch_norm = None

self.reset_parameters()
self.static_init_hidden = utils.keydefaultdict(self.init_hidden)

def setDAG(self, dag):
if self.dag is None:
self.dag = dag

def forward(self, word_seq, hidden=None):
inputs = torch.transpose(word_seq, 0, 1)

time_steps = inputs.size(0)
batch_size = inputs.size(1)


self.w_hh = _get_dropped_weights(self.w_hh_raw,
0.5,
self.training)
self.w_hc = _get_dropped_weights(self.w_hc_raw,
0.5,
self.training)

# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden
hidden = self.static_init_hidden[batch_size]

embed = self.encoder(inputs)

embed = self.lockdrop(embed, 0.65 if self.training else 0)

# The norm of hidden states are clipped here because
# otherwise ENAS is especially prone to exploding activations on the
# forward pass. This could probably be fixed in a more elegant way, but
# it might be exposing a weakness in the ENAS algorithm as currently
# proposed.
#
# For more details, see
# https://github.com/carpedm20/ENAS-pytorch/issues/6
clipped_num = 0
max_clipped_norm = 0
h1tohT = []
logits = []
for step in range(time_steps):
x_t = embed[step]
logit, hidden = self.cell(x_t, hidden, self.dag)

hidden_norms = hidden.norm(dim=-1)
max_norm = 25.0
if hidden_norms.data.max() > max_norm:
# Just directly use the torch slice operations
# in PyTorch v0.4.
#
# This workaround for PyTorch v0.3.1 does everything in numpy,
# because the PyTorch slicing and slice assignment is too
# flaky.
hidden_norms = hidden_norms.data.cpu().numpy()

clipped_num += 1
if hidden_norms.max() > max_clipped_norm:
max_clipped_norm = hidden_norms.max()

clip_select = hidden_norms > max_norm
clip_norms = hidden_norms[clip_select]

mask = np.ones(hidden.size())
normalizer = max_norm/clip_norms
normalizer = normalizer[:, np.newaxis]

mask[clip_select] = normalizer

if self.use_cuda:
hidden *= torch.autograd.Variable(
torch.FloatTensor(mask).cuda(), requires_grad=False)
else:
hidden *= torch.autograd.Variable(
torch.FloatTensor(mask), requires_grad=False)
logits.append(logit)
h1tohT.append(hidden)

h1tohT = torch.stack(h1tohT)
output = torch.stack(logits)
raw_output = output

output = self.lockdrop(output, 0.4 if self.training else 0)

#Pooling
output = torch.mean(output, 0)

decoded = self.decoder(output)

extra_out = {'dropped': decoded,
'hiddens': h1tohT,
'raw': raw_output}
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out}

def cell(self, x, h_prev, dag):
"""Computes a single pass through the discovered RNN cell."""
c = {}
h = {}
f = {}

f[0] = self.get_f(dag[-1][0].name)
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None))
h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) +
(1 - c[0])*h_prev)

leaf_node_ids = []
q = collections.deque()
q.append(0)

# Computes connections from the parent nodes `node_id`
# to their child nodes `next_id` recursively, skipping leaf nodes. A
# leaf node is a node whose id == `self.num_blocks`.
#
# Connections between parent i and child j should be computed as
# h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i,
# where c_j = \sigmoid{(W^c_{ij}*h_i)}
#
# See Training details from Section 3.1 of the paper.
#
# The following algorithm does a breadth-first (since `q.popleft()` is
# used) search over the nodes and computes all the hidden states.
while True:
if len(q) == 0:
break

node_id = q.popleft()
nodes = dag[node_id]

for next_node in nodes:
next_id = next_node.id
if next_id == self.num_blocks:
leaf_node_ids.append(node_id)
assert len(nodes) == 1, ('parent of leaf node should have '
'only one child')
continue

w_h = self.w_h[node_id][next_id]
w_c = self.w_c[node_id][next_id]

f[next_id] = self.get_f(next_node.name)
c[next_id] = torch.sigmoid(w_c(h[node_id]))
h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) +
(1 - c[next_id])*h[node_id])

q.append(next_id)

# Instead of averaging loose ends, perhaps there should
# be a set of separate unshared weights for each "loose" connection
# between each node in a cell and the output.
#
# As it stands, all weights W^h_{ij} are doing double duty by
# connecting both from i to j, as well as from i to the output.

# average all the loose ends
leaf_nodes = [h[node_id] for node_id in leaf_node_ids]
output = torch.mean(torch.stack(leaf_nodes, 2), -1)

# stabilizing the Updates of omega
if self.batch_norm is not None:
output = self.batch_norm(output)

return output, h[self.num_blocks - 1]

def init_hidden(self, batch_size):
zeros = torch.zeros(batch_size, self.shared_hid)
return utils.get_variable(zeros, self.use_cuda, requires_grad=False)

def get_f(self, name):
name = name.lower()
if name == 'relu':
f = torch.relu
elif name == 'tanh':
f = torch.tanh
elif name == 'identity':
f = lambda x: x
elif name == 'sigmoid':
f = torch.sigmoid
return f

@property
def num_parameters(self):
def size(p):
return np.prod(p.size())
return sum([size(param) for param in self.parameters()])


def reset_parameters(self):
init_range = 0.025
# init_range = 0.025 if self.args.mode == 'train' else 0.04
for param in self.parameters():
param.data.uniform_(-init_range, init_range)
self.decoder.bias.data.fill_(0)

def predict(self, word_seq):
"""

:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return predict: dict of torch.LongTensor, [batch_size, seq_len]
"""
output = self(word_seq)
_, predict = output['pred'].max(dim=1)
return {'pred': predict}

+ 385
- 0
fastNLP/models/enas_trainer.py View File

@@ -0,0 +1,385 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

import os
import time
from datetime import datetime
from datetime import timedelta

import numpy as np
import torch
import math
from torch import nn

try:
from tqdm.autonotebook import tqdm
except:
from fastNLP.core.utils import pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager, CallbackException
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _move_dict_value_to_device
import fastNLP
import fastNLP.models.enas_utils as utils
from fastNLP.core.utils import _build_args

from torch.optim import Adam


def _get_no_grad_ctx_mgr():
"""Returns a the `torch.no_grad` context manager for PyTorch version >=
0.4, or a no-op context manager otherwise.
"""
return torch.no_grad()


class ENASTrainer(fastNLP.Trainer):
"""A class to wrap training code."""
def __init__(self, train_data, model, controller, **kwargs):
"""Constructor for training algorithm.
:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param torch.nn.modules.module controller: a PyTorch model
"""
self.final_epochs = kwargs['final_epochs']
kwargs.pop('final_epochs')
super(ENASTrainer, self).__init__(train_data, model, **kwargs)
self.controller_step = 0
self.shared_step = 0
self.max_length = 35

self.shared = model
self.controller = controller

self.shared_optim = Adam(
self.shared.parameters(),
lr=20.0,
weight_decay=1e-7)

self.controller_optim = Adam(
self.controller.parameters(),
lr=3.5e-4)

def train(self, load_best_model=True):
"""
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。
:return results: 返回一个字典类型的数据, 内含以下内容::

seconds: float, 表示训练时长
以下三个内容只有在提供了dev_data的情况下会有。
best_eval: Dict of Dict, 表示evaluation的结果
best_epoch: int,在第几个epoch取得的最佳值
best_step: int, 在第几个step(batch)更新取得的最佳值

"""
results = {}
if self.n_epochs <= 0:
print(f"training epoch is {self.n_epochs}, nothing was done.")
results['seconds'] = 0.
return results
try:
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()
self._model_device = self.model.parameters().__next__().device
self._mode(self.model, is_test=False)

self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time()
print("training epochs started " + self.start_time, flush=True)

try:
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.on_train_end(self.model)
except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e, self.model)

if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf),)
results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step
if load_best_model:
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
load_succeed = self._load_model(self.model, model_name)
if load_succeed:
print("Reloaded the best model.")
else:
print("Fail to reload best model.")
finally:
pass
results['seconds'] = round(time.time() - start_time, 2)

return results

def _train(self):
if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
self.step = 0
start = time.time()
total_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
if epoch == self.n_epochs + 1 - self.final_epochs:
print('Entering the final stage. (Only train the selected structure)')
# early stopping
self.callback_manager.on_epoch_begin(epoch, self.n_epochs)

# 1. Training the shared parameters omega of the child models
self.train_shared(pbar)

# 2. Training the controller parameters theta
if not last_stage:
self.train_controller()

if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None:
if not last_stage:
self.derive()
eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
total_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str)

# lr decay; early stopping
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
# =============== epochs end =================== #
pbar.close()
# ============ tqdm end ============== #


def get_loss(self, inputs, targets, hidden, dags):
"""Computes the loss for the same batch for M models.

This amounts to an estimate of the loss, which is turned into an
estimate for the gradients of the shared model.
"""
if not isinstance(dags, list):
dags = [dags]

loss = 0
for dag in dags:
self.shared.setDAG(dag)
inputs = _build_args(self.shared.forward, **inputs)
inputs['hidden'] = hidden
result = self.shared(**inputs)
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out']

self.callback_manager.on_loss_begin(targets, result)
sample_loss = self._compute_loss(result, targets)
loss += sample_loss

assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
return loss, hidden, extra_out

def train_shared(self, pbar=None, max_step=None, dag=None):
"""Train the language model for 400 steps of minibatches of 64
examples.

Args:
max_step: Used to run extra training steps as a warm-up.
dag: If not None, is used instead of calling sample().

BPTT is truncated at 35 timesteps.

For each weight update, gradients are estimated by sampling M models
from the fixed controller policy, and averaging their gradients
computed on a batch of training data.
"""
model = self.shared
model.train()
self.controller.eval()

hidden = self.shared.init_hidden(self.batch_size)

abs_max_grad = 0
abs_max_hidden_norm = 0
step = 0
raw_total_loss = 0
total_loss = 0
train_idx = 0
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)

for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
# prediction = self._data_forward(self.model, batch_x)

dags = self.controller.sample(1)
inputs, targets = batch_x, batch_y
# self.callback_manager.on_loss_begin(batch_y, prediction)
loss, hidden, extra_out = self.get_loss(inputs,
targets,
hidden,
dags)
hidden.detach_()
avg_loss += loss.item()

# Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss, self.model)
self._grad_backward(loss)
self.callback_manager.on_backward_end(self.model)

self._update()
self.callback_manager.on_step_end(self.optimizer)

if (self.step+1) % self.print_every == 0:
if self.use_tqdm:
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
pbar.update(self.print_every)
else:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, avg_loss, diff)
pbar.set_postfix_str(print_output)
avg_loss = 0
self.step += 1
step += 1
self.shared_step += 1
self.callback_manager.on_batch_end()
# ================= mini-batch end ==================== #


def get_reward(self, dag, entropies, hidden, valid_idx=0):
"""Computes the perplexity of a single sampled model on a minibatch of
validation data.
"""
if not isinstance(entropies, np.ndarray):
entropies = entropies.data.cpu().numpy()

data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)

for inputs, targets in data_iterator:
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)
valid_loss = utils.to_item(valid_loss.data)

valid_ppl = math.exp(valid_loss)

R = 80 / valid_ppl

rewards = R + 1e-4 * entropies

return rewards, hidden

def train_controller(self):
"""Fixes the shared parameters and updates the controller parameters.

The controller is updated with a score function gradient estimator
(i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
is computed on a minibatch of validation data.

A moving average baseline is used.

The controller is trained for 2000 steps per epoch (i.e.,
first (Train Shared) phase -> second (Train Controller) phase).
"""
model = self.controller
model.train()
# Why can't we call shared.eval() here? Leads to loss
# being uniformly zero for the controller.
# self.shared.eval()

avg_reward_base = None
baseline = None
adv_history = []
entropy_history = []
reward_history = []

hidden = self.shared.init_hidden(self.batch_size)
total_loss = 0
valid_idx = 0
for step in range(20):
# sample models
dags, log_probs, entropies = self.controller.sample(
with_details=True)

# calculate reward
np_entropies = entropies.data.cpu().numpy()
# No gradients should be backpropagated to the
# shared model during controller training, obviously.
with _get_no_grad_ctx_mgr():
rewards, hidden = self.get_reward(dags,
np_entropies,
hidden,
valid_idx)


reward_history.extend(rewards)
entropy_history.extend(np_entropies)

# moving average baseline
if baseline is None:
baseline = rewards
else:
decay = 0.95
baseline = decay * baseline + (1 - decay) * rewards

adv = rewards - baseline
adv_history.extend(adv)

# policy loss
loss = -log_probs*utils.get_variable(adv,
self.use_cuda,
requires_grad=False)

loss = loss.sum() # or loss.mean()

# update
self.controller_optim.zero_grad()
loss.backward()

self.controller_optim.step()

total_loss += utils.to_item(loss.data)

if ((step % 50) == 0) and (step > 0):
reward_history, adv_history, entropy_history = [], [], []
total_loss = 0

self.controller_step += 1
# prev_valid_idx = valid_idx
# valid_idx = ((valid_idx + self.max_length) %
# (self.valid_data.size(0) - 1))
# # Whenever we wrap around to the beginning of the
# # validation data, we reset the hidden states.
# if prev_valid_idx > valid_idx:
# hidden = self.shared.init_hidden(self.batch_size)

def derive(self, sample_num=10, valid_idx=0):
"""We are always deriving based on the very first batch
of validation data? This seems wrong...
"""
hidden = self.shared.init_hidden(self.batch_size)

dags, _, entropies = self.controller.sample(sample_num,
with_details=True)

max_R = 0
best_dag = None
for dag in dags:
R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
if R.max() > max_R:
max_R = R.max()
best_dag = dag

self.model.setDAG(best_dag)

+ 56
- 0
fastNLP/models/enas_utils.py View File

@@ -0,0 +1,56 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

from __future__ import print_function

from collections import defaultdict
import collections
from datetime import datetime
import os
import json

import numpy as np

import torch
from torch.autograd import Variable

def detach(h):
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(detach(v) for v in h)

def get_variable(inputs, cuda=False, **kwargs):
if type(inputs) in [list, np.ndarray]:
inputs = torch.Tensor(inputs)
if cuda:
out = Variable(inputs.cuda(), **kwargs)
else:
out = Variable(inputs, **kwargs)
return out

def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr

Node = collections.namedtuple('Node', ['id', 'name'])


class keydefaultdict(defaultdict):
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
else:
ret = self[key] = self.default_factory(key)
return ret


def to_item(x):
"""Converts x, possibly scalar and possibly tensor, to a Python scalar."""
if isinstance(x, (float, int)):
return x

if float(torch.__version__[0:3]) < 0.4:
assert (x.dim() == 1) and (len(x) == 1)
return x[0]

return x.item()

+ 112
- 0
test/models/test_enas.py View File

@@ -0,0 +1,112 @@
import unittest

from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import Vocabulary
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric


class TestENAS(unittest.TestCase):
def testENAS(self):
# 从csv读取数据到DataSet
sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv"
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
sep='\t')
print(len(dataset))
print(dataset[0])
print(dataset[-3])

dataset.append(Instance(raw_sentence='fake data', label='0'))
# 将所有数字转为小写
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
# label转int
dataset.apply(lambda x: int(x['label']), new_field_name='label')

# 使用空格分割句子
def split_sent(ins):
return ins['raw_sentence'].split()

dataset.apply(split_sent, new_field_name='words')

# 增加长度信息
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')
print(len(dataset))
print(dataset[0])

# DataSet.drop(func)筛除数据
dataset.drop(lambda x: x['seq_len'] <= 3)
print(len(dataset))

# 设置DataSet中,哪些field要转为tensor
# set target,loss或evaluate中的golden,计算loss,模型评估时使用
dataset.set_target("label")
# set input,模型forward时使用
dataset.set_input("words", "seq_len")

# 分出测试集、训练集
test_data, train_data = dataset.split(0.5)
print(len(test_data))
print(len(train_data))

# 构建词表, Vocabulary.add(word)
vocab = Vocabulary(min_freq=2)
train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
vocab.build_vocab()

# index句子, Vocabulary.to_index(word)
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
print(test_data[0])

# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler

batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())
for batch_x, batch_y in batch_iterator:
print("batch_x has: ", batch_x)
print("batch_y has: ", batch_y)
break

from fastNLP.models.enas_model import ENASModel
from fastNLP.models.enas_controller import Controller
model = ENASModel(embed_num=len(vocab), num_classes=5)
controller = Controller()

from fastNLP.models.enas_trainer import ENASTrainer
from copy import deepcopy

# 更改DataSet中对应field的名称,要以模型的forward等参数名一致
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致
train_data.rename_field('label', 'label_seq')
test_data.rename_field('words', 'word_seq')
test_data.rename_field('label', 'label_seq')

loss = CrossEntropyLoss(pred="output", target="label_seq")
metric = AccuracyMetric(pred="predict", target="label_seq")

trainer = ENASTrainer(model=model, controller=controller, train_data=train_data, dev_data=test_data,
loss=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"),
check_code_level=-1,
save_path=None,
batch_size=32,
print_every=1,
n_epochs=3,
final_epochs=1)
trainer.train()
print('Train finished!')

# 调用Tester在test_data上评价效果
from fastNLP import Tester

tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"),
batch_size=4)

acc = tester.test()
print(acc)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save