@@ -5,9 +5,9 @@ import os | |||
import torch | |||
import torch.nn.functional as F | |||
import fastNLP | |||
import fastNLP.models.enas_utils as utils | |||
from fastNLP.models.enas_utils import Node | |||
import fastNLP.automl.enas_utils as utils | |||
from fastNLP.automl.enas_utils import Node | |||
def _construct_dags(prev_nodes, activations, func_names, num_blocks): |
@@ -1,17 +1,17 @@ | |||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||
"""Module containing the shared RNN model.""" | |||
import numpy as np | |||
import collections | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from torch.autograd import Variable | |||
import fastNLP.models.enas_utils as utils | |||
import fastNLP.automl.enas_utils as utils | |||
from fastNLP.models.base_model import BaseModel | |||
import fastNLP.modules.encoder as encoder | |||
def _get_dropped_weights(w_raw, dropout_p, is_training): | |||
"""Drops out weights to implement DropConnect. |
@@ -1,14 +1,12 @@ | |||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||
import os | |||
import math | |||
import time | |||
from datetime import datetime | |||
from datetime import timedelta | |||
import numpy as np | |||
import torch | |||
import math | |||
from torch import nn | |||
try: | |||
from tqdm.autonotebook import tqdm | |||
@@ -16,12 +14,11 @@ except: | |||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.callback import CallbackManager, CallbackException | |||
from fastNLP.core.callback import CallbackException | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
import fastNLP | |||
import fastNLP.models.enas_utils as utils | |||
import fastNLP.automl.enas_utils as utils | |||
from fastNLP.core.utils import _build_args | |||
from torch.optim import Adam |
@@ -2,17 +2,14 @@ | |||
from __future__ import print_function | |||
from collections import defaultdict | |||
import collections | |||
from datetime import datetime | |||
import os | |||
import json | |||
from collections import defaultdict | |||
import numpy as np | |||
import torch | |||
from torch.autograd import Variable | |||
def detach(h): | |||
if type(h) == Variable: | |||
return Variable(h.data) |
@@ -17,6 +17,38 @@ class Callback(object): | |||
super(Callback, self).__init__() | |||
self.trainer = None # 在Trainer内部被重新赋值 | |||
# callback只读属性 | |||
self._n_epochs = None | |||
self._n_steps = None | |||
self._batch_size = None | |||
self._model = None | |||
self._pbar = None | |||
self._optimizer = None | |||
@property | |||
def n_epochs(self): | |||
return self._n_epochs | |||
@property | |||
def n_steps(self): | |||
return self._n_steps | |||
@property | |||
def batch_size(self): | |||
return self._batch_size | |||
@property | |||
def model(self): | |||
return self._model | |||
@property | |||
def pbar(self): | |||
return self._pbar | |||
@property | |||
def optimizer(self): | |||
return self._optimizer | |||
def on_train_begin(self): | |||
# before the main training loop | |||
pass | |||
@@ -101,8 +133,6 @@ def transfer(func): | |||
def wrapper(manager, *arg): | |||
returns = [] | |||
for callback in manager.callbacks: | |||
for env_name, env_value in manager.env.items(): | |||
setattr(callback, env_name, env_value) | |||
returns.append(getattr(callback, func.__name__)(*arg)) | |||
return returns | |||
@@ -115,15 +145,15 @@ class CallbackManager(Callback): | |||
""" | |||
def __init__(self, env, callbacks=None): | |||
def __init__(self, env, attr, callbacks=None): | |||
""" | |||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||
:param dict attr: read-only attributes for all callbacks | |||
:param Callback callbacks: | |||
""" | |||
super(CallbackManager, self).__init__() | |||
# set attribute of trainer environment | |||
self.env = env | |||
self.callbacks = [] | |||
if callbacks is not None: | |||
@@ -136,6 +166,23 @@ class CallbackManager(Callback): | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
for env_name, env_val in env.items(): | |||
for callback in self.callbacks: | |||
setattr(callback, env_name, env_val) # Callback.trainer | |||
self.set_property(**attr) | |||
def set_property(self, **kwargs): | |||
"""设置所有callback的只读属性 | |||
:param kwargs: | |||
:return: | |||
""" | |||
for callback in self.callbacks: | |||
for k, v in kwargs.items(): | |||
setattr(callback, "_" + k, v) | |||
@transfer | |||
def on_train_begin(self): | |||
pass | |||
@@ -157,7 +157,7 @@ class MetricBase(object): | |||
fast_param = {} | |||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
fast_param['pred'] = list(pred_dict.values())[0] | |||
fast_param['target'] = list(pred_dict.values())[0] | |||
fast_param['target'] = list(target_dict.values())[0] | |||
return fast_param | |||
return fast_param | |||
@@ -121,7 +121,6 @@ class Trainer(object): | |||
self.best_dev_perf = None | |||
self.sampler = sampler if sampler is not None else RandomSampler() | |||
self.prefetch = prefetch | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
@@ -144,6 +143,12 @@ class Trainer(object): | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
attr={"n_epochs": self.n_epochs, "n_steps": self.step, | |||
"batch_size": self.batch_size, "model": self.model, | |||
"optimizer": self.optimizer}, | |||
callbacks=callbacks) | |||
def train(self, load_best_model=True): | |||
""" | |||
@@ -236,6 +241,7 @@ class Trainer(object): | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
prefetch=self.prefetch) | |||
self.callback_manager.set_property(pbar=pbar) | |||
for epoch in range(1, self.n_epochs+1): | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
@@ -361,6 +367,8 @@ class Trainer(object): | |||
""" | |||
if self.save_path is not None: | |||
model_path = os.path.join(self.save_path, model_name) | |||
if not os.path.exists(self.save_path): | |||
os.makedirs(self.save_path, exist_ok=True) | |||
if only_param: | |||
state_dict = model.state_dict() | |||
for key in state_dict: | |||
@@ -196,3 +196,9 @@ class Vocabulary(object): | |||
""" | |||
self.__dict__.update(state) | |||
self.build_reverse_vocab() | |||
def __repr__(self): | |||
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | |||
def __iter__(self): | |||
return iter(list(self.word_count.keys())) |
@@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None): | |||
init_method(w.data) # weight | |||
else: | |||
init.normal_(w.data) # bias | |||
elif hasattr(m, 'weight') and m.weight.requires_grad: | |||
elif m is not None and hasattr(m, 'weight') and \ | |||
hasattr(m.weight, "requires_grad"): | |||
init_method(m.weight.data) | |||
else: | |||
for w in m.parameters(): | |||
@@ -69,13 +69,12 @@ class TestENAS(unittest.TestCase): | |||
print("batch_y has: ", batch_y) | |||
break | |||
from fastNLP.models.enas_model import ENASModel | |||
from fastNLP.models.enas_controller import Controller | |||
from fastNLP.automl.enas_model import ENASModel | |||
from fastNLP.automl.enas_controller import Controller | |||
model = ENASModel(embed_num=len(vocab), num_classes=5) | |||
controller = Controller() | |||
from fastNLP.models.enas_trainer import ENASTrainer | |||
from copy import deepcopy | |||
from fastNLP.automl.enas_trainer import ENASTrainer | |||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 |
@@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[TensorboardCallback("loss", "metric")]) | |||
trainer.train() | |||
def test_readonly_property(self): | |||
from fastNLP.core.callback import Callback | |||
class MyCallback(Callback): | |||
def __init__(self): | |||
super(MyCallback, self).__init__() | |||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||
print(self.n_epochs, self.n_steps, self.batch_size) | |||
print(self.model) | |||
print(self.optimizer) | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=5, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[MyCallback()]) | |||
trainer.train() |
@@ -60,6 +60,15 @@ class TestIndexing(unittest.TestCase): | |||
vocab.update(text) | |||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||
def test_iteration(self): | |||
vocab = Vocabulary() | |||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
"works", "well", "in", "most", "cases", "scales", "well"] | |||
vocab.update(text) | |||
text = set(text) | |||
for word in vocab: | |||
self.assertTrue(word in text) | |||
class TestOther(unittest.TestCase): | |||
def test_additional_update(self): | |||