You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sequence_classification_pipeline.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import os
  2. import uuid
  3. from typing import Any, Dict
  4. import json
  5. import numpy as np
  6. from maas_lib.models.nlp import SequenceClassificationModel
  7. from maas_lib.preprocessors import SequenceClassificationPreprocessor
  8. from maas_lib.utils.constant import Tasks
  9. from ..base import Input, Pipeline
  10. from ..builder import PIPELINES
  11. __all__ = ['SequenceClassificationPipeline']
  12. @PIPELINES.register_module(
  13. Tasks.text_classification, module_name=r'bert-sentiment-analysis')
  14. class SequenceClassificationPipeline(Pipeline):
  15. def __init__(self, model: SequenceClassificationModel,
  16. preprocessor: SequenceClassificationPreprocessor, **kwargs):
  17. """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
  18. Args:
  19. model (SequenceClassificationModel): a model instance
  20. preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
  21. """
  22. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  23. from easynlp.utils import io
  24. self.label_path = os.path.join(model.model_dir, 'label_mapping.json')
  25. with io.open(self.label_path) as f:
  26. self.label_mapping = json.load(f)
  27. self.label_id_to_name = {
  28. idx: name
  29. for name, idx in self.label_mapping.items()
  30. }
  31. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
  32. """process the prediction results
  33. Args:
  34. inputs (Dict[str, Any]): _description_
  35. Returns:
  36. Dict[str, str]: the prediction results
  37. """
  38. probs = inputs['probabilities']
  39. logits = inputs['logits']
  40. predictions = np.argsort(-probs, axis=-1)
  41. preds = predictions[0]
  42. b = 0
  43. new_result = list()
  44. for pred in preds:
  45. new_result.append({
  46. 'pred': self.label_id_to_name[pred],
  47. 'prob': float(probs[b][pred]),
  48. 'logit': float(logits[b][pred])
  49. })
  50. new_results = list()
  51. new_results.append({
  52. 'id':
  53. inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
  54. 'output':
  55. new_result,
  56. 'predictions':
  57. new_result[0]['pred'],
  58. 'probabilities':
  59. ','.join([str(t) for t in inputs['probabilities'][b]]),
  60. 'logits':
  61. ','.join([str(t) for t in inputs['logits'][b]])
  62. })
  63. return new_results[0]

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展