import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt

def build_generator():
    model = Sequential([
        Dense(128 * 7 * 7, input_dim=100),
        LeakyReLU(alpha=0.2),
        Reshape((7, 7, 128)),
        Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(1, kernel_size=7, padding='same', activation='tanh')
    ])
    return model

def build_discriminator():
    model = Sequential([
        Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(28, 28, 1)),
        LeakyReLU(alpha=0.2),
        Conv2D(128, kernel_size=3, strides=2, padding='same'),
        LeakyReLU(alpha=0.2),
        Flatten(),
        Dense(1, activation='sigmoid')
    ])
    return model

def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = Sequential([generator, discriminator])
    return model

# Hyperparameters
lr = 0.0002
batch_size = 64
epochs = 10

# Load MNIST dataset
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1).astype('float32')
x_train = (x_train - 127.5) / 127.5

# Build and compile models
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)

discriminator.compile(optimizer=tf.keras.optimizers.Adam(lr), loss='binary_crossentropy', metrics=['accuracy'])
gan.compile(optimizer=tf.keras.optimizers.Adam(lr), loss='binary_crossentropy')

# Training loop
for epoch in range(epochs):
    for _ in range(len(x_train) // batch_size):
        noise = np.random.randn(batch_size, 100)
        fake_images = generator.predict(noise)
        real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
        labels_real = np.ones((batch_size, 1))
        labels_fake = np.zeros((batch_size, 1))
        
        d_loss_real = discriminator.train_on_batch(real_images, labels_real)
        d_loss_fake = discriminator.train_on_batch(fake_images, labels_fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        noise = np.random.randn(batch_size, 100)
        g_loss = gan.train_on_batch(noise, labels_real)
        
    print(f'Epoch [{epoch+1}/{epochs}], D Loss: {d_loss[0]}, D Acc: {d_loss[1]}, G Loss: {g_loss}')
    
    if (epoch + 1) % 5 == 0:
        gen_images = generator.predict(np.random.randn(25, 100))
        gen_images = (gen_images + 1) / 2.0
        plt.figure(figsize=(5, 5))
        for i in range(25):
            plt.subplot(5, 5, i+1)
            plt.imshow(gen_images[i].reshape(28, 28), cmap='gray')
            plt.axis('off')
        plt.savefig(f'stylegan_images_{epoch+1}.png')
        plt.close()