Browse Source

全部改为相对路径引用

tags/v0.4.10
ChenXin 5 years ago
parent
commit
f66012a640
26 changed files with 131 additions and 146 deletions
  1. +1
    -0
      fastNLP/api/__init__.py
  2. +63
    -66
      fastNLP/api/api.py
  3. +1
    -1
      fastNLP/api/examples.py
  4. +1
    -1
      fastNLP/api/pipeline.py
  5. +4
    -4
      fastNLP/api/processor.py
  6. +7
    -7
      fastNLP/automl/enas_trainer.py
  7. +1
    -1
      fastNLP/models/base_model.py
  8. +1
    -1
      fastNLP/models/bert.py
  9. +1
    -1
      fastNLP/models/char_language_model.py
  10. +2
    -3
      fastNLP/models/enas_controller.py
  11. +2
    -3
      fastNLP/models/enas_model.py
  12. +9
    -12
      fastNLP/models/enas_trainer.py
  13. +6
    -5
      fastNLP/models/enas_utils.py
  14. +4
    -4
      fastNLP/models/sequence_modeling.py
  15. +5
    -5
      fastNLP/models/snli.py
  16. +0
    -1
      fastNLP/models/star_transformer.py
  17. +3
    -3
      fastNLP/modules/aggregator/attention.py
  18. +11
    -19
      fastNLP/modules/aggregator/pooling.py
  19. +2
    -2
      fastNLP/modules/decoder/CRF.py
  20. +1
    -1
      fastNLP/modules/decoder/MLP.py
  21. +1
    -1
      fastNLP/modules/encoder/char_encoder.py
  22. +1
    -1
      fastNLP/modules/encoder/conv_maxpool.py
  23. +1
    -1
      fastNLP/modules/encoder/embedding.py
  24. +1
    -1
      fastNLP/modules/encoder/linear.py
  25. +1
    -1
      fastNLP/modules/encoder/lstm.py
  26. +1
    -1
      fastNLP/modules/encoder/variational_rnn.py

+ 1
- 0
fastNLP/api/__init__.py View File

@@ -1 +1,2 @@
__all__ = ["CWS", "POS", "Parser"]
from .api import CWS, POS, Parser

+ 63
- 66
fastNLP/api/api.py View File

@@ -1,6 +1,3 @@
"""
api.api的介绍文档
"""
import warnings

import torch
@@ -8,15 +5,14 @@ import torch
warnings.filterwarnings('ignore')
import os

from fastNLP.core.dataset import DataSet

from fastNLP.api.utils import load_url
from fastNLP.api.processor import ModelProcessor
from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader
from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.api.processor import IndexerProcessor
from ..core.dataset import DataSet
from .utils import load_url
from .processor import ModelProcessor
from ..io.dataset_loader import _cut_long_sentence, ConllLoader
from ..core.instance import Instance
from ..api.pipeline import Pipeline
from ..core.metrics import SpanFPreRecMetric
from .processor import IndexerProcessor

