Browse Source

* 将enas相关代码放到automl目录下

* 修复fast_param_mapping的一个bug
* Trainer添加自动创建save目录
* Vocabulary的打印,显示内容
tags/v0.4.10
FengZiYjun 6 years ago
parent
commit
f5ab7a5d45
10 changed files with 23 additions and 24 deletions
  1. +0
    -0
      fastNLP/automl/__init__.py
  2. +3
    -3
      fastNLP/automl/enas_controller.py
  3. +4
    -4
      fastNLP/automl/enas_model.py
  4. +3
    -6
      fastNLP/automl/enas_trainer.py
  5. +2
    -5
      fastNLP/automl/enas_utils.py
  6. +1
    -1
      fastNLP/core/metrics.py
  7. +2
    -0
      fastNLP/core/trainer.py
  8. +3
    -0
      fastNLP/core/vocabulary.py
  9. +2
    -1
      fastNLP/modules/utils.py
  10. +3
    -4
      test/automl/test_enas.py

+ 0
- 0
fastNLP/automl/__init__.py View File


fastNLP/models/enas_controller.py → fastNLP/automl/enas_controller.py View File

@@ -5,9 +5,9 @@ import os


import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import fastNLP
import fastNLP.models.enas_utils as utils
from fastNLP.models.enas_utils import Node
import fastNLP.automl.enas_utils as utils
from fastNLP.automl.enas_utils import Node




def _construct_dags(prev_nodes, activations, func_names, num_blocks): def _construct_dags(prev_nodes, activations, func_names, num_blocks):

fastNLP/models/enas_model.py → fastNLP/automl/enas_model.py View File

@@ -1,17 +1,17 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch # Code Modified from https://github.com/carpedm20/ENAS-pytorch


"""Module containing the shared RNN model.""" """Module containing the shared RNN model."""
import numpy as np
import collections import collections


import numpy as np
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable from torch.autograd import Variable


import fastNLP.models.enas_utils as utils
import fastNLP.automl.enas_utils as utils
from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
import fastNLP.modules.encoder as encoder


def _get_dropped_weights(w_raw, dropout_p, is_training): def _get_dropped_weights(w_raw, dropout_p, is_training):
"""Drops out weights to implement DropConnect. """Drops out weights to implement DropConnect.

fastNLP/models/enas_trainer.py → fastNLP/automl/enas_trainer.py View File

@@ -1,14 +1,12 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch # Code Modified from https://github.com/carpedm20/ENAS-pytorch


import os
import math
import time import time
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta


import numpy as np import numpy as np
import torch import torch
import math
from torch import nn


try: try:
from tqdm.autonotebook import tqdm from tqdm.autonotebook import tqdm
@@ -16,12 +14,11 @@ except:
from fastNLP.core.utils import pseudo_tqdm as tqdm from fastNLP.core.utils import pseudo_tqdm as tqdm


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager, CallbackException
from fastNLP.core.callback import CallbackException
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
import fastNLP import fastNLP
import fastNLP.models.enas_utils as utils
import fastNLP.automl.enas_utils as utils
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args


from torch.optim import Adam from torch.optim import Adam

fastNLP/models/enas_utils.py → fastNLP/automl/enas_utils.py View File

@@ -2,17 +2,14 @@


from __future__ import print_function from __future__ import print_function


from collections import defaultdict
import collections import collections
from datetime import datetime
import os
import json
from collections import defaultdict


import numpy as np import numpy as np

import torch import torch
from torch.autograd import Variable from torch.autograd import Variable



def detach(h): def detach(h):
if type(h) == Variable: if type(h) == Variable:
return Variable(h.data) return Variable(h.data)

+ 1
- 1
fastNLP/core/metrics.py View File

@@ -157,7 +157,7 @@ class MetricBase(object):
fast_param = {} fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0] fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param return fast_param
return fast_param return fast_param




+ 2
- 0
fastNLP/core/trainer.py View File

@@ -367,6 +367,8 @@ class Trainer(object):
""" """
if self.save_path is not None: if self.save_path is not None:
model_path = os.path.join(self.save_path, model_name) model_path = os.path.join(self.save_path, model_name)
if not os.path.exists(self.save_path):
os.makedirs(self.save_path, exist_ok=True)
if only_param: if only_param:
state_dict = model.state_dict() state_dict = model.state_dict()
for key in state_dict: for key in state_dict:


+ 3
- 0
fastNLP/core/vocabulary.py View File

@@ -196,3 +196,6 @@ class Vocabulary(object):
""" """
self.__dict__.update(state) self.__dict__.update(state)
self.build_reverse_vocab() self.build_reverse_vocab()

def __repr__(self):
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5])

+ 2
- 1
fastNLP/modules/utils.py View File

@@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None):
init_method(w.data) # weight init_method(w.data) # weight
else: else:
init.normal_(w.data) # bias init.normal_(w.data) # bias
elif hasattr(m, 'weight') and m.weight.requires_grad:
elif m is not None and hasattr(m, 'weight') and \
hasattr(m.weight, "requires_grad"):
init_method(m.weight.data) init_method(m.weight.data)
else: else:
for w in m.parameters(): for w in m.parameters():


test/models/test_enas.py → test/automl/test_enas.py View File

@@ -69,13 +69,12 @@ class TestENAS(unittest.TestCase):
print("batch_y has: ", batch_y) print("batch_y has: ", batch_y)
break break


from fastNLP.models.enas_model import ENASModel
from fastNLP.models.enas_controller import Controller
from fastNLP.automl.enas_model import ENASModel
from fastNLP.automl.enas_controller import Controller
model = ENASModel(embed_num=len(vocab), num_classes=5) model = ENASModel(embed_num=len(vocab), num_classes=5)
controller = Controller() controller = Controller()


from fastNLP.models.enas_trainer import ENASTrainer
from copy import deepcopy
from fastNLP.automl.enas_trainer import ENASTrainer


# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 # 更改DataSet中对应field的名称,要以模型的forward等参数名一致
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致

Loading…
Cancel
Save