@@ -4,6 +4,7 @@ import numpy as np | |||||
from PIL import Image | from PIL import Image | ||||
import os | import os | ||||
import time | import time | ||||
import sys | |||||
jt.flags.use_cuda = 1 | jt.flags.use_cuda = 1 | ||||
@@ -158,5 +159,13 @@ def generate(numbers, epoch, filename): | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
# train() | |||||
generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], 87, 'result1.png') | |||||
if len(sys.argv) < 2: | |||||
print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]') | |||||
exit(1) | |||||
if sys.argv[1] == 'train': | |||||
train() | |||||
elif sys.argv[1] == 'eval': | |||||
generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], int(sys.argv[3]), sys.argv[2]) | |||||
else: | |||||
print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]') | |||||
exit(1) |
@@ -8,30 +8,42 @@ | |||||
## 安装 | ## 安装 | ||||
#### 运行环境 | |||||
### 运行环境 | |||||
- Jittor 1.3.4.12 | - Jittor 1.3.4.12 | ||||
- Numpy | |||||
- Pillow | |||||
- Numpy 1.22.4 | |||||
- Pillow 9.1.1 | |||||
#### 安装依赖 | |||||
### 安装依赖 | |||||
执行以下命令安装 python 依赖 | 执行以下命令安装 python 依赖 | ||||
``` | |||||
```bash | |||||
pip install -r requirements.txt | pip install -r requirements.txt | ||||
``` | ``` | ||||
#### 预训练模型 | |||||
### 预训练模型 | |||||
预训练模型[下载地址](https://drive.google.com/drive/folders/1qnS1SXwtmR-2H-i8vDxW1W4H5k1E2ENI?usp=sharing)。请将预训练模型放入 `saved/` 文件夹中供代码加载,并不要改变其文件名。 | |||||
预训练模型[下载地址](https://drive.google.com/drive/folders/1qnS1SXwtmR-2H-i8vDxW1W4H5k1E2ENI?usp=sharing),可通过 `load` 导入系数。 | |||||
## 运行 | |||||
## 训练 | |||||
### 训练 | |||||
```bash | |||||
python CGAN.py train | |||||
``` | |||||
运行上述指令即可开始训练,训练使用 Jittor 框架内置的 MNIST 数据集格式,训练过程中每一个 epoch 的模型系数和结果文件将保存在 `saved/` 文件夹下。 | |||||
`CGAN.py` 中 `train()` 函数为用于训练的函数,调用它即可运行训练。相关参数见文件内。 | `CGAN.py` 中 `train()` 函数为用于训练的函数,调用它即可运行训练。相关参数见文件内。 | ||||
## 推理 | |||||
### 推理 | |||||
```bash | |||||
python CGAN.py eval result.png 87 | |||||
``` | |||||
将参数放入 `saved` 文件夹内后,可以调用 `generate(numbers, epoch, filename)` 函数生成图像。其中 `numbers` 为要生成的数字数组,`epoch` 为模型训练的代数,预训练模型为 87,`filename` 为生成图像保存路径。 | |||||
将预训练模型或自行训练的系数 `pkl` 文件放入 `saved/` 文件夹下,并使用上述指令进行推理。推理命令的格式为 `python CGAN.py [output file] [epoch]`,预训练模型 epoch 数为 87,故给出的上述指令使用的是预训练模型推理。 | |||||
参见 `CGAN.py` 中的实现。 | 参见 `CGAN.py` 中的实现。 | ||||