# TODO add pretrain urls
model_urls = {
@@ -28,9 +24,10 @@ model_urls = {

class ConllCWSReader(object):
"""Deprecated. Use ConllLoader for all types of conll-format files."""
def __init__(self):
pass
def load(self, path, cut_long_sent=False):
"""
返回的DataSet只包含raw_sentence这个field,内容为str。
@@ -63,7 +60,7 @@ class ConllCWSReader(object):
sample.append(line.strip().split())
if len(sample) > 0:
datalist.append(sample)
ds = DataSet()
for sample in datalist:
# print(sample)
@@ -78,7 +75,7 @@ class ConllCWSReader(object):
for raw_sentence in sents:
ds.append(Instance(raw_sentence=raw_sentence))
return ds
def get_char_lst(self, sample):
if len(sample) == 0:
return None
@@ -90,11 +87,13 @@ class ConllCWSReader(object):
text.append(t1)
return text


class ConllxDataLoader(ConllLoader):
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。

Deprecated. Use ConllLoader for all types of conll-format files.
"""
def __init__(self):
headers = [
'words', 'pos_tags', 'heads', 'labels',
@@ -106,18 +105,15 @@ class ConllxDataLoader(ConllLoader):


class API:
"""
这是 API 类的文档
"""
def __init__(self):
self.pipeline = None
self._dict = None
def predict(self, *args, **kwargs):
"""Do prediction for the given input.
"""
raise NotImplementedError
def test(self, file_path):
"""Test performance over the given data set.

@@ -125,7 +121,7 @@ class API:
:return: a dictionary of metric values
"""
raise NotImplementedError
def load(self, path, device):
if os.path.exists(os.path.expanduser(path)):
_dict = torch.load(path, map_location='cpu')
@@ -145,14 +141,14 @@ class POS(API):
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch.

"""
def __init__(self, model_path=None, device='cpu'):
super(POS, self).__init__()
if model_path is None:
model_path = model_urls['pos']
self.load(model_path, device)
def predict(self, content):
"""predict函数的介绍,
函数介绍的第二句,这句话不会换行
@@ -162,48 +158,48 @@ class POS(API):
"""
if not hasattr(self, "pipeline"):
raise ValueError("You have to load model first.")
sentence_list = content
# 1. 检查sentence的类型
for sentence in sentence_list:
if not all((type(obj) == str for obj in sentence)):
raise ValueError("Input must be list of list of string.")
# 2. 组建dataset
dataset = DataSet()
dataset.add_field("words", sentence_list)
# 3. 使用pipeline
self.pipeline(dataset)
def merge_tag(words_list, tags_list):
rtn = []
for words, tags in zip(words_list, tags_list):
rtn.append([w + "/" + t for w, t in zip(words, tags)])
return rtn
output = dataset.field_arrays["tag"].content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
return merge_tag(content, output)
def test(self, file_path):
test_data = ConllxDataLoader().load(file_path)
save_dict = self._dict
tag_vocab = save_dict["tag_vocab"]
pipeline = save_dict["pipeline"]
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
pipeline.pipeline = [index_tag] + pipeline.pipeline
test_data.rename_field("pos_tags", "tag")
pipeline(test_data)
test_data.set_target("truth")
prediction = test_data.field_arrays["predict"].content
truth = test_data.field_arrays["truth"].content
seq_len = test_data.field_arrays["word_seq_origin_len"].content
# padding by hand
max_length = max([len(seq) for seq in prediction])
for idx in range(len(prediction)):
@@ -217,7 +213,7 @@ class POS(API):
f1 = round(test_result['f'] * 100, 2)
pre = round(test_result['pre'] * 100, 2)
rec = round(test_result['rec'] * 100, 2)
return {"F1": f1, "precision": pre, "recall": rec}


@@ -228,14 +224,15 @@ class CWS(API):
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。
"""
def __init__(self, model_path=None, device='cpu'):
super(CWS, self).__init__()
if model_path is None:
model_path = model_urls['cws']
self.load(model_path, device)
def predict(self, content):
"""
分词接口。
@@ -246,27 +243,27 @@ class CWS(API):
"""
if not hasattr(self, 'pipeline'):
raise ValueError("You have to load model first.")
sentence_list = []
# 1. 检查sentence的类型
if isinstance(content, str):
sentence_list.append(content)
elif isinstance(content, list):
sentence_list = content
# 2. 组建dataset
dataset = DataSet()
dataset.add_field('raw_sentence', sentence_list)
# 3. 使用pipeline
self.pipeline(dataset)
output = dataset.get_field('output').content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
return output
def test(self, filepath):
"""
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。
@@ -292,28 +289,28 @@ class CWS(API):
tag_proc = self._dict['tag_proc']
cws_model = self.pipeline.pipeline[-2].model
pipeline = self.pipeline.pipeline[:-2]
pipeline.insert(1, tag_proc)
pp = Pipeline(pipeline)
reader = ConllCWSReader()
# te_filename = '/home/hyan/ctb3/test.conllx'
te_dataset = reader.load(filepath)
pp(te_dataset)
from fastNLP.core.tester import Tester
from fastNLP.core.metrics import BMESF1PreRecMetric
tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64,
verbose=0)
eval_res = tester.test()
f1 = eval_res['BMESF1PreRecMetric']['f']
pre = eval_res['BMESF1PreRecMetric']['pre']
rec = eval_res['BMESF1PreRecMetric']['rec']
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
return {"F1": f1, "precision": pre, "recall": rec}


@@ -322,25 +319,25 @@ class Parser(API):
super(Parser, self).__init__()
if model_path is None:
model_path = model_urls['parser']
self.pos_tagger = POS(device=device)
self.load(model_path, device)
def predict(self, content):
if not hasattr(self, 'pipeline'):
raise ValueError("You have to load model first.")
# 1. 利用POS得到分词和pos tagging结果
pos_out = self.pos_tagger.predict(content)
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()]
# 2. 组建dataset
dataset = DataSet()
dataset.add_field('wp', pos_out)
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words')
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos')
dataset.rename_field("words", "raw_words")
# 3. 使用pipeline
self.pipeline(dataset)
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred')
@@ -348,7 +345,7 @@ class Parser(API):
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output')
# output like: [['2/top', '0/root', '4/nn', '2/dep']]
return dataset.field_arrays['output'].content
def load_test_file(self, path):
def get_one(sample):
sample = list(map(list, zip(*sample)))
@@ -360,7 +357,7 @@ class Parser(API):
return None
# return word_seq, pos_seq, head_seq, head_tag_seq
return sample[1], sample[3], list(map(int, sample[6])), sample[7]
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
@@ -374,14 +371,14 @@ class Parser(API):
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)
data = [get_one(sample) for sample in datalist]
data_list = list(filter(lambda x: x is not None, data))
return data_list
def test(self, filepath):
data = self.load_test_file(filepath)
def convert(data):
BOS = '<BOS>'
dataset = DataSet()
@@ -396,7 +393,7 @@ class Parser(API):
arc_true=heads,
tags=head_tags))
return dataset
ds = convert(data)
pp = self.pipeline
for p in pp:
@@ -417,23 +414,23 @@ class Parser(API):
head_cor += 1 if head_pred[i] == head_gold[i] else 0
uas = head_cor / total
# print('uas:{:.2f}'.format(uas))
for p in pp:
if p.field_name == 'gold_words':
p.field_name = 'word_list'
elif p.field_name == 'gold_pos':
p.field_name = 'pos_list'
return {"USA": round(uas, 5)}


class Analyzer:
def __init__(self, device='cpu'):
self.cws = CWS(device=device)
self.pos = POS(device=device)
self.parser = Parser(device=device)
def predict(self, content, seg=False, pos=False, parser=False):
if seg is False and pos is False and parser is False:
seg = True
@@ -447,9 +444,9 @@ class Analyzer:
if parser:
parser_output = self.parser.predict(content)
output_dict['parser'] = parser_output
return output_dict
def test(self, filepath):
output_dict = {}
if self.cws:
@@ -461,5 +458,5 @@ class Analyzer:
if self.parser:
parser_output = self.parser.test(filepath)
output_dict['parser'] = parser_output
return output_dict

+ 1
- 1
fastNLP/api/examples.py View File

@@ -3,7 +3,7 @@ api/example.py contains all API examples provided by fastNLP.
It is used as a tutorial for API or a test script since it is difficult to test APIs in travis.

"""
from fastNLP.api import CWS, POS, Parser
from . import CWS, POS, Parser

text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',


+ 1
- 1
fastNLP/api/pipeline.py View File

@@ -1,4 +1,4 @@
from fastNLP.api.processor import Processor
from ..api.processor import Processor


class Pipeline:


+ 4
- 4
fastNLP/api/processor.py View File

@@ -3,10 +3,10 @@ from collections import defaultdict

import torch

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.vocabulary import Vocabulary
from ..core.batch import Batch
from ..core.dataset import DataSet
from ..core.sampler import SequentialSampler
from ..core.vocabulary import Vocabulary


class Processor(object):


+ 7
- 7
fastNLP/automl/enas_trainer.py View File

@@ -11,15 +11,15 @@ import torch
try:
from tqdm.autonotebook import tqdm
except:
from fastNLP.core.utils import _pseudo_tqdm as tqdm
from ..core.utils import _pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackException
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import _move_dict_value_to_device
from ..core.batch import Batch
from ..core.callback import CallbackException
from ..core.dataset import DataSet
from ..core.utils import _move_dict_value_to_device
import fastNLP
import fastNLP.automl.enas_utils as utils
from fastNLP.core.utils import _build_args
from . import enas_utils as utils
from ..core.utils import _build_args

from torch.optim import Adam



+ 1
- 1
fastNLP/models/base_model.py View File

@@ -1,6 +1,6 @@
import torch

from fastNLP.modules.decoder.MLP import MLP
from ..modules.decoder.MLP import MLP


class BaseModel(torch.nn.Module):


+ 1
- 1
fastNLP/models/bert.py View File

@@ -6,7 +6,7 @@ import torch
from torch import nn

from .base_model import BaseModel
from fastNLP.modules.encoder import BertModel
from ..modules.encoder import BertModel


class BertForSequenceClassification(BaseModel):


+ 1
- 1
fastNLP/models/char_language_model.py View File

@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F

from fastNLP.modules.encoder.lstm import LSTM
from ..modules.encoder.lstm import LSTM


class Highway(nn.Module):


+ 2
- 3
fastNLP/models/enas_controller.py View File

@@ -5,9 +5,8 @@ import os

import torch
import torch.nn.functional as F
import fastNLP
import fastNLP.models.enas_utils as utils
from fastNLP.models.enas_utils import Node
from . import enas_utils as utils
from .enas_utils import Node


def _construct_dags(prev_nodes, activations, func_names, num_blocks):


+ 2
- 3
fastNLP/models/enas_model.py View File

@@ -9,9 +9,8 @@ from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable

import fastNLP.models.enas_utils as utils
from fastNLP.models.base_model import BaseModel
import fastNLP.modules.encoder as encoder
from . import enas_utils as utils
from .base_model import BaseModel

def _get_dropped_weights(w_raw, dropout_p, is_training):
"""Drops out weights to implement DropConnect.


+ 9
- 12
fastNLP/models/enas_trainer.py View File

@@ -1,6 +1,5 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

import os
import time
from datetime import datetime
from datetime import timedelta
@@ -8,21 +7,19 @@ from datetime import timedelta
import numpy as np
import torch
import math
from torch import nn

try:
from tqdm.autonotebook import tqdm
except:
from fastNLP.core.utils import _pseudo_tqdm as tqdm
from ..core.utils import _pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager, CallbackException
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _move_dict_value_to_device
import fastNLP
import fastNLP.models.enas_utils as utils
from fastNLP.core.utils import _build_args
from ..core.trainer import Trainer
from ..core.batch import Batch
from ..core.callback import CallbackManager, CallbackException
from ..core.dataset import DataSet
from ..core.utils import _move_dict_value_to_device
from . import enas_utils as utils
from ..core.utils import _build_args

from torch.optim import Adam

@@ -34,7 +31,7 @@ def _get_no_grad_ctx_mgr():
return torch.no_grad()


class ENASTrainer(fastNLP.Trainer):
class ENASTrainer(Trainer):
"""A class to wrap training code."""
def __init__(self, train_data, model, controller, **kwargs):
"""Constructor for training algorithm.


+ 6
- 5
fastNLP/models/enas_utils.py View File

@@ -4,21 +4,20 @@ from __future__ import print_function

from collections import defaultdict
import collections
from datetime import datetime
import os
import json

import numpy as np

import torch
from torch.autograd import Variable


def detach(h):
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(detach(v) for v in h)


def get_variable(inputs, cuda=False, **kwargs):
if type(inputs) in [list, np.ndarray]:
inputs = torch.Tensor(inputs)
@@ -28,10 +27,12 @@ def get_variable(inputs, cuda=False, **kwargs):
out = Variable(inputs, **kwargs)
return out


def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr


Node = collections.namedtuple('Node', ['id', 'name'])


@@ -48,9 +49,9 @@ def to_item(x):
"""Converts x, possibly scalar and possibly tensor, to a Python scalar."""
if isinstance(x, (float, int)):
return x
if float(torch.__version__[0:3]) < 0.4:
assert (x.dim() == 1) and (len(x) == 1)
return x[0]
return x.item()

+ 4
- 4
fastNLP/models/sequence_modeling.py View File

@@ -1,9 +1,9 @@
import torch

from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder, encoder
from fastNLP.modules.decoder.CRF import allowed_transitions
from fastNLP.modules.utils import seq_mask
from .base_model import BaseModel
from ..modules import decoder, encoder
from ..modules.decoder.CRF import allowed_transitions
from ..modules.utils import seq_mask


class SeqLabeling(BaseModel):


+ 5
- 5
fastNLP/models/snli.py View File

@@ -1,11 +1,11 @@
import torch
import torch.nn as nn

from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder as Decoder
from fastNLP.modules import encoder as Encoder
from fastNLP.modules import aggregator as Aggregator
from fastNLP.modules.utils import seq_mask
from .base_model import BaseModel
from ..modules import decoder as Decoder
from ..modules import encoder as Encoder
from ..modules import aggregator as Aggregator
from ..modules.utils import seq_mask


my_inf = 10e12


+ 0
- 1
fastNLP/models/star_transformer.py View File

@@ -7,7 +7,6 @@ from ..core.const import Const

import torch
from torch import nn
import torch.nn.functional as F


class StarTransEnc(nn.Module):


+ 3
- 3
fastNLP/modules/aggregator/attention.py View File

@@ -4,10 +4,10 @@ import torch
import torch.nn.functional as F
from torch import nn

from fastNLP.modules.dropout import TimestepDropout
from fastNLP.modules.utils import mask_softmax
from ..dropout import TimestepDropout
from ..utils import mask_softmax

from fastNLP.modules.utils import initial_parameter
from ..utils import initial_parameter


class Attention(torch.nn.Module):


+ 11
- 19
fastNLP/modules/aggregator/pooling.py View File

@@ -1,17 +1,12 @@
# python: 3.6
# encoding: utf-8

import torch
import torch.nn as nn


class MaxPool(nn.Module):
"""Max-pooling模块。"""

def __init__(
self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None,
return_indices=False, ceil_mode=False
):
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None,
return_indices=False, ceil_mode=False):
"""
:param stride: 窗口移动大小,默认为kernel_size
:param padding: padding的内容,默认为0
@@ -30,7 +25,7 @@ class MaxPool(nn.Module):
self.kernel_size = kernel_size
self.return_indices = return_indices
self.ceil_mode = ceil_mode
def forward(self, x):
if self.dimension == 1:
pooling = nn.MaxPool1d(
@@ -57,10 +52,11 @@ class MaxPool(nn.Module):

class MaxPoolWithMask(nn.Module):
"""带mask矩阵的1维max pooling"""
def __init__(self):
super(MaxPoolWithMask, self).__init__()
self.inf = 10e12
def forward(self, tensor, mask, dim=1):
"""
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor
@@ -75,11 +71,11 @@ class MaxPoolWithMask(nn.Module):

class KMaxPool(nn.Module):
"""K max-pooling module."""
def __init__(self, k=1):
super(KMaxPool, self).__init__()
self.k = k
def forward(self, x):
"""
:param torch.Tensor x: [N, C, L] 初始tensor
@@ -92,12 +88,12 @@ class KMaxPool(nn.Module):

class AvgPool(nn.Module):
"""1-d average pooling module."""
def __init__(self, stride=None, padding=0):
super(AvgPool, self).__init__()
self.stride = stride
self.padding = padding
def forward(self, x):
"""
:param torch.Tensor x: [N, C, L] 初始tensor
@@ -117,7 +113,7 @@ class MeanPoolWithMask(nn.Module):
def __init__(self):
super(MeanPoolWithMask, self).__init__()
self.inf = 10e12
def forward(self, tensor, mask, dim=1):
"""
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor
@@ -127,7 +123,3 @@ class MeanPoolWithMask(nn.Module):
"""
masks = mask.view(mask.size(0), mask.size(1), -1).float()
return torch.sum(tensor * masks.float(), dim=dim) / torch.sum(masks.float(), dim=1)





+ 2
- 2
fastNLP/modules/decoder/CRF.py View File

@@ -1,8 +1,8 @@
import torch
from torch import nn

from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.decoder.utils import log_sum_exp
from ..utils import initial_parameter
from ..decoder.utils import log_sum_exp


def seq_len_to_byte_mask(seq_lens):


+ 1
- 1
fastNLP/modules/decoder/MLP.py View File

@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from fastNLP.modules.utils import initial_parameter
from ..utils import initial_parameter


class MLP(nn.Module):


+ 1
- 1
fastNLP/modules/encoder/char_encoder.py View File

@@ -1,7 +1,7 @@
import torch
from torch import nn

from fastNLP.modules.utils import initial_parameter
from ..utils import initial_parameter


# from torch.nn.init import xavier_uniform


+ 1
- 1
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F

from fastNLP.modules.utils import initial_parameter
from ..utils import initial_parameter


class ConvMaxpool(nn.Module):


+ 1
- 1
fastNLP/modules/encoder/embedding.py View File

@@ -1,5 +1,5 @@
import torch.nn as nn
from fastNLP.modules.utils import get_embeddings
from ..utils import get_embeddings

class Embedding(nn.Embedding):
"""Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度"""


+ 1
- 1
fastNLP/modules/encoder/linear.py View File

@@ -1,6 +1,6 @@
import torch.nn as nn

from fastNLP.modules.utils import initial_parameter
from ..utils import initial_parameter


class Linear(nn.Module):


+ 1
- 1
fastNLP/modules/encoder/lstm.py View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn

from fastNLP.modules.utils import initial_parameter
from ..utils import initial_parameter


class LSTM(nn.Module):


+ 1
- 1
fastNLP/modules/encoder/variational_rnn.py View File

@@ -3,7 +3,7 @@
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
from fastNLP.modules.utils import initial_parameter
from ..utils import initial_parameter

try:
from torch import flip


Loading…
Cancel
Save