diff --git a/fastNLP/automl/__init__.py b/fastNLP/automl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/models/enas_controller.py b/fastNLP/automl/enas_controller.py similarity index 98% rename from fastNLP/models/enas_controller.py rename to fastNLP/automl/enas_controller.py index ae9bcfd2..6ddbb211 100644 --- a/fastNLP/models/enas_controller.py +++ b/fastNLP/automl/enas_controller.py @@ -5,9 +5,9 @@ import os import torch import torch.nn.functional as F -import fastNLP -import fastNLP.models.enas_utils as utils -from fastNLP.models.enas_utils import Node + +import fastNLP.automl.enas_utils as utils +from fastNLP.automl.enas_utils import Node def _construct_dags(prev_nodes, activations, func_names, num_blocks): diff --git a/fastNLP/models/enas_model.py b/fastNLP/automl/enas_model.py similarity index 99% rename from fastNLP/models/enas_model.py rename to fastNLP/automl/enas_model.py index cc91e675..4f9fb449 100644 --- a/fastNLP/models/enas_model.py +++ b/fastNLP/automl/enas_model.py @@ -1,17 +1,17 @@ # Code Modified from https://github.com/carpedm20/ENAS-pytorch """Module containing the shared RNN model.""" -import numpy as np import collections +import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn from torch.autograd import Variable -import fastNLP.models.enas_utils as utils +import fastNLP.automl.enas_utils as utils from fastNLP.models.base_model import BaseModel -import fastNLP.modules.encoder as encoder + def _get_dropped_weights(w_raw, dropout_p, is_training): """Drops out weights to implement DropConnect. diff --git a/fastNLP/models/enas_trainer.py b/fastNLP/automl/enas_trainer.py similarity index 98% rename from fastNLP/models/enas_trainer.py rename to fastNLP/automl/enas_trainer.py index 22e323ce..7c0da752 100644 --- a/fastNLP/models/enas_trainer.py +++ b/fastNLP/automl/enas_trainer.py @@ -1,14 +1,12 @@ # Code Modified from https://github.com/carpedm20/ENAS-pytorch -import os +import math import time from datetime import datetime from datetime import timedelta import numpy as np import torch -import math -from torch import nn try: from tqdm.autonotebook import tqdm @@ -16,12 +14,11 @@ except: from fastNLP.core.utils import pseudo_tqdm as tqdm from fastNLP.core.batch import Batch -from fastNLP.core.callback import CallbackManager, CallbackException +from fastNLP.core.callback import CallbackException from fastNLP.core.dataset import DataSet -from fastNLP.core.utils import CheckError from fastNLP.core.utils import _move_dict_value_to_device import fastNLP -import fastNLP.models.enas_utils as utils +import fastNLP.automl.enas_utils as utils from fastNLP.core.utils import _build_args from torch.optim import Adam diff --git a/fastNLP/models/enas_utils.py b/fastNLP/automl/enas_utils.py similarity index 96% rename from fastNLP/models/enas_utils.py rename to fastNLP/automl/enas_utils.py index e5027d81..7a53dd12 100644 --- a/fastNLP/models/enas_utils.py +++ b/fastNLP/automl/enas_utils.py @@ -2,17 +2,14 @@ from __future__ import print_function -from collections import defaultdict import collections -from datetime import datetime -import os -import json +from collections import defaultdict import numpy as np - import torch from torch.autograd import Variable + def detach(h): if type(h) == Variable: return Variable(h.data) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 54fde815..9d581798 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -157,7 +157,7 @@ class MetricBase(object): fast_param = {} if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: fast_param['pred'] = list(pred_dict.values())[0] - fast_param['target'] = list(pred_dict.values())[0] + fast_param['target'] = list(target_dict.values())[0] return fast_param return fast_param diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 743570fd..ca2ff93b 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -367,6 +367,8 @@ class Trainer(object): """ if self.save_path is not None: model_path = os.path.join(self.save_path, model_name) + if not os.path.exists(self.save_path): + os.makedirs(self.save_path, exist_ok=True) if only_param: state_dict = model.state_dict() for key in state_dict: diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 987a3527..1e0857f3 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -196,3 +196,6 @@ class Vocabulary(object): """ self.__dict__.update(state) self.build_reverse_vocab() + + def __repr__(self): + return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 5287bca4..4ae15b18 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None): init_method(w.data) # weight else: init.normal_(w.data) # bias - elif hasattr(m, 'weight') and m.weight.requires_grad: + elif m is not None and hasattr(m, 'weight') and \ + hasattr(m.weight, "requires_grad"): init_method(m.weight.data) else: for w in m.parameters(): diff --git a/test/models/test_enas.py b/test/automl/test_enas.py similarity index 94% rename from test/models/test_enas.py rename to test/automl/test_enas.py index 07a43205..d2d3af05 100644 --- a/test/models/test_enas.py +++ b/test/automl/test_enas.py @@ -69,13 +69,12 @@ class TestENAS(unittest.TestCase): print("batch_y has: ", batch_y) break - from fastNLP.models.enas_model import ENASModel - from fastNLP.models.enas_controller import Controller + from fastNLP.automl.enas_model import ENASModel + from fastNLP.automl.enas_controller import Controller model = ENASModel(embed_num=len(vocab), num_classes=5) controller = Controller() - from fastNLP.models.enas_trainer import ENASTrainer - from copy import deepcopy + from fastNLP.automl.enas_trainer import ENASTrainer # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致