| @@ -0,0 +1,185 @@ | |||||
| import os | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.optim as optim | |||||
| import torch.nn.functional as F | |||||
| from torch.utils.data import DataLoader | |||||
| import torchvision.transforms as transforms | |||||
| import torchvision.datasets as datasets | |||||
| import tarfile | |||||
| # 检查 NPU 是否可用 | |||||
| def get_device(): | |||||
| if torch.cuda.is_available(): | |||||
| device = torch.device("cuda") | |||||
| print("Using NVIDIA GPU (CUDA)") | |||||
| else: | |||||
| device = torch.device("cpu") | |||||
| print("Using CPU") | |||||
| return device | |||||
| # 定义 3 层卷积神经网络 | |||||
| class ThreeLayerCNN(nn.Module): | |||||
| def __init__(self): | |||||
| super(ThreeLayerCNN, self).__init__() | |||||
| self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) | |||||
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) | |||||
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | |||||
| self.pool = nn.MaxPool2d(2, 2) | |||||
| self.flatten = nn.Flatten() | |||||
| self.fc1 = nn.LazyLinear(256) | |||||
| self.fc2 = nn.Linear(256, 10) | |||||
| self.dropout = nn.Dropout(0.5) | |||||
| def forward(self, x): | |||||
| x = self.pool(F.relu(self.conv1(x))) | |||||
| x = self.pool(F.relu(self.conv2(x))) | |||||
| x = self.pool(F.relu(self.conv3(x))) | |||||
| x = self.flatten(x) | |||||
| x = F.relu(self.fc1(x)) | |||||
| x = self.dropout(x) | |||||
| x = self.fc2(x) | |||||
| return x | |||||
| # 数据预处理 | |||||
| def get_transform(): | |||||
| return transforms.Compose( | |||||
| [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |||||
| ) | |||||
| # 解压数据集 | |||||
| def extract_dataset(data_tar_path, extract_path): | |||||
| if os.path.exists(extract_path) and check_dataset_files(extract_path): | |||||
| print("数据集已解压,无需重复解压") | |||||
| return True # 直接返回,不解压 | |||||
| if os.path.exists(data_tar_path): | |||||
| try: | |||||
| os.makedirs(extract_path, exist_ok=True) | |||||
| with tarfile.open(data_tar_path, "r:gz") as tar: | |||||
| tar.extractall(path=extract_path) | |||||
| print("数据集解压成功") | |||||
| return True | |||||
| except Exception as e: | |||||
| print(f"解压数据集时出错: {e}") | |||||
| return False | |||||
| # 检查数据集文件是否完整 | |||||
| def check_dataset_files(extract_path): | |||||
| data_files = ["train", "test", "meta"] | |||||
| for file in data_files: | |||||
| file_path = os.path.join(extract_path, "cifar-10-python", file) | |||||
| if not os.path.exists(file_path): | |||||
| return False | |||||
| return True | |||||
| # 加载数据集 | |||||
| def load_datasets(data_tar_path, extract_path, transform): | |||||
| # 如果 `train, test, meta` 不存在,则尝试解压 | |||||
| if not check_dataset_files(extract_path): | |||||
| if os.path.exists(data_tar_path): | |||||
| print("数据集未解压,正在解压...") | |||||
| if not extract_dataset(data_tar_path, extract_path): | |||||
| raise RuntimeError(f"数据集解压失败,请检查 {data_tar_path}") | |||||
| else: | |||||
| raise FileNotFoundError(f"数据集文件不完整,且 {data_tar_path} 也不存在,无法解压") | |||||
| train_dataset = datasets.CIFAR10( | |||||
| root=extract_path, train=True, download=False, transform=transform | |||||
| ) | |||||
| test_dataset = datasets.CIFAR10( | |||||
| root=extract_path, train=False, download=False, transform=transform | |||||
| ) | |||||
| return train_dataset, test_dataset | |||||
| # 训练函数 | |||||
| def train(model, train_loader, criterion, optimizer, device, epochs=3): | |||||
| model.train() | |||||
| for epoch in range(epochs): | |||||
| running_loss = 0.0 | |||||
| for i, (inputs, labels) in enumerate(train_loader): | |||||
| inputs, labels = inputs.to(device), labels.to(device) | |||||
| optimizer.zero_grad() | |||||
| outputs = model(inputs) | |||||
| loss = criterion(outputs, labels) | |||||
| loss.backward() | |||||
| optimizer.step() | |||||
| running_loss += loss.item() | |||||
| if i % 100 == 99: | |||||
| print( | |||||
| f"Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}" | |||||
| ) | |||||
| running_loss = 0.0 | |||||
| # 测试函数 | |||||
| def test(model, test_loader, device): | |||||
| model.eval() | |||||
| correct = 0 | |||||
| total = 0 | |||||
| with torch.no_grad(): | |||||
| for inputs, labels in test_loader: | |||||
| inputs, labels = inputs.to(device), labels.to(device) | |||||
| outputs = model(inputs) | |||||
| _, predicted = torch.max(outputs.data, 1) | |||||
| total += labels.size(0) | |||||
| correct += (predicted == labels).sum().item() | |||||
| print(f"Accuracy of the model on the test images: {100 * correct / total:.2f}%") | |||||
| # 保存模型 | |||||
| def save_model(model, model_save_path): | |||||
| model_dir = os.path.dirname(model_save_path) | |||||
| if not os.path.exists(model_dir): | |||||
| os.makedirs(model_dir) | |||||
| torch.save(model.state_dict(), model_save_path) | |||||
| # 主函数 | |||||
| def main(): | |||||
| # 获取设备 | |||||
| device = get_device() | |||||
| # 数据预处理 | |||||
| transform = get_transform() | |||||
| # 数据集路径 | |||||
| data_tar_path = "/userhome/dataset/5/submissionScript/cifar-10-python.tar.gz" | |||||
| extract_path = "./data" | |||||
| # 加载数据集 | |||||
| train_dataset, test_dataset = load_datasets(data_tar_path, extract_path, transform) | |||||
| # 创建数据加载器 | |||||
| train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | |||||
| test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) | |||||
| # 初始化模型 | |||||
| model = ThreeLayerCNN().to(device) | |||||
| # 定义损失函数和优化器 | |||||
| criterion = nn.CrossEntropyLoss() | |||||
| optimizer = optim.Adam(model.parameters(), lr=0.001) | |||||
| # 开始训练 | |||||
| train(model, train_loader, criterion, optimizer, device, epochs=500) | |||||
| # 进行测试 | |||||
| test(model, test_loader, device) | |||||
| # 保存模型 | |||||
| model_save_path = ( | |||||
| "/model/three_layer_cnn_cifar10_500epo_octopus.pth" | |||||
| ) | |||||
| save_model(model, model_save_path) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||