diff --git a/fastNLP/api/__init__.py b/fastNLP/api/__init__.py index a21a4c42..5171d8c2 100644 --- a/fastNLP/api/__init__.py +++ b/fastNLP/api/__init__.py @@ -1 +1,2 @@ +__all__ = ["CWS", "POS", "Parser"] from .api import CWS, POS, Parser diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 88f1755a..c72f3690 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -1,6 +1,3 @@ -""" -api.api的介绍文档 -""" import warnings import torch @@ -8,15 +5,14 @@ import torch warnings.filterwarnings('ignore') import os -from fastNLP.core.dataset import DataSet - -from fastNLP.api.utils import load_url -from fastNLP.api.processor import ModelProcessor -from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader -from fastNLP.core.instance import Instance -from fastNLP.api.pipeline import Pipeline -from fastNLP.core.metrics import SpanFPreRecMetric -from fastNLP.api.processor import IndexerProcessor +from ..core.dataset import DataSet +from .utils import load_url +from .processor import ModelProcessor +from ..io.dataset_loader import _cut_long_sentence, ConllLoader +from ..core.instance import Instance +from ..api.pipeline import Pipeline +from ..core.metrics import SpanFPreRecMetric +from .processor import IndexerProcessor # TODO add pretrain urls model_urls = { @@ -28,9 +24,10 @@ model_urls = { class ConllCWSReader(object): """Deprecated. Use ConllLoader for all types of conll-format files.""" + def __init__(self): pass - + def load(self, path, cut_long_sent=False): """ 返回的DataSet只包含raw_sentence这个field,内容为str。 @@ -63,7 +60,7 @@ class ConllCWSReader(object): sample.append(line.strip().split()) if len(sample) > 0: datalist.append(sample) - + ds = DataSet() for sample in datalist: # print(sample) @@ -78,7 +75,7 @@ class ConllCWSReader(object): for raw_sentence in sents: ds.append(Instance(raw_sentence=raw_sentence)) return ds - + def get_char_lst(self, sample): if len(sample) == 0: return None @@ -90,11 +87,13 @@ class ConllCWSReader(object): text.append(t1) return text + class ConllxDataLoader(ConllLoader): """返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 Deprecated. Use ConllLoader for all types of conll-format files. """ + def __init__(self): headers = [ 'words', 'pos_tags', 'heads', 'labels', @@ -106,18 +105,15 @@ class ConllxDataLoader(ConllLoader): class API: - """ - 这是 API 类的文档 - """ def __init__(self): self.pipeline = None self._dict = None - + def predict(self, *args, **kwargs): """Do prediction for the given input. """ raise NotImplementedError - + def test(self, file_path): """Test performance over the given data set. @@ -125,7 +121,7 @@ class API: :return: a dictionary of metric values """ raise NotImplementedError - + def load(self, path, device): if os.path.exists(os.path.expanduser(path)): _dict = torch.load(path, map_location='cpu') @@ -145,14 +141,14 @@ class POS(API): :param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. """ - + def __init__(self, model_path=None, device='cpu'): super(POS, self).__init__() if model_path is None: model_path = model_urls['pos'] - + self.load(model_path, device) - + def predict(self, content): """predict函数的介绍, 函数介绍的第二句,这句话不会换行 @@ -162,48 +158,48 @@ class POS(API): """ if not hasattr(self, "pipeline"): raise ValueError("You have to load model first.") - + sentence_list = content # 1. 检查sentence的类型 for sentence in sentence_list: if not all((type(obj) == str for obj in sentence)): raise ValueError("Input must be list of list of string.") - + # 2. 组建dataset dataset = DataSet() dataset.add_field("words", sentence_list) - + # 3. 使用pipeline self.pipeline(dataset) - + def merge_tag(words_list, tags_list): rtn = [] for words, tags in zip(words_list, tags_list): rtn.append([w + "/" + t for w, t in zip(words, tags)]) return rtn - + output = dataset.field_arrays["tag"].content if isinstance(content, str): return output[0] elif isinstance(content, list): return merge_tag(content, output) - + def test(self, file_path): test_data = ConllxDataLoader().load(file_path) - + save_dict = self._dict tag_vocab = save_dict["tag_vocab"] pipeline = save_dict["pipeline"] index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) pipeline.pipeline = [index_tag] + pipeline.pipeline - + test_data.rename_field("pos_tags", "tag") pipeline(test_data) test_data.set_target("truth") prediction = test_data.field_arrays["predict"].content truth = test_data.field_arrays["truth"].content seq_len = test_data.field_arrays["word_seq_origin_len"].content - + # padding by hand max_length = max([len(seq) for seq in prediction]) for idx in range(len(prediction)): @@ -217,7 +213,7 @@ class POS(API): f1 = round(test_result['f'] * 100, 2) pre = round(test_result['pre'] * 100, 2) rec = round(test_result['rec'] * 100, 2) - + return {"F1": f1, "precision": pre, "recall": rec} @@ -228,14 +224,15 @@ class CWS(API): :param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 :param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 """ + def __init__(self, model_path=None, device='cpu'): super(CWS, self).__init__() if model_path is None: model_path = model_urls['cws'] - + self.load(model_path, device) - + def predict(self, content): """ 分词接口。 @@ -246,27 +243,27 @@ class CWS(API): """ if not hasattr(self, 'pipeline'): raise ValueError("You have to load model first.") - + sentence_list = [] # 1. 检查sentence的类型 if isinstance(content, str): sentence_list.append(content) elif isinstance(content, list): sentence_list = content - + # 2. 组建dataset dataset = DataSet() dataset.add_field('raw_sentence', sentence_list) - + # 3. 使用pipeline self.pipeline(dataset) - + output = dataset.get_field('output').content if isinstance(content, str): return output[0] elif isinstance(content, list): return output - + def test(self, filepath): """ 传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 @@ -292,28 +289,28 @@ class CWS(API): tag_proc = self._dict['tag_proc'] cws_model = self.pipeline.pipeline[-2].model pipeline = self.pipeline.pipeline[:-2] - + pipeline.insert(1, tag_proc) pp = Pipeline(pipeline) - + reader = ConllCWSReader() - + # te_filename = '/home/hyan/ctb3/test.conllx' te_dataset = reader.load(filepath) pp(te_dataset) - + from fastNLP.core.tester import Tester from fastNLP.core.metrics import BMESF1PreRecMetric - + tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64, verbose=0) eval_res = tester.test() - + f1 = eval_res['BMESF1PreRecMetric']['f'] pre = eval_res['BMESF1PreRecMetric']['pre'] rec = eval_res['BMESF1PreRecMetric']['rec'] # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) - + return {"F1": f1, "precision": pre, "recall": rec} @@ -322,25 +319,25 @@ class Parser(API): super(Parser, self).__init__() if model_path is None: model_path = model_urls['parser'] - + self.pos_tagger = POS(device=device) self.load(model_path, device) - + def predict(self, content): if not hasattr(self, 'pipeline'): raise ValueError("You have to load model first.") - + # 1. 利用POS得到分词和pos tagging结果 pos_out = self.pos_tagger.predict(content) # pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()] - + # 2. 组建dataset dataset = DataSet() dataset.add_field('wp', pos_out) dataset.apply(lambda x: [''] + [w.split('/')[0] for w in x['wp']], new_field_name='words') dataset.apply(lambda x: [''] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') dataset.rename_field("words", "raw_words") - + # 3. 使用pipeline self.pipeline(dataset) dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred') @@ -348,7 +345,7 @@ class Parser(API): zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output') # output like: [['2/top', '0/root', '4/nn', '2/dep']] return dataset.field_arrays['output'].content - + def load_test_file(self, path): def get_one(sample): sample = list(map(list, zip(*sample))) @@ -360,7 +357,7 @@ class Parser(API): return None # return word_seq, pos_seq, head_seq, head_tag_seq return sample[1], sample[3], list(map(int, sample[6])), sample[7] - + datalist = [] with open(path, 'r', encoding='utf-8') as f: sample = [] @@ -374,14 +371,14 @@ class Parser(API): sample.append(line.split('\t')) if len(sample) > 0: datalist.append(sample) - + data = [get_one(sample) for sample in datalist] data_list = list(filter(lambda x: x is not None, data)) return data_list - + def test(self, filepath): data = self.load_test_file(filepath) - + def convert(data): BOS = '' dataset = DataSet() @@ -396,7 +393,7 @@ class Parser(API): arc_true=heads, tags=head_tags)) return dataset - + ds = convert(data) pp = self.pipeline for p in pp: @@ -417,23 +414,23 @@ class Parser(API): head_cor += 1 if head_pred[i] == head_gold[i] else 0 uas = head_cor / total # print('uas:{:.2f}'.format(uas)) - + for p in pp: if p.field_name == 'gold_words': p.field_name = 'word_list' elif p.field_name == 'gold_pos': p.field_name = 'pos_list' - + return {"USA": round(uas, 5)} class Analyzer: def __init__(self, device='cpu'): - + self.cws = CWS(device=device) self.pos = POS(device=device) self.parser = Parser(device=device) - + def predict(self, content, seg=False, pos=False, parser=False): if seg is False and pos is False and parser is False: seg = True @@ -447,9 +444,9 @@ class Analyzer: if parser: parser_output = self.parser.predict(content) output_dict['parser'] = parser_output - + return output_dict - + def test(self, filepath): output_dict = {} if self.cws: @@ -461,5 +458,5 @@ class Analyzer: if self.parser: parser_output = self.parser.test(filepath) output_dict['parser'] = parser_output - + return output_dict diff --git a/fastNLP/api/examples.py b/fastNLP/api/examples.py index a85e7c30..c1b2e155 100644 --- a/fastNLP/api/examples.py +++ b/fastNLP/api/examples.py @@ -3,7 +3,7 @@ api/example.py contains all API examples provided by fastNLP. It is used as a tutorial for API or a test script since it is difficult to test APIs in travis. """ -from fastNLP.api import CWS, POS, Parser +from . import CWS, POS, Parser text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', diff --git a/fastNLP/api/pipeline.py b/fastNLP/api/pipeline.py index 0c567678..2cec16b3 100644 --- a/fastNLP/api/pipeline.py +++ b/fastNLP/api/pipeline.py @@ -1,4 +1,4 @@ -from fastNLP.api.processor import Processor +from ..api.processor import Processor class Pipeline: diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 0bba96c0..be111cd0 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -3,10 +3,10 @@ from collections import defaultdict import torch -from fastNLP.core.batch import Batch -from fastNLP.core.dataset import DataSet -from fastNLP.core.sampler import SequentialSampler -from fastNLP.core.vocabulary import Vocabulary +from ..core.batch import Batch +from ..core.dataset import DataSet +from ..core.sampler import SequentialSampler +from ..core.vocabulary import Vocabulary class Processor(object): diff --git a/fastNLP/automl/enas_trainer.py b/fastNLP/automl/enas_trainer.py index a6316341..a9b1b8c3 100644 --- a/fastNLP/automl/enas_trainer.py +++ b/fastNLP/automl/enas_trainer.py @@ -11,15 +11,15 @@ import torch try: from tqdm.autonotebook import tqdm except: - from fastNLP.core.utils import _pseudo_tqdm as tqdm + from ..core.utils import _pseudo_tqdm as tqdm -from fastNLP.core.batch import Batch -from fastNLP.core.callback import CallbackException -from fastNLP.core.dataset import DataSet -from fastNLP.core.utils import _move_dict_value_to_device +from ..core.batch import Batch +from ..core.callback import CallbackException +from ..core.dataset import DataSet +from ..core.utils import _move_dict_value_to_device import fastNLP -import fastNLP.automl.enas_utils as utils -from fastNLP.core.utils import _build_args +from . import enas_utils as utils +from ..core.utils import _build_args from torch.optim import Adam diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py index ec532014..39ac99a0 100644 --- a/fastNLP/models/base_model.py +++ b/fastNLP/models/base_model.py @@ -1,6 +1,6 @@ import torch -from fastNLP.modules.decoder.MLP import MLP +from ..modules.decoder.MLP import MLP class BaseModel(torch.nn.Module): diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 42626934..7934b435 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -6,7 +6,7 @@ import torch from torch import nn from .base_model import BaseModel -from fastNLP.modules.encoder import BertModel +from ..modules.encoder import BertModel class BertForSequenceClassification(BaseModel): diff --git a/fastNLP/models/char_language_model.py b/fastNLP/models/char_language_model.py index d5e3359d..d0b4c426 100644 --- a/fastNLP/models/char_language_model.py +++ b/fastNLP/models/char_language_model.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from fastNLP.modules.encoder.lstm import LSTM +from ..modules.encoder.lstm import LSTM class Highway(nn.Module): diff --git a/fastNLP/models/enas_controller.py b/fastNLP/models/enas_controller.py index ae9bcfd2..16b970e6 100644 --- a/fastNLP/models/enas_controller.py +++ b/fastNLP/models/enas_controller.py @@ -5,9 +5,8 @@ import os import torch import torch.nn.functional as F -import fastNLP -import fastNLP.models.enas_utils as utils -from fastNLP.models.enas_utils import Node +from . import enas_utils as utils +from .enas_utils import Node def _construct_dags(prev_nodes, activations, func_names, num_blocks): diff --git a/fastNLP/models/enas_model.py b/fastNLP/models/enas_model.py index cc91e675..5c667927 100644 --- a/fastNLP/models/enas_model.py +++ b/fastNLP/models/enas_model.py @@ -9,9 +9,8 @@ from torch import nn import torch.nn.functional as F from torch.autograd import Variable -import fastNLP.models.enas_utils as utils -from fastNLP.models.base_model import BaseModel -import fastNLP.modules.encoder as encoder +from . import enas_utils as utils +from .base_model import BaseModel def _get_dropped_weights(w_raw, dropout_p, is_training): """Drops out weights to implement DropConnect. diff --git a/fastNLP/models/enas_trainer.py b/fastNLP/models/enas_trainer.py index d8110db0..824b8184 100644 --- a/fastNLP/models/enas_trainer.py +++ b/fastNLP/models/enas_trainer.py @@ -1,6 +1,5 @@ # Code Modified from https://github.com/carpedm20/ENAS-pytorch -import os import time from datetime import datetime from datetime import timedelta @@ -8,21 +7,19 @@ from datetime import timedelta import numpy as np import torch import math -from torch import nn try: from tqdm.autonotebook import tqdm except: - from fastNLP.core.utils import _pseudo_tqdm as tqdm + from ..core.utils import _pseudo_tqdm as tqdm -from fastNLP.core.batch import Batch -from fastNLP.core.callback import CallbackManager, CallbackException -from fastNLP.core.dataset import DataSet -from fastNLP.core.utils import _CheckError -from fastNLP.core.utils import _move_dict_value_to_device -import fastNLP -import fastNLP.models.enas_utils as utils -from fastNLP.core.utils import _build_args +from ..core.trainer import Trainer +from ..core.batch import Batch +from ..core.callback import CallbackManager, CallbackException +from ..core.dataset import DataSet +from ..core.utils import _move_dict_value_to_device +from . import enas_utils as utils +from ..core.utils import _build_args from torch.optim import Adam @@ -34,7 +31,7 @@ def _get_no_grad_ctx_mgr(): return torch.no_grad() -class ENASTrainer(fastNLP.Trainer): +class ENASTrainer(Trainer): """A class to wrap training code.""" def __init__(self, train_data, model, controller, **kwargs): """Constructor for training algorithm. diff --git a/fastNLP/models/enas_utils.py b/fastNLP/models/enas_utils.py index e5027d81..aafcb3a7 100644 --- a/fastNLP/models/enas_utils.py +++ b/fastNLP/models/enas_utils.py @@ -4,21 +4,20 @@ from __future__ import print_function from collections import defaultdict import collections -from datetime import datetime -import os -import json import numpy as np import torch from torch.autograd import Variable + def detach(h): if type(h) == Variable: return Variable(h.data) else: return tuple(detach(v) for v in h) + def get_variable(inputs, cuda=False, **kwargs): if type(inputs) in [list, np.ndarray]: inputs = torch.Tensor(inputs) @@ -28,10 +27,12 @@ def get_variable(inputs, cuda=False, **kwargs): out = Variable(inputs, **kwargs) return out + def update_lr(optimizer, lr): for param_group in optimizer.param_groups: param_group['lr'] = lr + Node = collections.namedtuple('Node', ['id', 'name']) @@ -48,9 +49,9 @@ def to_item(x): """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" if isinstance(x, (float, int)): return x - + if float(torch.__version__[0:3]) < 0.4: assert (x.dim() == 1) and (len(x) == 1) return x[0] - + return x.item() diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index b9b0677d..e076910f 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -1,9 +1,9 @@ import torch -from fastNLP.models.base_model import BaseModel -from fastNLP.modules import decoder, encoder -from fastNLP.modules.decoder.CRF import allowed_transitions -from fastNLP.modules.utils import seq_mask +from .base_model import BaseModel +from ..modules import decoder, encoder +from ..modules.decoder.CRF import allowed_transitions +from ..modules.utils import seq_mask class SeqLabeling(BaseModel): diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index d4bf3d59..6b54bee6 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from fastNLP.models.base_model import BaseModel -from fastNLP.modules import decoder as Decoder -from fastNLP.modules import encoder as Encoder -from fastNLP.modules import aggregator as Aggregator -from fastNLP.modules.utils import seq_mask +from .base_model import BaseModel +from ..modules import decoder as Decoder +from ..modules import encoder as Encoder +from ..modules import aggregator as Aggregator +from ..modules.utils import seq_mask my_inf = 10e12 diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py index c3247333..93ee72f6 100644 --- a/fastNLP/models/star_transformer.py +++ b/fastNLP/models/star_transformer.py @@ -7,7 +7,6 @@ from ..core.const import Const import torch from torch import nn -import torch.nn.functional as F class StarTransEnc(nn.Module): diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index f2f2ac68..67f68ff2 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -4,10 +4,10 @@ import torch import torch.nn.functional as F from torch import nn -from fastNLP.modules.dropout import TimestepDropout -from fastNLP.modules.utils import mask_softmax +from ..dropout import TimestepDropout +from ..utils import mask_softmax -from fastNLP.modules.utils import initial_parameter +from ..utils import initial_parameter class Attention(torch.nn.Module): diff --git a/fastNLP/modules/aggregator/pooling.py b/fastNLP/modules/aggregator/pooling.py index 9961b87f..fd4414b7 100644 --- a/fastNLP/modules/aggregator/pooling.py +++ b/fastNLP/modules/aggregator/pooling.py @@ -1,17 +1,12 @@ -# python: 3.6 -# encoding: utf-8 - import torch import torch.nn as nn class MaxPool(nn.Module): """Max-pooling模块。""" - - def __init__( - self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, - return_indices=False, ceil_mode=False - ): + + def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, + return_indices=False, ceil_mode=False): """ :param stride: 窗口移动大小,默认为kernel_size :param padding: padding的内容,默认为0 @@ -30,7 +25,7 @@ class MaxPool(nn.Module): self.kernel_size = kernel_size self.return_indices = return_indices self.ceil_mode = ceil_mode - + def forward(self, x): if self.dimension == 1: pooling = nn.MaxPool1d( @@ -57,10 +52,11 @@ class MaxPool(nn.Module): class MaxPoolWithMask(nn.Module): """带mask矩阵的1维max pooling""" + def __init__(self): super(MaxPoolWithMask, self).__init__() self.inf = 10e12 - + def forward(self, tensor, mask, dim=1): """ :param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor @@ -75,11 +71,11 @@ class MaxPoolWithMask(nn.Module): class KMaxPool(nn.Module): """K max-pooling module.""" - + def __init__(self, k=1): super(KMaxPool, self).__init__() self.k = k - + def forward(self, x): """ :param torch.Tensor x: [N, C, L] 初始tensor @@ -92,12 +88,12 @@ class KMaxPool(nn.Module): class AvgPool(nn.Module): """1-d average pooling module.""" - + def __init__(self, stride=None, padding=0): super(AvgPool, self).__init__() self.stride = stride self.padding = padding - + def forward(self, x): """ :param torch.Tensor x: [N, C, L] 初始tensor @@ -117,7 +113,7 @@ class MeanPoolWithMask(nn.Module): def __init__(self): super(MeanPoolWithMask, self).__init__() self.inf = 10e12 - + def forward(self, tensor, mask, dim=1): """ :param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor @@ -127,7 +123,3 @@ class MeanPoolWithMask(nn.Module): """ masks = mask.view(mask.size(0), mask.size(1), -1).float() return torch.sum(tensor * masks.float(), dim=dim) / torch.sum(masks.float(), dim=1) - - - - diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index cc713bc6..4c3ac122 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -1,8 +1,8 @@ import torch from torch import nn -from fastNLP.modules.utils import initial_parameter -from fastNLP.modules.decoder.utils import log_sum_exp +from ..utils import initial_parameter +from ..decoder.utils import log_sum_exp def seq_len_to_byte_mask(seq_lens): diff --git a/fastNLP/modules/decoder/MLP.py b/fastNLP/modules/decoder/MLP.py index e7fafd68..35484932 100644 --- a/fastNLP/modules/decoder/MLP.py +++ b/fastNLP/modules/decoder/MLP.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from fastNLP.modules.utils import initial_parameter +from ..utils import initial_parameter class MLP(nn.Module): diff --git a/fastNLP/modules/encoder/char_encoder.py b/fastNLP/modules/encoder/char_encoder.py index 39e4b43e..54b702ea 100644 --- a/fastNLP/modules/encoder/char_encoder.py +++ b/fastNLP/modules/encoder/char_encoder.py @@ -1,7 +1,7 @@ import torch from torch import nn -from fastNLP.modules.utils import initial_parameter +from ..utils import initial_parameter # from torch.nn.init import xavier_uniform diff --git a/fastNLP/modules/encoder/conv_maxpool.py b/fastNLP/modules/encoder/conv_maxpool.py index d7a8b286..d01eddea 100644 --- a/fastNLP/modules/encoder/conv_maxpool.py +++ b/fastNLP/modules/encoder/conv_maxpool.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from fastNLP.modules.utils import initial_parameter +from ..utils import initial_parameter class ConvMaxpool(nn.Module): diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 098788a8..8cc53b0b 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -1,5 +1,5 @@ import torch.nn as nn -from fastNLP.modules.utils import get_embeddings +from ..utils import get_embeddings class Embedding(nn.Embedding): """Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" diff --git a/fastNLP/modules/encoder/linear.py b/fastNLP/modules/encoder/linear.py index 2dc31eea..06edf81b 100644 --- a/fastNLP/modules/encoder/linear.py +++ b/fastNLP/modules/encoder/linear.py @@ -1,6 +1,6 @@ import torch.nn as nn -from fastNLP.modules.utils import initial_parameter +from ..utils import initial_parameter class Linear(nn.Module): diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index cff39c84..cc6b1183 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.utils.rnn as rnn -from fastNLP.modules.utils import initial_parameter +from ..utils import initial_parameter class LSTM(nn.Module): diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index 89ab44d9..2657ebf4 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence -from fastNLP.modules.utils import initial_parameter +from ..utils import initial_parameter try: from torch import flip