diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py index 54e28687..24dfdb85 100644 --- a/fastNLP/models/base_model.py +++ b/fastNLP/models/base_model.py @@ -3,31 +3,12 @@ import torch class BaseModel(torch.nn.Module): """Base PyTorch model for all models. - Three network modules presented: - - encoder module - - aggregation module - - decoder module - Subclasses must implement these three modules with "components". + To do: add some useful common features """ def __init__(self): super(BaseModel, self).__init__() - def forward(self, *inputs): - x = self.encode(*inputs) - x = self.aggregate(x) - x = self.decode(x) - return x - - def encode(self, x): - raise NotImplementedError - - def aggregate(self, x): - raise NotImplementedError - - def decode(self, x): - raise NotImplementedError - class Vocabulary(object): """A look-up table that allows you to access `Lexeme` objects. The `Vocab`