You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_torchcode.py 890 B

2 months ago
1234567891011121314151617181920212223242526272829303132333435
  1. import torch
  2. import torch.nn as nn
  3. class Model(nn.Module):
  4. """
  5. Model that performs matrix multiplication followed by ReLU activation.
  6. """
  7. def __init__(self, weight):
  8. super(Model, self).__init__()
  9. self.weight = nn.Parameter(weight)
  10. def forward(self, x: torch.Tensor) -> torch.Tensor:
  11. """
  12. Performs matrix multiplication and applies ReLU activation.
  13. Args:
  14. x (torch.Tensor): Input tensor of shape [batch_size, input_dim]
  15. Returns:
  16. torch.Tensor: Output tensor of shape [batch_size, output_dim]
  17. """
  18. x = torch.matmul(x, self.weight)
  19. return torch.relu(x)
  20. batch_size = 16
  21. input_dim = 1024
  22. output_dim = 2048
  23. def get_inputs():
  24. x = torch.randn(batch_size, input_dim)
  25. return [x]
  26. def get_init_inputs():
  27. weight = torch.randn(input_dim, output_dim)
  28. return [weight]