Browse Source

* 将enas相关代码放到automl目录下

* 修复fast_param_mapping的一个bug
* Trainer添加自动创建save目录
* Vocabulary的打印,显示内容
tags/v0.4.10
FengZiYjun 5 years ago
parent
commit
f5ab7a5d45
10 changed files with 23 additions and 24 deletions
  1. +0
    -0
      fastNLP/automl/__init__.py
  2. +3
    -3
      fastNLP/automl/enas_controller.py
  3. +4
    -4
      fastNLP/automl/enas_model.py
  4. +3
    -6
      fastNLP/automl/enas_trainer.py
  5. +2
    -5
      fastNLP/automl/enas_utils.py
  6. +1
    -1
      fastNLP/core/metrics.py
  7. +2
    -0
      fastNLP/core/trainer.py
  8. +3
    -0
      fastNLP/core/vocabulary.py
  9. +2
    -1
      fastNLP/modules/utils.py
  10. +3
    -4
      test/automl/test_enas.py

+ 0
- 0
fastNLP/automl/__init__.py View File


fastNLP/models/enas_controller.py → fastNLP/automl/enas_controller.py View File

@@ -5,9 +5,9 @@ import os

import torch
import torch.nn.functional as F
import fastNLP
import fastNLP.models.enas_utils as utils
from fastNLP.models.enas_utils import Node
import fastNLP.automl.enas_utils as utils
from fastNLP.automl.enas_utils import Node


def _construct_dags(prev_nodes, activations, func_names, num_blocks):

fastNLP/models/enas_model.py → fastNLP/automl/enas_model.py View File

@@ -1,17 +1,17 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

"""Module containing the shared RNN model."""
import numpy as np
import collections

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

import fastNLP.models.enas_utils as utils
import fastNLP.automl.enas_utils as utils
from fastNLP.models.base_model import BaseModel
import fastNLP.modules.encoder as encoder

def _get_dropped_weights(w_raw, dropout_p, is_training):
"""Drops out weights to implement DropConnect.

fastNLP/models/enas_trainer.py → fastNLP/automl/enas_trainer.py View File

@@ -1,14 +1,12 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

import os
import math
import time
from datetime import datetime
from datetime import timedelta

import numpy as np
import torch
import math
from torch import nn

try:
from tqdm.autonotebook import tqdm
@@ -16,12 +14,11 @@ except:
from fastNLP.core.utils import pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager, CallbackException
from fastNLP.core.callback import CallbackException
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _move_dict_value_to_device
import fastNLP
import fastNLP.models.enas_utils as utils
import fastNLP.automl.enas_utils as utils
from fastNLP.core.utils import _build_args

from torch.optim import Adam

fastNLP/models/enas_utils.py → fastNLP/automl/enas_utils.py View File

@@ -2,17 +2,14 @@

from __future__ import print_function

from collections import defaultdict
import collections
from datetime import datetime
import os
import json
from collections import defaultdict

import numpy as np

import torch
from torch.autograd import Variable


def detach(h):
if type(h) == Variable:
return Variable(h.data)

+ 1
- 1
fastNLP/core/metrics.py View File

@@ -157,7 +157,7 @@ class MetricBase(object):
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param
return fast_param



+ 2
- 0
fastNLP/core/trainer.py View File

@@ -367,6 +367,8 @@ class Trainer(object):
"""
if self.save_path is not None:
model_path = os.path.join(self.save_path, model_name)
if not os.path.exists(self.save_path):
os.makedirs(self.save_path, exist_ok=True)
if only_param:
state_dict = model.state_dict()
for key in state_dict:


+ 3
- 0
fastNLP/core/vocabulary.py View File

@@ -196,3 +196,6 @@ class Vocabulary(object):
"""
self.__dict__.update(state)
self.build_reverse_vocab()

def __repr__(self):
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5])

+ 2
- 1
fastNLP/modules/utils.py View File

@@ -60,7 +60,8 @@ def initial_parameter(net, initial_method=None):
init_method(w.data) # weight
else:
init.normal_(w.data) # bias
elif hasattr(m, 'weight') and m.weight.requires_grad:
elif m is not None and hasattr(m, 'weight') and \
hasattr(m.weight, "requires_grad"):
init_method(m.weight.data)
else:
for w in m.parameters():


test/models/test_enas.py → test/automl/test_enas.py View File

@@ -69,13 +69,12 @@ class TestENAS(unittest.TestCase):
print("batch_y has: ", batch_y)
break

from fastNLP.models.enas_model import ENASModel
from fastNLP.models.enas_controller import Controller
from fastNLP.automl.enas_model import ENASModel
from fastNLP.automl.enas_controller import Controller
model = ENASModel(embed_num=len(vocab), num_classes=5)
controller = Controller()

from fastNLP.models.enas_trainer import ENASTrainer
from copy import deepcopy
from fastNLP.automl.enas_trainer import ENASTrainer

# 更改DataSet中对应field的名称,要以模型的forward等参数名一致
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致

Loading…
Cancel
Save