| @@ -0,0 +1,185 @@ | |||
| import os | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.optim as optim | |||
| import torch.nn.functional as F | |||
| from torch.utils.data import DataLoader | |||
| import torchvision.transforms as transforms | |||
| import torchvision.datasets as datasets | |||
| import tarfile | |||
| # 检查 NPU 是否可用 | |||
| def get_device(): | |||
| if torch.cuda.is_available(): | |||
| device = torch.device("cuda") | |||
| print("Using NVIDIA GPU (CUDA)") | |||
| else: | |||
| device = torch.device("cpu") | |||
| print("Using CPU") | |||
| return device | |||
| # 定义 3 层卷积神经网络 | |||
| class ThreeLayerCNN(nn.Module): | |||
| def __init__(self): | |||
| super(ThreeLayerCNN, self).__init__() | |||
| self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) | |||
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) | |||
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | |||
| self.pool = nn.MaxPool2d(2, 2) | |||
| self.flatten = nn.Flatten() | |||
| self.fc1 = nn.LazyLinear(256) | |||
| self.fc2 = nn.Linear(256, 10) | |||
| self.dropout = nn.Dropout(0.5) | |||
| def forward(self, x): | |||
| x = self.pool(F.relu(self.conv1(x))) | |||
| x = self.pool(F.relu(self.conv2(x))) | |||
| x = self.pool(F.relu(self.conv3(x))) | |||
| x = self.flatten(x) | |||
| x = F.relu(self.fc1(x)) | |||
| x = self.dropout(x) | |||
| x = self.fc2(x) | |||
| return x | |||
| # 数据预处理 | |||
| def get_transform(): | |||
| return transforms.Compose( | |||
| [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |||
| ) | |||
| # 解压数据集 | |||
| def extract_dataset(data_tar_path, extract_path): | |||
| if os.path.exists(extract_path) and check_dataset_files(extract_path): | |||
| print("数据集已解压,无需重复解压") | |||
| return True # 直接返回,不解压 | |||
| if os.path.exists(data_tar_path): | |||
| try: | |||
| os.makedirs(extract_path, exist_ok=True) | |||
| with tarfile.open(data_tar_path, "r:gz") as tar: | |||
| tar.extractall(path=extract_path) | |||
| print("数据集解压成功") | |||
| return True | |||
| except Exception as e: | |||
| print(f"解压数据集时出错: {e}") | |||
| return False | |||
| # 检查数据集文件是否完整 | |||
| def check_dataset_files(extract_path): | |||
| data_files = ["train", "test", "meta"] | |||
| for file in data_files: | |||
| file_path = os.path.join(extract_path, "cifar-10-python", file) | |||
| if not os.path.exists(file_path): | |||
| return False | |||
| return True | |||
| # 加载数据集 | |||
| def load_datasets(data_tar_path, extract_path, transform): | |||
| # 如果 `train, test, meta` 不存在,则尝试解压 | |||
| if not check_dataset_files(extract_path): | |||
| if os.path.exists(data_tar_path): | |||
| print("数据集未解压,正在解压...") | |||
| if not extract_dataset(data_tar_path, extract_path): | |||
| raise RuntimeError(f"数据集解压失败,请检查 {data_tar_path}") | |||
| else: | |||
| raise FileNotFoundError(f"数据集文件不完整,且 {data_tar_path} 也不存在,无法解压") | |||
| train_dataset = datasets.CIFAR10( | |||
| root=extract_path, train=True, download=False, transform=transform | |||
| ) | |||
| test_dataset = datasets.CIFAR10( | |||
| root=extract_path, train=False, download=False, transform=transform | |||
| ) | |||
| return train_dataset, test_dataset | |||
| # 训练函数 | |||
| def train(model, train_loader, criterion, optimizer, device, epochs=3): | |||
| model.train() | |||
| for epoch in range(epochs): | |||
| running_loss = 0.0 | |||
| for i, (inputs, labels) in enumerate(train_loader): | |||
| inputs, labels = inputs.to(device), labels.to(device) | |||
| optimizer.zero_grad() | |||
| outputs = model(inputs) | |||
| loss = criterion(outputs, labels) | |||
| loss.backward() | |||
| optimizer.step() | |||
| running_loss += loss.item() | |||
| if i % 100 == 99: | |||
| print( | |||
| f"Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}" | |||
| ) | |||
| running_loss = 0.0 | |||
| # 测试函数 | |||
| def test(model, test_loader, device): | |||
| model.eval() | |||
| correct = 0 | |||
| total = 0 | |||
| with torch.no_grad(): | |||
| for inputs, labels in test_loader: | |||
| inputs, labels = inputs.to(device), labels.to(device) | |||
| outputs = model(inputs) | |||
| _, predicted = torch.max(outputs.data, 1) | |||
| total += labels.size(0) | |||
| correct += (predicted == labels).sum().item() | |||
| print(f"Accuracy of the model on the test images: {100 * correct / total:.2f}%") | |||
| # 保存模型 | |||
| def save_model(model, model_save_path): | |||
| model_dir = os.path.dirname(model_save_path) | |||
| if not os.path.exists(model_dir): | |||
| os.makedirs(model_dir) | |||
| torch.save(model.state_dict(), model_save_path) | |||
| # 主函数 | |||
| def main(): | |||
| # 获取设备 | |||
| device = get_device() | |||
| # 数据预处理 | |||
| transform = get_transform() | |||
| # 数据集路径 | |||
| data_tar_path = "/userhome/dataset/5/submissionScript/cifar-10-python.tar.gz" | |||
| extract_path = "./data" | |||
| # 加载数据集 | |||
| train_dataset, test_dataset = load_datasets(data_tar_path, extract_path, transform) | |||
| # 创建数据加载器 | |||
| train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | |||
| test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) | |||
| # 初始化模型 | |||
| model = ThreeLayerCNN().to(device) | |||
| # 定义损失函数和优化器 | |||
| criterion = nn.CrossEntropyLoss() | |||
| optimizer = optim.Adam(model.parameters(), lr=0.001) | |||
| # 开始训练 | |||
| train(model, train_loader, criterion, optimizer, device, epochs=500) | |||
| # 进行测试 | |||
| test(model, test_loader, device) | |||
| # 保存模型 | |||
| model_save_path = ( | |||
| "/model/three_layer_cnn_cifar10_500epo_octopus.pth" | |||
| ) | |||
| save_model(model, model_save_path) | |||
| if __name__ == "__main__": | |||
| main() | |||