diff --git a/CGAN.py b/CGAN.py index b14bff0..6a07b3e 100644 --- a/CGAN.py +++ b/CGAN.py @@ -4,6 +4,7 @@ import numpy as np from PIL import Image import os import time +import sys jt.flags.use_cuda = 1 @@ -158,5 +159,13 @@ def generate(numbers, epoch, filename): if __name__ == '__main__': - # train() - generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], 87, 'result1.png') + if len(sys.argv) < 2: + print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]') + exit(1) + if sys.argv[1] == 'train': + train() + elif sys.argv[1] == 'eval': + generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], int(sys.argv[3]), sys.argv[2]) + else: + print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]') + exit(1) diff --git a/README.md b/README.md index 9e6860a..1474e38 100644 --- a/README.md +++ b/README.md @@ -8,30 +8,42 @@ ## 安装 -#### 运行环境 +### 运行环境 - Jittor 1.3.4.12 -- Numpy -- Pillow +- Numpy 1.22.4 +- Pillow 9.1.1 -#### 安装依赖 +### 安装依赖 执行以下命令安装 python 依赖 -``` +```bash pip install -r requirements.txt ``` -#### 预训练模型 +### 预训练模型 + +预训练模型[下载地址](https://drive.google.com/drive/folders/1qnS1SXwtmR-2H-i8vDxW1W4H5k1E2ENI?usp=sharing)。请将预训练模型放入 `saved/` 文件夹中供代码加载,并不要改变其文件名。 -预训练模型[下载地址](https://drive.google.com/drive/folders/1qnS1SXwtmR-2H-i8vDxW1W4H5k1E2ENI?usp=sharing),可通过 `load` 导入系数。 +## 运行 -## 训练 +### 训练 + +```bash +python CGAN.py train +``` + +运行上述指令即可开始训练,训练使用 Jittor 框架内置的 MNIST 数据集格式,训练过程中每一个 epoch 的模型系数和结果文件将保存在 `saved/` 文件夹下。 `CGAN.py` 中 `train()` 函数为用于训练的函数,调用它即可运行训练。相关参数见文件内。 -## 推理 +### 推理 + +```bash +python CGAN.py eval result.png 87 +``` -将参数放入 `saved` 文件夹内后,可以调用 `generate(numbers, epoch, filename)` 函数生成图像。其中 `numbers` 为要生成的数字数组,`epoch` 为模型训练的代数,预训练模型为 87,`filename` 为生成图像保存路径。 +将预训练模型或自行训练的系数 `pkl` 文件放入 `saved/` 文件夹下,并使用上述指令进行推理。推理命令的格式为 `python CGAN.py [output file] [epoch]`,预训练模型 epoch 数为 87,故给出的上述指令使用的是预训练模型推理。 参见 `CGAN.py` 中的实现。