@@ -1,3 +1,7 @@ | |||||
"""undocumented""" | |||||
__all__ = [] | |||||
import torch | import torch | ||||
from ..modules.decoder.mlp import MLP | from ..modules.decoder.mlp import MLP | ||||
@@ -1,16 +1,20 @@ | |||||
""" | |||||
"""undocumented | |||||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | ||||
""" | """ | ||||
__all__ = [] | |||||
import os | import os | ||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..core.utils import seq_len_to_mask | |||||
from ..modules.encoder import BertModel | from ..modules.encoder import BertModel | ||||
from ..modules.encoder.bert import BertConfig, CONFIG_FILE | from ..modules.encoder.bert import BertConfig, CONFIG_FILE | ||||
from ..core.utils import seq_len_to_mask | |||||
class BertForSequenceClassification(BaseModel): | class BertForSequenceClassification(BaseModel): | ||||
@@ -1,3 +1,8 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
"CNNText" | "CNNText" | ||||
] | ] | ||||
@@ -7,8 +12,8 @@ import torch.nn as nn | |||||
from ..core.const import Const as C | from ..core.const import Const as C | ||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
from ..modules import encoder | |||||
from ..embeddings import embedding | from ..embeddings import embedding | ||||
from ..modules import encoder | |||||
class CNNText(torch.nn.Module): | class CNNText(torch.nn.Module): | ||||
@@ -1,5 +1,10 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""A module with NAS controller-related code.""" | |||||
"""undocumented | |||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
A module with NAS controller-related code. | |||||
""" | |||||
__all__ = [] | |||||
import collections | import collections | ||||
import os | import os | ||||
@@ -1,7 +1,10 @@ | |||||
""" | |||||
"""undocumented | |||||
Module containing the shared RNN model. | Module containing the shared RNN model. | ||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | Code Modified from https://github.com/carpedm20/ENAS-pytorch | ||||
""" | """ | ||||
__all__ = [] | |||||
import collections | import collections | ||||
import numpy as np | import numpy as np | ||||
@@ -1,11 +1,15 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""undocumented | |||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
""" | |||||
__all__ = [] | |||||
import math | import math | ||||
import numpy as np | |||||
import time | import time | ||||
import torch | |||||
from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||
import numpy as np | |||||
import torch | |||||
from torch.optim import Adam | from torch.optim import Adam | ||||
try: | try: | ||||
@@ -15,7 +19,7 @@ except: | |||||
from ..core.trainer import Trainer | from ..core.trainer import Trainer | ||||
from ..core.batch import DataSetIter | from ..core.batch import DataSetIter | ||||
from ..core.callback import CallbackManager, CallbackException | |||||
from ..core.callback import CallbackException | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.utils import _move_dict_value_to_device | from ..core.utils import _move_dict_value_to_device | ||||
from . import enas_utils as utils | from . import enas_utils as utils | ||||
@@ -1,7 +1,11 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""undocumented | |||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
""" | |||||
__all__ = [] | |||||
from collections import defaultdict | |||||
import collections | import collections | ||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -1,5 +1,5 @@ | |||||
""" | """ | ||||
本模块实现了几种序列标注模型 | |||||
本模块实现了几种序列标注模型 | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"SeqLabeling", | "SeqLabeling", | ||||
@@ -12,14 +12,14 @@ import torch.nn as nn | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..embeddings import embedding | |||||
from ..modules import decoder, encoder | |||||
from ..modules.decoder.crf import allowed_transitions | |||||
from ..core.utils import seq_len_to_mask | |||||
from ..core.const import Const as C | from ..core.const import Const as C | ||||
from ..modules import LSTM | |||||
from ..core.utils import seq_len_to_mask | |||||
from ..embeddings import embedding | |||||
from ..embeddings import get_embeddings | from ..embeddings import get_embeddings | ||||
from ..modules import ConditionalRandomField | from ..modules import ConditionalRandomField | ||||
from ..modules import LSTM | |||||
from ..modules import decoder, encoder | |||||
from ..modules.decoder.crf import allowed_transitions | |||||
class BiLSTMCRF(BaseModel): | class BiLSTMCRF(BaseModel): | ||||
@@ -1,3 +1,7 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
"ESIM" | "ESIM" | ||||
] | ] | ||||
@@ -5,13 +9,12 @@ __all__ = [ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch.nn import CrossEntropyLoss | from torch.nn import CrossEntropyLoss | ||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..embeddings.embedding import TokenEmbedding, Embedding | |||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
from ..embeddings.embedding import TokenEmbedding, Embedding | |||||
class ESIM(BaseModel): | class ESIM(BaseModel): | ||||