@@ -1,4 +1,6 @@ | |||
from typing import Dict, Optional, Union | |||
from typing import Dict, Optional, Union, Any | |||
import torch | |||
from ...models import Model | |||
from ...models.nlp.masked_language_model import \ | |||
@@ -35,6 +37,7 @@ class FillMaskPipeline(Pipeline): | |||
fill_mask_model.model_dir, | |||
first_sequence=first_sequence, | |||
second_sequence=None) | |||
fill_mask_model.eval() | |||
super().__init__(model=fill_mask_model, preprocessor=preprocessor, **kwargs) | |||
self.preprocessor = preprocessor | |||
self.tokenizer = preprocessor.tokenizer | |||
@@ -61,6 +64,11 @@ class FillMaskPipeline(Pipeline): | |||
} | |||
} | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""process the prediction results | |||
@@ -1,6 +1,6 @@ | |||
import uuid | |||
from typing import Any, Dict, Union | |||
import torch | |||
import uuid | |||
from typing import Any, Dict, Union | |||
@@ -42,9 +42,15 @@ class NLIPipeline(Pipeline): | |||
sc_model.model_dir, | |||
first_sequence=first_sequence, | |||
second_sequence=second_sequence) | |||
sc_model.eval() | |||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
assert len(sc_model.id2label) > 0 | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
@@ -1,7 +1,7 @@ | |||
from typing import Any, Dict, Union | |||
import numpy as np | |||
import torch | |||
from ...metainfo import Pipelines | |||
from ...models.nlp import SbertForSentenceSimilarity | |||
from ...preprocessors import SequenceClassificationPreprocessor | |||
@@ -39,11 +39,17 @@ class SentenceSimilarityPipeline(Pipeline): | |||
sc_model.model_dir, | |||
first_sequence=first_sequence, | |||
second_sequence=second_sequence) | |||
sc_model.eval() | |||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
assert hasattr(self.model, 'id2label'), \ | |||
'id2label map should be initalizaed in init function.' | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
@@ -1,7 +1,7 @@ | |||
import os | |||
import uuid | |||
from typing import Any, Dict, Union | |||
import torch | |||
import json | |||
import numpy as np | |||
@@ -43,9 +43,15 @@ class SentimentClassificationPipeline(Pipeline): | |||
sc_model.model_dir, | |||
first_sequence=first_sequence, | |||
second_sequence=second_sequence) | |||
sc_model.eval() | |||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
assert len(sc_model.id2label) > 0 | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
@@ -1,5 +1,5 @@ | |||
from typing import Dict, Optional, Union | |||
from typing import Dict, Optional, Union, Any | |||
import torch | |||
from ...metainfo import Pipelines | |||
from ...models import Model | |||
from ...models.nlp import PalmForTextGeneration | |||
@@ -33,9 +33,15 @@ class TextGenerationPipeline(Pipeline): | |||
model.tokenizer, | |||
first_sequence='sentence', | |||
second_sequence=None) | |||
model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.tokenizer = model.tokenizer | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
@@ -1,5 +1,5 @@ | |||
from typing import Any, Dict, Optional, Union | |||
import torch | |||
from ...metainfo import Pipelines | |||
from ...models import Model | |||
from ...models.nlp import SbertForTokenClassification | |||
@@ -30,12 +30,18 @@ class WordSegmentationPipeline(Pipeline): | |||
SbertForTokenClassification) else Model.from_pretrained(model) | |||
if preprocessor is None: | |||
preprocessor = TokenClassifcationPreprocessor(model.model_dir) | |||
model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.tokenizer = preprocessor.tokenizer | |||
self.config = model.config | |||
assert len(self.config.id2label) > 0 | |||
self.id2label = self.config.id2label | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
@@ -1,7 +1,7 @@ | |||
import os | |||
import uuid | |||
from typing import Any, Dict, Union | |||
import torch | |||
import json | |||
import numpy as np | |||
from scipy.special import softmax | |||
@@ -44,6 +44,7 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
if preprocessor is None: | |||
preprocessor = ZeroShotClassificationPreprocessor( | |||
sc_model.model_dir) | |||
model.eval() | |||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
def _sanitize_parameters(self, **kwargs): | |||
@@ -62,6 +63,11 @@ class ZeroShotClassificationPipeline(Pipeline): | |||
postprocess_params['multi_label'] = kwargs.pop('multi_label', False) | |||
return preprocess_params, {}, postprocess_params | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, | |||
inputs: Dict[str, Any], | |||
candidate_labels, | |||