You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

torch_model.py 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  2. if _NEED_IMPORT_TORCH:
  3. import torch
  4. from torch.nn import Module
  5. import torch.nn as nn
  6. else:
  7. from fastNLP.core.utils.dummy_class import DummyClass as Module
  8. # 1. 最为基础的分类模型
  9. class TorchNormalModel_Classification_1(Module):
  10. """
  11. 单独实现 train_step 和 evaluate_step;
  12. """
  13. def __init__(self, num_labels, feature_dimension):
  14. super(TorchNormalModel_Classification_1, self).__init__()
  15. self.num_labels = num_labels
  16. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)
  17. self.ac1 = nn.ReLU()
  18. self.linear2 = nn.Linear(in_features=10, out_features=10)
  19. self.ac2 = nn.ReLU()
  20. self.output = nn.Linear(in_features=10, out_features=num_labels)
  21. self.loss_fn = nn.CrossEntropyLoss()
  22. def forward(self, x):
  23. x = self.ac1(self.linear1(x))
  24. x = self.ac2(self.linear2(x))
  25. x = self.output(x)
  26. return x
  27. def train_step(self, x, y):
  28. x = self(x)
  29. return {"loss": self.loss_fn(x, y)}
  30. def evaluate_step(self, x, y):
  31. """
  32. 如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"};
  33. """
  34. x = self(x)
  35. x = torch.max(x, dim=-1)[1]
  36. return {"preds": x, "target": y}
  37. class TorchNormalModel_Classification_2(Module):
  38. """
  39. 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
  40. """
  41. def __init__(self, num_labels, feature_dimension):
  42. super(TorchNormalModel_Classification_2, self).__init__()
  43. self.num_labels = num_labels
  44. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)
  45. self.ac1 = nn.ReLU()
  46. self.linear2 = nn.Linear(in_features=10, out_features=10)
  47. self.ac2 = nn.ReLU()
  48. self.output = nn.Linear(in_features=10, out_features=num_labels)
  49. self.loss_fn = nn.CrossEntropyLoss()
  50. def forward(self, x, y):
  51. x = self.ac1(self.linear1(x))
  52. x = self.ac2(self.linear2(x))
  53. x = self.output(x)
  54. loss = self.loss_fn(x, y)
  55. x = torch.max(x, dim=-1)[1]
  56. return {"loss": loss, "preds": x, "target": y}
  57. class TorchNormalModel_Classification_3(Module):
  58. """
  59. 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
  60. 关闭 auto_param_call,forward 只有一个 batch 参数;
  61. """
  62. def __init__(self, num_labels, feature_dimension):
  63. super(TorchNormalModel_Classification_3, self).__init__()
  64. self.num_labels = num_labels
  65. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)
  66. self.ac1 = nn.ReLU()
  67. self.linear2 = nn.Linear(in_features=10, out_features=10)
  68. self.ac2 = nn.ReLU()
  69. self.output = nn.Linear(in_features=10, out_features=num_labels)
  70. self.loss_fn = nn.CrossEntropyLoss()
  71. def forward(self, batch):
  72. x = batch["x"]
  73. y = batch["y"]
  74. x = self.ac1(self.linear1(x))
  75. x = self.ac2(self.linear2(x))
  76. x = self.output(x)
  77. loss = self.loss_fn(x, y)
  78. x = torch.max(x, dim=-1)[1]
  79. return {"loss": loss, "preds": x, "target": y}