Browse Source

chore: improve README

master
KSkun 3 years ago
parent
commit
d5299d3daf
2 changed files with 33 additions and 12 deletions
  1. +11
    -2
      CGAN.py
  2. +22
    -10
      README.md

+ 11
- 2
CGAN.py View File

@@ -4,6 +4,7 @@ import numpy as np
from PIL import Image
import os
import time
import sys

jt.flags.use_cuda = 1

@@ -158,5 +159,13 @@ def generate(numbers, epoch, filename):


if __name__ == '__main__':
# train()
generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], 87, 'result1.png')
if len(sys.argv) < 2:
print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]')
exit(1)
if sys.argv[1] == 'train':
train()
elif sys.argv[1] == 'eval':
generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], int(sys.argv[3]), sys.argv[2])
else:
print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]')
exit(1)

+ 22
- 10
README.md View File

@@ -8,30 +8,42 @@

## 安装

#### 运行环境
### 运行环境

- Jittor 1.3.4.12
- Numpy
- Pillow
- Numpy 1.22.4
- Pillow 9.1.1

#### 安装依赖
### 安装依赖
执行以下命令安装 python 依赖

```
```bash
pip install -r requirements.txt
```

#### 预训练模型
### 预训练模型

预训练模型[下载地址](https://drive.google.com/drive/folders/1qnS1SXwtmR-2H-i8vDxW1W4H5k1E2ENI?usp=sharing)。请将预训练模型放入 `saved/` 文件夹中供代码加载,并不要改变其文件名。

预训练模型[下载地址](https://drive.google.com/drive/folders/1qnS1SXwtmR-2H-i8vDxW1W4H5k1E2ENI?usp=sharing),可通过 `load` 导入系数。
## 运行

## 训练
### 训练

```bash
python CGAN.py train
```

运行上述指令即可开始训练,训练使用 Jittor 框架内置的 MNIST 数据集格式,训练过程中每一个 epoch 的模型系数和结果文件将保存在 `saved/` 文件夹下。

`CGAN.py` 中 `train()` 函数为用于训练的函数,调用它即可运行训练。相关参数见文件内。

## 推理
### 推理

```bash
python CGAN.py eval result.png 87
```

将参数放入 `saved` 文件夹内后,可以调用 `generate(numbers, epoch, filename)` 函数生成图像。其中 `numbers` 为要生成的数字数组,`epoch` 为模型训练的代数,预训练模型为 87,`filename` 为生成图像保存路径。
预训练模型或自行训练的系数 `pkl` 文件放入 `saved/` 文件夹下,并使用上述指令进行推理。推理命令的格式为 `python CGAN.py [output file] [epoch]`,预训练模型 epoch 数为 87,故给出的上述指令使用的是预训练模型推理

参见 `CGAN.py` 中的实现。



Loading…
Cancel
Save