diff --git a/.gitignore b/.gitignore index 31a549b..01c8f8e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .idea/ -*.png +**/*.png +**/*.pkl +saved/ diff --git a/CGAN.py b/CGAN.py new file mode 100644 index 0000000..1ecc4c0 --- /dev/null +++ b/CGAN.py @@ -0,0 +1,162 @@ +import jittor as jt +from jittor import nn, transform +import numpy as np +from PIL import Image +import os +import time + +jt.flags.use_cuda = 1 + +# hyper params +epochs = 100 +batch_size = 64 +lr = 0.0001 +latent_dim = 100 + + +# ref https://github.com/arturml/mnist-cgan/blob/master/mnist-cgan.ipynb + +# generator +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.label_emb = nn.Embedding(10, 10) + self.model = nn.Sequential( + nn.Linear(110, 256), + nn.LeakyReLU(0.2), + nn.Linear(256, 512), + nn.LeakyReLU(0.2), + nn.Linear(512, 1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, 784), + nn.Tanh() + ) + + def execute(self, noise, labels): + labels = self.label_emb(labels) + x = jt.concat([noise, labels], 1) + out = self.model(x) + return out.view(out.shape[0], 1, 28, 28) + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.label_emb = nn.Embedding(10, 10) + self.model = nn.Sequential( + nn.Linear(794, 1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, 512), + nn.LeakyReLU(0.2), + nn.Linear(512, 256), + nn.LeakyReLU(0.2), + nn.Linear(256, 1), + nn.Sigmoid() + ) + + def execute(self, images, labels): + images = images.view(images.shape[0], -1) + labels = self.label_emb(labels) + x = jt.concat([images, labels], 1) + out = self.model(x) + return out + + +generator = Generator() +discriminator = Discriminator() + +trans = transform.Compose([ + transform.Gray(), + transform.ImageNormalize([0.5], [0.5]) +]) +data = jt.dataset.MNIST(train=True, transform=trans).set_attrs(batch_size=batch_size, shuffle=True) + +loss = nn.BCELoss() +generator_opt = jt.optim.Adam(generator.parameters(), lr) +discriminator_opt = jt.optim.Adam(discriminator.parameters(), lr) + +saved_path = 'saved' +if not os.path.exists(saved_path): + os.mkdir(saved_path) + + +def save_image(filename): + noise = jt.array(np.random.randn(10, 100)).float32().stop_grad() + labels = jt.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + images = generator(noise, labels).numpy() + image_all = [] + for i in range(10): + image_all.append(images[i]) + image_all = np.concatenate(image_all, 1) + image_all = (image_all + 1) / 2 * 255 + image_all = image_all.transpose(1, 2, 0) + image_all = image_all[:, :, 0] + Image.fromarray(np.uint8(image_all)).save(filename) + + +def train(): + generator.train() + discriminator.train() + for epoch in range(epochs): + start_time = time.time() + for i, (images, labels) in enumerate(data): + batch_size = images.shape[0] + + valid = jt.ones([batch_size, 1]).float32().stop_grad() + fake = jt.zeros([batch_size, 1]).float32().stop_grad() + + # train generator + noise = jt.array(np.random.randn(batch_size, 100)) + fake_labels = jt.array(np.random.randint(0, 10, batch_size)) + fake_images = generator(noise, fake_labels) + validity = discriminator(fake_images, fake_labels) + generator_loss = loss(validity, valid) + generator_loss.sync() + generator_opt.step(generator_loss) + + # train discriminator + # real images + real_images = jt.array(images) + labels = jt.array(labels) + real_validity = discriminator(real_images, labels) + real_loss = loss(real_validity, valid) + # fake images + noise = jt.array(np.random.randn(batch_size, 100)) + fake_validity = discriminator(fake_images.stop_grad(), fake_labels) + fake_loss = loss(fake_validity, fake) + discriminator_loss = (real_loss + fake_loss) / 2 + discriminator_loss.sync() + discriminator_opt.step(discriminator_loss) + + if i % 50 == 0: + print('Epoch %d Batch %d Loss g %f d %f' % (epoch, i, generator_loss.data, discriminator_loss.data)) + + print('Epoch %d done Time %s' % (epoch, time.time() - start_time)) + + generator.save('saved/generator_%d.pkl' % epoch) + discriminator.save('saved/discriminator_%d.pkl' % epoch) + save_image('saved/gen_%d.png' % epoch) + + +def generate(numbers, epoch, filename): + generator_g = Generator() + generator_g.eval() + generator_g.load('saved/generator_%d.pkl' % epoch) + + num_labels = len(numbers) + noise = jt.array(np.random.randn(num_labels, 100)).float32().stop_grad() + labels = jt.array(numbers).float32().stop_grad() + images = generator_g(noise, labels).numpy() + image_all = [] + for i in range(num_labels): + image_all.append(images[i]) + image_all = np.concatenate(image_all, 2) + image_all = (image_all + 1) / 2 * 255 + image_all = image_all.transpose(1, 2, 0) + image_all = image_all[:, :, 0] + Image.fromarray(np.uint8(image_all)).save(filename) + + +if __name__ == '__main__': + # train() + generate([1, 2, 3, 3, 2, 1], 87, 'result.png')