|
- from fastNLP.embeddings import BertEmbedding
- import torch
- import torch.nn as nn
- from fastNLP.core.const import Const as C
-
- class BertTC(nn.Module):
- def __init__(self, vocab,num_class,bert_model_dir_or_name,fine_tune=False):
- super(BertTC, self).__init__()
- self.embed=BertEmbedding(vocab, requires_grad=fine_tune,
- model_dir_or_name=bert_model_dir_or_name,include_cls_sep=True)
- self.classifier = nn.Linear(self.embed.embedding_dim, num_class)
-
- def forward(self, words):
- embedding_cls=self.embed(words)[:,0]
- output=self.classifier(embedding_cls)
- return {C.OUTPUT: output}
-
- def predict(self,words):
- return self.forward(words)
-
- if __name__=="__main__":
- ta=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
- tb=ta[:,0]
- print(tb)
|