|
|
@@ -0,0 +1,162 @@ |
|
|
|
import jittor as jt |
|
|
|
from jittor import nn, transform |
|
|
|
import numpy as np |
|
|
|
from PIL import Image |
|
|
|
import os |
|
|
|
import time |
|
|
|
|
|
|
|
jt.flags.use_cuda = 1 |
|
|
|
|
|
|
|
# hyper params |
|
|
|
epochs = 100 |
|
|
|
batch_size = 64 |
|
|
|
lr = 0.0001 |
|
|
|
latent_dim = 100 |
|
|
|
|
|
|
|
|
|
|
|
# ref https://github.com/arturml/mnist-cgan/blob/master/mnist-cgan.ipynb |
|
|
|
|
|
|
|
# generator |
|
|
|
class Generator(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super(Generator, self).__init__() |
|
|
|
self.label_emb = nn.Embedding(10, 10) |
|
|
|
self.model = nn.Sequential( |
|
|
|
nn.Linear(110, 256), |
|
|
|
nn.LeakyReLU(0.2), |
|
|
|
nn.Linear(256, 512), |
|
|
|
nn.LeakyReLU(0.2), |
|
|
|
nn.Linear(512, 1024), |
|
|
|
nn.LeakyReLU(0.2), |
|
|
|
nn.Linear(1024, 784), |
|
|
|
nn.Tanh() |
|
|
|
) |
|
|
|
|
|
|
|
def execute(self, noise, labels): |
|
|
|
labels = self.label_emb(labels) |
|
|
|
x = jt.concat([noise, labels], 1) |
|
|
|
out = self.model(x) |
|
|
|
return out.view(out.shape[0], 1, 28, 28) |
|
|
|
|
|
|
|
|
|
|
|
class Discriminator(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super(Discriminator, self).__init__() |
|
|
|
self.label_emb = nn.Embedding(10, 10) |
|
|
|
self.model = nn.Sequential( |
|
|
|
nn.Linear(794, 1024), |
|
|
|
nn.LeakyReLU(0.2), |
|
|
|
nn.Linear(1024, 512), |
|
|
|
nn.LeakyReLU(0.2), |
|
|
|
nn.Linear(512, 256), |
|
|
|
nn.LeakyReLU(0.2), |
|
|
|
nn.Linear(256, 1), |
|
|
|
nn.Sigmoid() |
|
|
|
) |
|
|
|
|
|
|
|
def execute(self, images, labels): |
|
|
|
images = images.view(images.shape[0], -1) |
|
|
|
labels = self.label_emb(labels) |
|
|
|
x = jt.concat([images, labels], 1) |
|
|
|
out = self.model(x) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
generator = Generator() |
|
|
|
discriminator = Discriminator() |
|
|
|
|
|
|
|
trans = transform.Compose([ |
|
|
|
transform.Gray(), |
|
|
|
transform.ImageNormalize([0.5], [0.5]) |
|
|
|
]) |
|
|
|
data = jt.dataset.MNIST(train=True, transform=trans).set_attrs(batch_size=batch_size, shuffle=True) |
|
|
|
|
|
|
|
loss = nn.BCELoss() |
|
|
|
generator_opt = jt.optim.Adam(generator.parameters(), lr) |
|
|
|
discriminator_opt = jt.optim.Adam(discriminator.parameters(), lr) |
|
|
|
|
|
|
|
saved_path = 'saved' |
|
|
|
if not os.path.exists(saved_path): |
|
|
|
os.mkdir(saved_path) |
|
|
|
|
|
|
|
|
|
|
|
def save_image(filename): |
|
|
|
noise = jt.array(np.random.randn(10, 100)).float32().stop_grad() |
|
|
|
labels = jt.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) |
|
|
|
images = generator(noise, labels).numpy() |
|
|
|
image_all = [] |
|
|
|
for i in range(10): |
|
|
|
image_all.append(images[i]) |
|
|
|
image_all = np.concatenate(image_all, 1) |
|
|
|
image_all = (image_all + 1) / 2 * 255 |
|
|
|
image_all = image_all.transpose(1, 2, 0) |
|
|
|
image_all = image_all[:, :, 0] |
|
|
|
Image.fromarray(np.uint8(image_all)).save(filename) |
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
|
generator.train() |
|
|
|
discriminator.train() |
|
|
|
for epoch in range(epochs): |
|
|
|
start_time = time.time() |
|
|
|
for i, (images, labels) in enumerate(data): |
|
|
|
batch_size = images.shape[0] |
|
|
|
|
|
|
|
valid = jt.ones([batch_size, 1]).float32().stop_grad() |
|
|
|
fake = jt.zeros([batch_size, 1]).float32().stop_grad() |
|
|
|
|
|
|
|
# train generator |
|
|
|
noise = jt.array(np.random.randn(batch_size, 100)) |
|
|
|
fake_labels = jt.array(np.random.randint(0, 10, batch_size)) |
|
|
|
fake_images = generator(noise, fake_labels) |
|
|
|
validity = discriminator(fake_images, fake_labels) |
|
|
|
generator_loss = loss(validity, valid) |
|
|
|
generator_loss.sync() |
|
|
|
generator_opt.step(generator_loss) |
|
|
|
|
|
|
|
# train discriminator |
|
|
|
# real images |
|
|
|
real_images = jt.array(images) |
|
|
|
labels = jt.array(labels) |
|
|
|
real_validity = discriminator(real_images, labels) |
|
|
|
real_loss = loss(real_validity, valid) |
|
|
|
# fake images |
|
|
|
noise = jt.array(np.random.randn(batch_size, 100)) |
|
|
|
fake_validity = discriminator(fake_images.stop_grad(), fake_labels) |
|
|
|
fake_loss = loss(fake_validity, fake) |
|
|
|
discriminator_loss = (real_loss + fake_loss) / 2 |
|
|
|
discriminator_loss.sync() |
|
|
|
discriminator_opt.step(discriminator_loss) |
|
|
|
|
|
|
|
if i % 50 == 0: |
|
|
|
print('Epoch %d Batch %d Loss g %f d %f' % (epoch, i, generator_loss.data, discriminator_loss.data)) |
|
|
|
|
|
|
|
print('Epoch %d done Time %s' % (epoch, time.time() - start_time)) |
|
|
|
|
|
|
|
generator.save('saved/generator_%d.pkl' % epoch) |
|
|
|
discriminator.save('saved/discriminator_%d.pkl' % epoch) |
|
|
|
save_image('saved/gen_%d.png' % epoch) |
|
|
|
|
|
|
|
|
|
|
|
def generate(numbers, epoch, filename): |
|
|
|
generator_g = Generator() |
|
|
|
generator_g.eval() |
|
|
|
generator_g.load('saved/generator_%d.pkl' % epoch) |
|
|
|
|
|
|
|
num_labels = len(numbers) |
|
|
|
noise = jt.array(np.random.randn(num_labels, 100)).float32().stop_grad() |
|
|
|
labels = jt.array(numbers).float32().stop_grad() |
|
|
|
images = generator_g(noise, labels).numpy() |
|
|
|
image_all = [] |
|
|
|
for i in range(num_labels): |
|
|
|
image_all.append(images[i]) |
|
|
|
image_all = np.concatenate(image_all, 2) |
|
|
|
image_all = (image_all + 1) / 2 * 255 |
|
|
|
image_all = image_all.transpose(1, 2, 0) |
|
|
|
image_all = image_all[:, :, 0] |
|
|
|
Image.fromarray(np.uint8(image_all)).save(filename) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
# train() |
|
|
|
generate([1, 2, 3, 3, 2, 1], 87, 'result.png') |