|
|
@@ -3,31 +3,12 @@ import torch |
|
|
|
|
|
|
|
class BaseModel(torch.nn.Module): |
|
|
|
"""Base PyTorch model for all models. |
|
|
|
Three network modules presented: |
|
|
|
- encoder module |
|
|
|
- aggregation module |
|
|
|
- decoder module |
|
|
|
Subclasses must implement these three modules with "components". |
|
|
|
To do: add some useful common features |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(BaseModel, self).__init__() |
|
|
|
|
|
|
|
def forward(self, *inputs): |
|
|
|
x = self.encode(*inputs) |
|
|
|
x = self.aggregate(x) |
|
|
|
x = self.decode(x) |
|
|
|
return x |
|
|
|
|
|
|
|
def encode(self, x): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def aggregate(self, x): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def decode(self, x): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
class Vocabulary(object): |
|
|
|
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` |
|
|
|