Browse Source

feat: cgan for generating hand-written numbers

master
KSkun 3 years ago
parent
commit
96a3cf9298
2 changed files with 165 additions and 1 deletions
  1. +3
    -1
      .gitignore
  2. +162
    -0
      CGAN.py

+ 3
- 1
.gitignore View File

@@ -1,2 +1,4 @@
.idea/
*.png
**/*.png
**/*.pkl
saved/

+ 162
- 0
CGAN.py View File

@@ -0,0 +1,162 @@
import jittor as jt
from jittor import nn, transform
import numpy as np
from PIL import Image
import os
import time

jt.flags.use_cuda = 1

# hyper params
epochs = 100
batch_size = 64
lr = 0.0001
latent_dim = 100


# ref https://github.com/arturml/mnist-cgan/blob/master/mnist-cgan.ipynb

# generator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(10, 10)
self.model = nn.Sequential(
nn.Linear(110, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 784),
nn.Tanh()
)

def execute(self, noise, labels):
labels = self.label_emb(labels)
x = jt.concat([noise, labels], 1)
out = self.model(x)
return out.view(out.shape[0], 1, 28, 28)


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_emb = nn.Embedding(10, 10)
self.model = nn.Sequential(
nn.Linear(794, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)

def execute(self, images, labels):
images = images.view(images.shape[0], -1)
labels = self.label_emb(labels)
x = jt.concat([images, labels], 1)
out = self.model(x)
return out


generator = Generator()
discriminator = Discriminator()

trans = transform.Compose([
transform.Gray(),
transform.ImageNormalize([0.5], [0.5])
])
data = jt.dataset.MNIST(train=True, transform=trans).set_attrs(batch_size=batch_size, shuffle=True)

loss = nn.BCELoss()
generator_opt = jt.optim.Adam(generator.parameters(), lr)
discriminator_opt = jt.optim.Adam(discriminator.parameters(), lr)

saved_path = 'saved'
if not os.path.exists(saved_path):
os.mkdir(saved_path)


def save_image(filename):
noise = jt.array(np.random.randn(10, 100)).float32().stop_grad()
labels = jt.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
images = generator(noise, labels).numpy()
image_all = []
for i in range(10):
image_all.append(images[i])
image_all = np.concatenate(image_all, 1)
image_all = (image_all + 1) / 2 * 255
image_all = image_all.transpose(1, 2, 0)
image_all = image_all[:, :, 0]
Image.fromarray(np.uint8(image_all)).save(filename)


def train():
generator.train()
discriminator.train()
for epoch in range(epochs):
start_time = time.time()
for i, (images, labels) in enumerate(data):
batch_size = images.shape[0]

valid = jt.ones([batch_size, 1]).float32().stop_grad()
fake = jt.zeros([batch_size, 1]).float32().stop_grad()

# train generator
noise = jt.array(np.random.randn(batch_size, 100))
fake_labels = jt.array(np.random.randint(0, 10, batch_size))
fake_images = generator(noise, fake_labels)
validity = discriminator(fake_images, fake_labels)
generator_loss = loss(validity, valid)
generator_loss.sync()
generator_opt.step(generator_loss)

# train discriminator
# real images
real_images = jt.array(images)
labels = jt.array(labels)
real_validity = discriminator(real_images, labels)
real_loss = loss(real_validity, valid)
# fake images
noise = jt.array(np.random.randn(batch_size, 100))
fake_validity = discriminator(fake_images.stop_grad(), fake_labels)
fake_loss = loss(fake_validity, fake)
discriminator_loss = (real_loss + fake_loss) / 2
discriminator_loss.sync()
discriminator_opt.step(discriminator_loss)

if i % 50 == 0:
print('Epoch %d Batch %d Loss g %f d %f' % (epoch, i, generator_loss.data, discriminator_loss.data))

print('Epoch %d done Time %s' % (epoch, time.time() - start_time))

generator.save('saved/generator_%d.pkl' % epoch)
discriminator.save('saved/discriminator_%d.pkl' % epoch)
save_image('saved/gen_%d.png' % epoch)


def generate(numbers, epoch, filename):
generator_g = Generator()
generator_g.eval()
generator_g.load('saved/generator_%d.pkl' % epoch)

num_labels = len(numbers)
noise = jt.array(np.random.randn(num_labels, 100)).float32().stop_grad()
labels = jt.array(numbers).float32().stop_grad()
images = generator_g(noise, labels).numpy()
image_all = []
for i in range(num_labels):
image_all.append(images[i])
image_all = np.concatenate(image_all, 2)
image_all = (image_all + 1) / 2 * 255
image_all = image_all.transpose(1, 2, 0)
image_all = image_all[:, :, 0]
Image.fromarray(np.uint8(image_all)).save(filename)


if __name__ == '__main__':
# train()
generate([1, 2, 3, 3, 2, 1], 87, 'result.png')

Loading…
Cancel
Save