* 修复fast_param_mapping的一个bug * Trainer添加自动创建save目录 * Vocabulary的打印,显示内容tags/v0.4.10
@@ -5,9 +5,9 @@ import os | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
import fastNLP | |||||
import fastNLP.models.enas_utils as utils | |||||
from fastNLP.models.enas_utils import Node | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.automl.enas_utils import Node | |||||
def _construct_dags(prev_nodes, activations, func_names, num_blocks): | def _construct_dags(prev_nodes, activations, func_names, num_blocks): |
@@ -1,17 +1,17 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | # Code Modified from https://github.com/carpedm20/ENAS-pytorch | ||||
"""Module containing the shared RNN model.""" | """Module containing the shared RNN model.""" | ||||
import numpy as np | |||||
import collections | import collections | ||||
import numpy as np | |||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | |||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
import fastNLP.models.enas_utils as utils | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
import fastNLP.modules.encoder as encoder | |||||
def _get_dropped_weights(w_raw, dropout_p, is_training): | def _get_dropped_weights(w_raw, dropout_p, is_training): | ||||
"""Drops out weights to implement DropConnect. | """Drops out weights to implement DropConnect. |
@@ -1,14 +1,12 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | # Code Modified from https://github.com/carpedm20/ENAS-pytorch | ||||
import os | |||||
import math | |||||
import time | import time | ||||
from datetime import datetime | from datetime import datetime | ||||
from datetime import timedelta | from datetime import timedelta | ||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import math | |||||
from torch import nn | |||||
try: | try: | ||||
from tqdm.autonotebook import tqdm | from tqdm.autonotebook import tqdm | ||||
@@ -16,12 +14,11 @@ except: | |||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | from fastNLP.core.utils import pseudo_tqdm as tqdm | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.callback import CallbackManager, CallbackException | |||||
from fastNLP.core.callback import CallbackException | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
import fastNLP | import fastNLP | ||||
import fastNLP.models.enas_utils as utils | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from torch.optim import Adam | from torch.optim import Adam |
@@ -2,17 +2,14 @@ | |||||
from __future__ import print_function | from __future__ import print_function | ||||
from collections import defaultdict | |||||
import collections | import collections | ||||
from datetime import datetime | |||||
import os | |||||
import json | |||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
def detach(h): | def detach(h): | ||||
if type(h) == Variable: | if type(h) == Variable: | ||||
return Variable(h.data) | return Variable(h.data) |
@@ -157,7 +157,7 @@ class MetricBase(object): | |||||
fast_param = {} | fast_param = {} | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | ||||
fast_param['pred'] = list(pred_dict.values())[0] | fast_param['pred'] = list(pred_dict.values())[0] | ||||
fast_param['target'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | return fast_param | ||||
return fast_param | return fast_param | ||||
@@ -367,6 +367,8 @@ class Trainer(object): | |||||
""" | """ | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
model_path = os.path.join(self.save_path, model_name) | model_path = os.path.join(self.save_path, model_name) | ||||
if not os.path.exists(self.save_path): | |||||
os.makedirs(self.save_path, exist_ok=True) | |||||
if only_param: | if only_param: | ||||
state_dict = model.state_dict() | state_dict = model.state_dict() | ||||
for key in state_dict: | for key in state_dict: | ||||
@@ -196,3 +196,6 @@ class Vocabulary(object): | |||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
def __repr__(self): | |||||
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) |
@@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None): | |||||
init_method(w.data) # weight | init_method(w.data) # weight | ||||
else: | else: | ||||
init.normal_(w.data) # bias | init.normal_(w.data) # bias | ||||
elif hasattr(m, 'weight') and m.weight.requires_grad: | |||||
elif m is not None and hasattr(m, 'weight') and \ | |||||
hasattr(m.weight, "requires_grad"): | |||||
init_method(m.weight.data) | init_method(m.weight.data) | ||||
else: | else: | ||||
for w in m.parameters(): | for w in m.parameters(): | ||||
@@ -69,13 +69,12 @@ class TestENAS(unittest.TestCase): | |||||
print("batch_y has: ", batch_y) | print("batch_y has: ", batch_y) | ||||
break | break | ||||
from fastNLP.models.enas_model import ENASModel | |||||
from fastNLP.models.enas_controller import Controller | |||||
from fastNLP.automl.enas_model import ENASModel | |||||
from fastNLP.automl.enas_controller import Controller | |||||
model = ENASModel(embed_num=len(vocab), num_classes=5) | model = ENASModel(embed_num=len(vocab), num_classes=5) | ||||
controller = Controller() | controller = Controller() | ||||
from fastNLP.models.enas_trainer import ENASTrainer | |||||
from copy import deepcopy | |||||
from fastNLP.automl.enas_trainer import ENASTrainer | |||||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | ||||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 |