| @@ -4,6 +4,7 @@ import numpy as np | |||||
| from PIL import Image | from PIL import Image | ||||
| import os | import os | ||||
| import time | import time | ||||
| import sys | |||||
| jt.flags.use_cuda = 1 | jt.flags.use_cuda = 1 | ||||
| @@ -158,5 +159,13 @@ def generate(numbers, epoch, filename): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| # train() | |||||
| generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], 87, 'result1.png') | |||||
| if len(sys.argv) < 2: | |||||
| print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]') | |||||
| exit(1) | |||||
| if sys.argv[1] == 'train': | |||||
| train() | |||||
| elif sys.argv[1] == 'eval': | |||||
| generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], int(sys.argv[3]), sys.argv[2]) | |||||
| else: | |||||
| print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]') | |||||
| exit(1) | |||||
| @@ -8,30 +8,42 @@ | |||||
| ## 安装 | ## 安装 | ||||
| #### 运行环境 | |||||
| ### 运行环境 | |||||
| - Jittor 1.3.4.12 | - Jittor 1.3.4.12 | ||||
| - Numpy | |||||
| - Pillow | |||||
| - Numpy 1.22.4 | |||||
| - Pillow 9.1.1 | |||||
| #### 安装依赖 | |||||
| ### 安装依赖 | |||||
| 执行以下命令安装 python 依赖 | 执行以下命令安装 python 依赖 | ||||
| ``` | |||||
| ```bash | |||||
| pip install -r requirements.txt | pip install -r requirements.txt | ||||
| ``` | ``` | ||||
| #### 预训练模型 | |||||
| ### 预训练模型 | |||||
| 预训练模型[下载地址](https://drive.google.com/drive/folders/1qnS1SXwtmR-2H-i8vDxW1W4H5k1E2ENI?usp=sharing)。请将预训练模型放入 `saved/` 文件夹中供代码加载,并不要改变其文件名。 | |||||
| 预训练模型[下载地址](https://drive.google.com/drive/folders/1qnS1SXwtmR-2H-i8vDxW1W4H5k1E2ENI?usp=sharing),可通过 `load` 导入系数。 | |||||
| ## 运行 | |||||
| ## 训练 | |||||
| ### 训练 | |||||
| ```bash | |||||
| python CGAN.py train | |||||
| ``` | |||||
| 运行上述指令即可开始训练,训练使用 Jittor 框架内置的 MNIST 数据集格式,训练过程中每一个 epoch 的模型系数和结果文件将保存在 `saved/` 文件夹下。 | |||||
| `CGAN.py` 中 `train()` 函数为用于训练的函数,调用它即可运行训练。相关参数见文件内。 | `CGAN.py` 中 `train()` 函数为用于训练的函数,调用它即可运行训练。相关参数见文件内。 | ||||
| ## 推理 | |||||
| ### 推理 | |||||
| ```bash | |||||
| python CGAN.py eval result.png 87 | |||||
| ``` | |||||
| 将参数放入 `saved` 文件夹内后,可以调用 `generate(numbers, epoch, filename)` 函数生成图像。其中 `numbers` 为要生成的数字数组,`epoch` 为模型训练的代数,预训练模型为 87,`filename` 为生成图像保存路径。 | |||||
| 将预训练模型或自行训练的系数 `pkl` 文件放入 `saved/` 文件夹下,并使用上述指令进行推理。推理命令的格式为 `python CGAN.py [output file] [epoch]`,预训练模型 epoch 数为 87,故给出的上述指令使用的是预训练模型推理。 | |||||
| 参见 `CGAN.py` 中的实现。 | 参见 `CGAN.py` 中的实现。 | ||||