import torch import torch.nn as nn class Model(nn.Module): """ Model that performs matrix multiplication followed by ReLU activation. """ def __init__(self, weight): super(Model, self).__init__() self.weight = nn.Parameter(weight) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs matrix multiplication and applies ReLU activation. Args: x (torch.Tensor): Input tensor of shape [batch_size, input_dim] Returns: torch.Tensor: Output tensor of shape [batch_size, output_dim] """ x = torch.matmul(x, self.weight) return torch.relu(x) batch_size = 16 input_dim = 1024 output_dim = 2048 def get_inputs(): x = torch.randn(batch_size, input_dim) return [x] def get_init_inputs(): weight = torch.randn(input_dim, output_dim) return [weight]