You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """
  2. Copyright 2020 Tianshu AI Platform. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. =============================================================
  13. """
  14. import abc
  15. import torch
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. import sys
  19. import typing
  20. from typing import Callable, Dict, List, Any
  21. from collections import Mapping, Sequence
  22. from . import loss
  23. from kamal.core import metrics, exceptions
  24. from kamal.core.attach import AttachTo
  25. class Task(object):
  26. def __init__(self, name):
  27. self.name = name
  28. @abc.abstractmethod
  29. def get_loss( self, outputs, targets ) -> Dict:
  30. pass
  31. @abc.abstractmethod
  32. def predict(self, outputs) -> Any:
  33. pass
  34. class GeneralTask(Task):
  35. def __init__(self,
  36. name: str,
  37. loss_fn: Callable,
  38. scaling:float=1.0,
  39. pred_fn: Callable=lambda x: x,
  40. attach_to=None):
  41. super(GeneralTask, self).__init__(name)
  42. self._attach = AttachTo(attach_to)
  43. self.loss_fn = loss_fn
  44. self.pred_fn = pred_fn
  45. self.scaling = scaling
  46. def get_loss(self, outputs, targets):
  47. outputs, targets = self._attach(outputs, targets)
  48. return { self.name: self.loss_fn( outputs, targets ) * self.scaling }
  49. def predict(self, outputs):
  50. outputs = self._attach(outputs)
  51. return self.pred_fn(outputs)
  52. def __repr__(self):
  53. rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach)
  54. return rep
  55. class TaskCompose(list):
  56. def __init__(self, tasks: list):
  57. for task in tasks:
  58. if isinstance(task, Task):
  59. self.append(task)
  60. def get_loss(self, outputs, targets):
  61. loss_dict = {}
  62. for task in self:
  63. loss_dict.update( task.get_loss( outputs, targets ) )
  64. return loss_dict
  65. def predict(self, outputs):
  66. results = []
  67. for task in self:
  68. results.append( task.predict( outputs ) )
  69. return results
  70. def __repr__(self):
  71. rep="TaskCompose: \n"
  72. for task in self:
  73. rep+="\t%s\n"%task
  74. class StandardTask:
  75. @staticmethod
  76. def classification(name='ce', scaling=1.0, attach_to=None):
  77. return GeneralTask( name=name,
  78. loss_fn=nn.CrossEntropyLoss(),
  79. scaling=scaling,
  80. pred_fn=lambda x: x.max(1)[1],
  81. attach_to=attach_to )
  82. @staticmethod
  83. def binary_classification(name='bce', scaling=1.0, attach_to=None):
  84. return GeneralTask(name=name,
  85. loss_fn=F.binary_cross_entropy_with_logits,
  86. scaling=scaling,
  87. pred_fn=lambda x: (x>0.5),
  88. attach_to=attach_to )
  89. @staticmethod
  90. def regression(name='mse', scaling=1.0, attach_to=None):
  91. return GeneralTask(name=name,
  92. loss_fn=nn.MSELoss(),
  93. scaling=scaling,
  94. pred_fn=lambda x: x,
  95. attach_to=attach_to )
  96. @staticmethod
  97. def segmentation(name='ce', scaling=1.0, attach_to=None):
  98. return GeneralTask(name=name,
  99. loss_fn=nn.CrossEntropyLoss(ignore_index=255),
  100. scaling=scaling,
  101. pred_fn=lambda x: x.max(1)[1],
  102. attach_to=attach_to )
  103. @staticmethod
  104. def monocular_depth(name='l1', scaling=1.0, attach_to=None):
  105. return GeneralTask(name=name,
  106. loss_fn=nn.L1Loss(),
  107. scaling=scaling,
  108. pred_fn=lambda x: x,
  109. attach_to=attach_to)
  110. @staticmethod
  111. def detection():
  112. raise NotImplementedError
  113. @staticmethod
  114. def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None):
  115. return GeneralTask(name=name,
  116. loss_fn=loss.KLDiv(T=T),
  117. scaling=scaling,
  118. pred_fn=lambda x: x.max(1)[1],
  119. attach_to=attach_to)
  120. class StandardMetrics(object):
  121. @staticmethod
  122. def classification(attach_to=None):
  123. return metrics.MetricCompose(
  124. metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)}
  125. )
  126. @staticmethod
  127. def regression(attach_to=None):
  128. return metrics.MetricCompose(
  129. metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)}
  130. )
  131. @staticmethod
  132. def segmentation(num_classes, ignore_idx=255, attach_to=None):
  133. confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to)
  134. return metrics.MetricCompose(
  135. metric_dict={'acc': metrics.Accuracy(attach_to=attach_to),
  136. 'confusion_matrix': confusion_matrix ,
  137. 'miou': metrics.mIoU(confusion_matrix)}
  138. )
  139. @staticmethod
  140. def monocular_depth(attach_to=None):
  141. return metrics.MetricCompose(
  142. metric_dict={
  143. 'rmse': metrics.RootMeanSquaredError(attach_to=attach_to),
  144. 'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ),
  145. 'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to),
  146. 'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to),
  147. 'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to),
  148. 'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to )
  149. }
  150. )
  151. @staticmethod
  152. def loss_metric(loss_fn):
  153. return metrics.MetricCompose(
  154. metric_dict={
  155. 'loss': metrics.AverageMetric( loss_fn )
  156. }
  157. )

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能