|
- import jittor as jt
- from jittor import nn, transform
- import numpy as np
- from PIL import Image
- import os
- import time
- import sys
-
- jt.flags.use_cuda = 1
-
- # hyper params
- epochs = 100
- batch_size = 64
- lr = 0.0001
- latent_dim = 100
-
-
- # ref https://github.com/arturml/mnist-cgan/blob/master/mnist-cgan.ipynb
-
- # generator
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.label_emb = nn.Embedding(10, 10)
- self.model = nn.Sequential(
- nn.Linear(110, 256),
- nn.LeakyReLU(0.2),
- nn.Linear(256, 512),
- nn.LeakyReLU(0.2),
- nn.Linear(512, 1024),
- nn.LeakyReLU(0.2),
- nn.Linear(1024, 784),
- nn.Tanh()
- )
-
- def execute(self, noise, labels):
- labels = self.label_emb(labels)
- x = jt.concat([noise, labels], 1)
- out = self.model(x)
- return out.view(out.shape[0], 1, 28, 28)
-
-
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.label_emb = nn.Embedding(10, 10)
- self.model = nn.Sequential(
- nn.Linear(794, 1024),
- nn.LeakyReLU(0.2),
- nn.Linear(1024, 512),
- nn.LeakyReLU(0.2),
- nn.Linear(512, 256),
- nn.LeakyReLU(0.2),
- nn.Linear(256, 1),
- nn.Sigmoid()
- )
-
- def execute(self, images, labels):
- images = images.view(images.shape[0], -1)
- labels = self.label_emb(labels)
- x = jt.concat([images, labels], 1)
- out = self.model(x)
- return out
-
-
- generator = Generator()
- discriminator = Discriminator()
-
- trans = transform.Compose([
- transform.Gray(),
- transform.ImageNormalize([0.5], [0.5])
- ])
- data = jt.dataset.MNIST(train=True, transform=trans).set_attrs(batch_size=batch_size, shuffle=True)
-
- loss = nn.BCELoss()
- generator_opt = jt.optim.Adam(generator.parameters(), lr)
- discriminator_opt = jt.optim.Adam(discriminator.parameters(), lr)
-
- saved_path = 'saved'
- if not os.path.exists(saved_path):
- os.mkdir(saved_path)
-
-
- def save_image(filename):
- noise = jt.array(np.random.randn(10, 100)).float32().stop_grad()
- labels = jt.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
- images = generator(noise, labels).numpy()
- image_all = []
- for i in range(10):
- image_all.append(images[i])
- image_all = np.concatenate(image_all, 1)
- image_all = (image_all + 1) / 2 * 255
- image_all = image_all.transpose(1, 2, 0)
- image_all = image_all[:, :, 0]
- Image.fromarray(np.uint8(image_all)).save(filename)
-
-
- def train():
- generator.train()
- discriminator.train()
- for epoch in range(epochs):
- start_time = time.time()
- for i, (images, labels) in enumerate(data):
- batch_size = images.shape[0]
-
- valid = jt.ones([batch_size, 1]).float32().stop_grad()
- fake = jt.zeros([batch_size, 1]).float32().stop_grad()
-
- # train generator
- noise = jt.array(np.random.randn(batch_size, 100))
- fake_labels = jt.array(np.random.randint(0, 10, batch_size))
- fake_images = generator(noise, fake_labels)
- validity = discriminator(fake_images, fake_labels)
- generator_loss = loss(validity, valid)
- generator_loss.sync()
- generator_opt.step(generator_loss)
-
- # train discriminator
- # real images
- real_images = jt.array(images)
- labels = jt.array(labels)
- real_validity = discriminator(real_images, labels)
- real_loss = loss(real_validity, valid)
- # fake images
- noise = jt.array(np.random.randn(batch_size, 100))
- fake_validity = discriminator(fake_images.stop_grad(), fake_labels)
- fake_loss = loss(fake_validity, fake)
- discriminator_loss = (real_loss + fake_loss) / 2
- discriminator_loss.sync()
- discriminator_opt.step(discriminator_loss)
-
- if i % 50 == 0:
- print('Epoch %d Batch %d Loss g %f d %f' % (epoch, i, generator_loss.data, discriminator_loss.data))
-
- print('Epoch %d done Time %s' % (epoch, time.time() - start_time))
-
- generator.save('saved/generator_%d.pkl' % epoch)
- discriminator.save('saved/discriminator_%d.pkl' % epoch)
- save_image('saved/gen_%d.png' % epoch)
-
-
- def generate(numbers, epoch, filename):
- generator_g = Generator()
- generator_g.eval()
- generator_g.load('saved/generator_%d.pkl' % epoch)
-
- num_labels = len(numbers)
- noise = jt.array(np.random.randn(num_labels, 100)).float32().stop_grad()
- labels = jt.array(numbers).float32().stop_grad()
- images = generator_g(noise, labels).numpy()
- image_all = []
- for i in range(num_labels):
- image_all.append(images[i])
- image_all = np.concatenate(image_all, 2)
- image_all = (image_all + 1) / 2 * 255
- image_all = image_all.transpose(1, 2, 0)
- image_all = image_all[:, :, 0]
- Image.fromarray(np.uint8(image_all)).save(filename)
-
-
- if __name__ == '__main__':
- if len(sys.argv) < 2:
- print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]')
- exit(1)
- if sys.argv[1] == 'train':
- train()
- elif sys.argv[1] == 'eval':
- generate([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], int(sys.argv[3]), sys.argv[2])
- else:
- print('usage: \npython CGAN.py train\npython CGAN.py eval [output file] [epoch]')
- exit(1)
|