import jittor as jt from jittor import nn, transform import numpy as np from PIL import Image import os import time import sys jt.flags.use_cuda = 1 # hyper params epochs = 100 batch_size = 64 lr = 0.0001 latent_dim = 100 # ref https://github.com/arturml/mnist-cgan/blob/master/mnist-cgan.ipynb # generator class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.label_emb = nn.Embedding(10, 10) self.model = nn.Sequential( nn.Linear(110, 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1024), nn.LeakyReLU(0.2), nn.Linear(1024, 784), nn.Tanh() ) def execute(self, noise, labels): labels = self.label_emb(labels) x = jt.concat([noise, labels], 1) out = self.model(x) return out.view(out.shape[0], 1, 28, 28) class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.label_emb = nn.Embedding(10, 10) self.model = nn.Sequential( nn.Linear(794, 1024), nn.LeakyReLU(0.2), nn.Linear(1024, 512), nn.LeakyReLU(0.2), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) def execute(self, images, labels): images = images.view(images.shape[0], -1) labels = self.label_emb(labels) x = jt.concat([images, labels], 1) out = self.model(x) return out generator = Generator() discriminator = Discriminator() trans = transform.Compose([ transform.Gray(), transform.ImageNormalize([0.5], [0.5]) ]) data = jt.dataset.MNIST(train=True, transform=trans).set_attrs(batch_size=batch_size, shuffle=True) loss = nn.BCELoss() generator_opt = jt.optim.Adam(generator.parameters(), lr) discriminator_opt = jt.optim.Adam(discriminator.parameters(), lr) saved_path = 'saved' if not os.path.exists(saved_path): os.mkdir(saved_path) def save_image(filename): noise = jt.array(np.random.randn(10, 100)).float32().stop_grad() labels = jt.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) images = generator(noise, labels).numpy() image_all = [] for i in range(10): image_all.append(images[i]) image_all = np.concatenate(image_all, 1) image_all = (image_all + 1) / 2 * 255 image_all = image_all.transpose(1, 2, 0) image_all = image_all[:, :, 0] Image.fromarray(np.uint8(image_all)).save(filename) def train(): generator.train() discriminator.train() for epoch in range(epochs): start_time = time.time() for i, (images, labels) in enumerate(data): batch_size = images.shape[0] valid = jt.ones([batch_size, 1]).float32().stop_grad() fake = jt.zeros([batch_size, 1]).float32().stop_grad() # train generator noise = jt.array(np.random.randn(batch_size, 100)) fake_labels = jt.array(np.random.randint(0, 10, batch_size)) fake_images = generator(noise, fake_labels) validity = discriminator(fake_images, fake_labels) generator_loss = loss(validity, valid) generator_loss.sync() generator_opt.step(generator_loss) # train discriminator # real images real_images = jt.array(images) labels = jt.array(labels) real_validity = discriminator(real_images, labels) real_loss = loss(real_validity, valid) # fake images noise = jt.array(np.random.randn(batch_size, 100)) fake_validity = discriminator(fake_images.stop_grad(), fake_labels) fake_loss = loss(fake_validity, fake) discriminator_loss = (real_loss + fake_loss) / 2 discriminator_loss.sync() discriminator_opt.step(discriminator_loss) if i % 50 == 0: print('Epoch %d Batch %d Loss g %f d %f' % (epoch, i, generator_loss.data, discriminator_loss.data)) print('Epoch %d done Time %s' % (epoch, time.time() - start_time)) generator.save('saved/generator_%d.pkl' % epoch) discriminator.save('saved/discriminator_%d.pkl' % epoch) save_image('saved/gen_%d.png' % epoch) def generate(numbers, epoch, filename): generator_g = Generator() generator_g.eval() generator_g.load('saved/generator_%d.pkl' % epoch) num_labels = len(numbers) noise = jt.array(np.random.randn(num_labels, 100)).float32().stop_grad() labels = jt.array(numbers).float32().stop_grad() images = generator_g(noise, labels).numpy() image_all = [] for i in range(num_labels): image_all.append(images[i]) image_all = np.concatenate(image_all, 2) image_all = (image_all + 1) / 2 * 255 image_all = image_all.transpose(1, 2, 0) image_all = image_all[:, :, 0] Image.fromarray(np.uint8(image_all)).save(filename) if __name__ == '__main__': if len(sys.argv) < 2: print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]') exit(1) if sys.argv[1] == 'train': train() elif sys.argv[1] == 'eval': generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], int(sys.argv[3]), sys.argv[2]) else: print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]') exit(1)