#
# Import the Necessary Libraries
#
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.models import Sequential
import numpy as np
#
# Load and Preprocess the Data
# - Load the MNIST dataset and preprocess it.
# - Preprocessing involves normalizing the data that can improve models' performance.
#
(X_train, _), (_, _) = mnist.load_data()
# Normalize to between -1 and 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
#
# Build the Generator and Discriminator
# - Define the generator and discriminator models.
# - Generator takes a random noise vector as input and outputs an image.
# - Discriminator takes an image as input and outputs the probability of the image being real.
#
# Generator
def create_generator():
model = Sequential()
model.add(Dense(256, input_dim=100, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(1024, activation='relu'))
model.add(Dense(784, activation='tanh'))
model.add(Reshape((28, 28, 1)))
return model
# Discriminator
def create_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(1024, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
return model
#
# Compile the Models
# - Compile the models, which involves defining the loss function and the optimizer.
# - The loss function evaluates the model's performance, while the optimizer aims to minimize the loss.
#
from tensorflow.keras.models import Sequential, Model
# Create and compile the discriminator
discriminator = create_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
# Create and compile the generator
generator = create_generator()
generator.compile(loss='binary_crossentropy', optimizer='adam')
# Create and compile the combined model
discriminator.trainable = False
gan_input = Input(shape=(100,))
x = generator(gan_input)
gan_output = discriminator(x)
gan = Model(inputs=gan_input, outputs=gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
#
# Train the Models
# - Train the model, which involves feeding data into the models and adjusting the weights of the models based on the output.
# - The primary aim is for the generator to create images indistinguishable from real images by the discriminator.
#
def train(epochs=1, batch_size=128):
# Load the data
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# Labels for the batch size and the test size
y_train_ones = np.ones((batch_size, 1))
y_train_zeros = np.zeros((batch_size, 1))
y_test_ones = np.ones((100, 1))
# Start training
for e in range(epochs):
for i in range(X_train.shape[0] // batch_size):
# Train Discriminator weights
discriminator.trainable = True
# Real samples
X_batch = X_train[i*batch_size:(i+1)*batch_size]
d_loss_real = discriminator.train_on_batch(x=X_batch, y=y_train_ones * (1 - 0.1 * np.random.rand(batch_size, 1)))
# Fake Samples
z_noise = np.random.normal(loc=0, scale=1, size=(batch_size, 100))
X_fake = generator.predict_on_batch(z_noise)
d_loss_fake = discriminator.train_on_batch(x=X_fake, y=y_train_zeros)
# Discriminator loss
d_loss = 0.5 * (d_loss_real + d_loss_fake)
# Train Generator weights
discriminator.trainable = False
g_loss = gan.train_on_batch(x=z_noise, y=y_train_ones)
print(f'Epoch: {e+1}, Batch: {i}, D Loss: {d_loss}, G Loss: {g_loss}')
#
# Execute the Training
#
# Call the train function
train(epochs=50, batch_size=128)
# Epochs parameter determines how many times the learning algorithm will work through the entire training dataset.
# The `batch_size` is the number of samples that will be propagated through the network at a time.
#
# Generate New Images and Evaluate the Model's Performance
# - Generate new images and evaluate the performance of the GAN.
# - Generate a random noise vector and feed it into the trained generator to create new images.
#
import matplotlib.pyplot as plt
# Generate random noise as an input to initialize the generator
random_noise = np.random.normal(0,1, [100, 100])
# Generate the images from the noise
generated_images = generator.predict(random_noise)
# Visualize the generated images
plt.figure(figsize=(10,10))
for i in range(generated_images.shape[0]):
plt.subplot(10, 10, i+1)
plt.imshow(generated_images[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()