@@ -5,9 +5,9 @@ import os | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
import fastNLP | |||||
import fastNLP.models.enas_utils as utils | |||||
from fastNLP.models.enas_utils import Node | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.automl.enas_utils import Node | |||||
def _construct_dags(prev_nodes, activations, func_names, num_blocks): | def _construct_dags(prev_nodes, activations, func_names, num_blocks): |
@@ -1,17 +1,17 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | # Code Modified from https://github.com/carpedm20/ENAS-pytorch | ||||
"""Module containing the shared RNN model.""" | """Module containing the shared RNN model.""" | ||||
import numpy as np | |||||
import collections | import collections | ||||
import numpy as np | |||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | |||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
import fastNLP.models.enas_utils as utils | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
import fastNLP.modules.encoder as encoder | |||||
def _get_dropped_weights(w_raw, dropout_p, is_training): | def _get_dropped_weights(w_raw, dropout_p, is_training): | ||||
"""Drops out weights to implement DropConnect. | """Drops out weights to implement DropConnect. |
@@ -1,14 +1,12 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | # Code Modified from https://github.com/carpedm20/ENAS-pytorch | ||||
import os | |||||
import math | |||||
import time | import time | ||||
from datetime import datetime | from datetime import datetime | ||||
from datetime import timedelta | from datetime import timedelta | ||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import math | |||||
from torch import nn | |||||
try: | try: | ||||
from tqdm.autonotebook import tqdm | from tqdm.autonotebook import tqdm | ||||
@@ -16,12 +14,11 @@ except: | |||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | from fastNLP.core.utils import pseudo_tqdm as tqdm | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.callback import CallbackManager, CallbackException | |||||
from fastNLP.core.callback import CallbackException | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
import fastNLP | import fastNLP | ||||
import fastNLP.models.enas_utils as utils | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from torch.optim import Adam | from torch.optim import Adam |
@@ -2,17 +2,14 @@ | |||||
from __future__ import print_function | from __future__ import print_function | ||||
from collections import defaultdict | |||||
import collections | import collections | ||||
from datetime import datetime | |||||
import os | |||||
import json | |||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
def detach(h): | def detach(h): | ||||
if type(h) == Variable: | if type(h) == Variable: | ||||
return Variable(h.data) | return Variable(h.data) |
@@ -17,6 +17,38 @@ class Callback(object): | |||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self.trainer = None # 在Trainer内部被重新赋值 | self.trainer = None # 在Trainer内部被重新赋值 | ||||
# callback只读属性 | |||||
self._n_epochs = None | |||||
self._n_steps = None | |||||
self._batch_size = None | |||||
self._model = None | |||||
self._pbar = None | |||||
self._optimizer = None | |||||
@property | |||||
def n_epochs(self): | |||||
return self._n_epochs | |||||
@property | |||||
def n_steps(self): | |||||
return self._n_steps | |||||
@property | |||||
def batch_size(self): | |||||
return self._batch_size | |||||
@property | |||||
def model(self): | |||||
return self._model | |||||
@property | |||||
def pbar(self): | |||||
return self._pbar | |||||
@property | |||||
def optimizer(self): | |||||
return self._optimizer | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
# before the main training loop | # before the main training loop | ||||
pass | pass | ||||
@@ -101,8 +133,6 @@ def transfer(func): | |||||
def wrapper(manager, *arg): | def wrapper(manager, *arg): | ||||
returns = [] | returns = [] | ||||
for callback in manager.callbacks: | for callback in manager.callbacks: | ||||
for env_name, env_value in manager.env.items(): | |||||
setattr(callback, env_name, env_value) | |||||
returns.append(getattr(callback, func.__name__)(*arg)) | returns.append(getattr(callback, func.__name__)(*arg)) | ||||
return returns | return returns | ||||
@@ -115,15 +145,15 @@ class CallbackManager(Callback): | |||||
""" | """ | ||||
def __init__(self, env, callbacks=None): | |||||
def __init__(self, env, attr, callbacks=None): | |||||
""" | """ | ||||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | ||||
:param dict attr: read-only attributes for all callbacks | |||||
:param Callback callbacks: | :param Callback callbacks: | ||||
""" | """ | ||||
super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
# set attribute of trainer environment | # set attribute of trainer environment | ||||
self.env = env | |||||
self.callbacks = [] | self.callbacks = [] | ||||
if callbacks is not None: | if callbacks is not None: | ||||
@@ -136,6 +166,23 @@ class CallbackManager(Callback): | |||||
else: | else: | ||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | ||||
for env_name, env_val in env.items(): | |||||
for callback in self.callbacks: | |||||
setattr(callback, env_name, env_val) # Callback.trainer | |||||
self.set_property(**attr) | |||||
def set_property(self, **kwargs): | |||||
"""设置所有callback的只读属性 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
for callback in self.callbacks: | |||||
for k, v in kwargs.items(): | |||||
setattr(callback, "_" + k, v) | |||||
@transfer | @transfer | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
pass | pass | ||||
@@ -157,7 +157,7 @@ class MetricBase(object): | |||||
fast_param = {} | fast_param = {} | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | ||||
fast_param['pred'] = list(pred_dict.values())[0] | fast_param['pred'] = list(pred_dict.values())[0] | ||||
fast_param['target'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | return fast_param | ||||
return fast_param | return fast_param | ||||
@@ -121,7 +121,6 @@ class Trainer(object): | |||||
self.best_dev_perf = None | self.best_dev_perf = None | ||||
self.sampler = sampler if sampler is not None else RandomSampler() | self.sampler = sampler if sampler is not None else RandomSampler() | ||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
@@ -144,6 +143,12 @@ class Trainer(object): | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||||
attr={"n_epochs": self.n_epochs, "n_steps": self.step, | |||||
"batch_size": self.batch_size, "model": self.model, | |||||
"optimizer": self.optimizer}, | |||||
callbacks=callbacks) | |||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
@@ -236,6 +241,7 @@ class Trainer(object): | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
self.callback_manager.set_property(pbar=pbar) | |||||
for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
@@ -361,6 +367,8 @@ class Trainer(object): | |||||
""" | """ | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
model_path = os.path.join(self.save_path, model_name) | model_path = os.path.join(self.save_path, model_name) | ||||
if not os.path.exists(self.save_path): | |||||
os.makedirs(self.save_path, exist_ok=True) | |||||
if only_param: | if only_param: | ||||
state_dict = model.state_dict() | state_dict = model.state_dict() | ||||
for key in state_dict: | for key in state_dict: | ||||
@@ -196,3 +196,9 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
def __repr__(self): | |||||
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | |||||
def __iter__(self): | |||||
return iter(list(self.word_count.keys())) |
@@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None): | |||||
init_method(w.data) # weight | init_method(w.data) # weight | ||||
else: | else: | ||||
init.normal_(w.data) # bias | init.normal_(w.data) # bias | ||||
elif hasattr(m, 'weight') and m.weight.requires_grad: | |||||
elif m is not None and hasattr(m, 'weight') and \ | |||||
hasattr(m.weight, "requires_grad"): | |||||
init_method(m.weight.data) | init_method(m.weight.data) | ||||
else: | else: | ||||
for w in m.parameters(): | for w in m.parameters(): | ||||
@@ -69,13 +69,12 @@ class TestENAS(unittest.TestCase): | |||||
print("batch_y has: ", batch_y) | print("batch_y has: ", batch_y) | ||||
break | break | ||||
from fastNLP.models.enas_model import ENASModel | |||||
from fastNLP.models.enas_controller import Controller | |||||
from fastNLP.automl.enas_model import ENASModel | |||||
from fastNLP.automl.enas_controller import Controller | |||||
model = ENASModel(embed_num=len(vocab), num_classes=5) | model = ENASModel(embed_num=len(vocab), num_classes=5) | ||||
controller = Controller() | controller = Controller() | ||||
from fastNLP.models.enas_trainer import ENASTrainer | |||||
from copy import deepcopy | |||||
from fastNLP.automl.enas_trainer import ENASTrainer | |||||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | ||||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 |
@@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[TensorboardCallback("loss", "metric")]) | callbacks=[TensorboardCallback("loss", "metric")]) | ||||
trainer.train() | trainer.train() | ||||
def test_readonly_property(self): | |||||
from fastNLP.core.callback import Callback | |||||
class MyCallback(Callback): | |||||
def __init__(self): | |||||
super(MyCallback, self).__init__() | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
print(self.n_epochs, self.n_steps, self.batch_size) | |||||
print(self.model) | |||||
print(self.optimizer) | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[MyCallback()]) | |||||
trainer.train() |
@@ -60,6 +60,15 @@ class TestIndexing(unittest.TestCase): | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | ||||
def test_iteration(self): | |||||
vocab = Vocabulary() | |||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||||
"works", "well", "in", "most", "cases", "scales", "well"] | |||||
vocab.update(text) | |||||
text = set(text) | |||||
for word in vocab: | |||||
self.assertTrue(word in text) | |||||
class TestOther(unittest.TestCase): | class TestOther(unittest.TestCase): | ||||
def test_additional_update(self): | def test_additional_update(self): | ||||