You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

jittor_model.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
  2. if _NEED_IMPORT_JITTOR:
  3. from jittor import Module, nn
  4. else:
  5. from fastNLP.core.utils.dummy_class import DummyClass as Module
  6. class JittorNormalModel_Classification_1(Module):
  7. """
  8. 基础的 jittor 分类模型
  9. """
  10. def __init__(self, num_labels, feature_dimension):
  11. super(JittorNormalModel_Classification_1, self).__init__()
  12. self.num_labels = num_labels
  13. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
  14. self.ac1 = nn.ReLU()
  15. self.linear2 = nn.Linear(in_features=64, out_features=32)
  16. self.ac2 = nn.ReLU()
  17. self.output = nn.Linear(in_features=32, out_features=num_labels)
  18. self.loss_fn = nn.CrossEntropyLoss()
  19. def execute(self, x):
  20. x = self.ac1(self.linear1(x))
  21. x = self.ac2(self.linear2(x))
  22. x = self.output(x)
  23. return x
  24. def train_step(self, x, y):
  25. x = self(x)
  26. return {"loss": self.loss_fn(x, y)}
  27. def evaluate_step(self, x, y):
  28. x = self(x)
  29. return {"pred": x, "target": y.reshape((-1,))}
  30. class JittorNormalModel_Classification_2(Module):
  31. """
  32. 基础的 jittor 分类模型,只实现 execute 函数测试用户自己初始化了分布式的场景
  33. """
  34. def __init__(self, num_labels, feature_dimension):
  35. super(JittorNormalModel_Classification_2, self).__init__()
  36. self.num_labels = num_labels
  37. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
  38. self.ac1 = nn.ReLU()
  39. self.linear2 = nn.Linear(in_features=64, out_features=32)
  40. self.ac2 = nn.ReLU()
  41. self.output = nn.Linear(in_features=32, out_features=num_labels)
  42. self.loss_fn = nn.CrossEntropyLoss()
  43. def execute(self, x, y):
  44. x = self.ac1(self.linear1(x))
  45. x = self.ac2(self.linear2(x))
  46. x = self.output(x)
  47. return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))}