내가 현재 수강하고 있는 부트캠프에서 최종 프로젝트로 super resolution을 할 수 있는 모델을 GAN으로 구현해 웹 애플리케이션을 만들기로 했다. 나 같은 경우는 GAN도 직접 구현해본적이 없었고 super resolution 자체를 머신러닝으로 시도해본 적이 없었다. 그래서 학습을 위해 SRGAN을 직접 코드로 구현해보기로 했다.
SRGAN이란 무엇인가?
SRGAN은 super resolution을 GAN 프레임워크를 빌려 구현한 딥러닝 모델 중 하나이다. 일단 GAN이란게 무엇이란 말인가... GAN은 Generative Adversarial Network의 약자로, 한국말론 적대적 생성 신경망이라는 뜻이다. GAN은 말 그대로 무언가를 '생성'하는데 특화되어 있다. 다양한 데이터나 이미지, 텍스트 등을 GAN을 통해 생성할 수 있다. 이런 특징을 이용해서 GAN은 예술, 학습데이터 증강, 이미지 보정 등 다양한 Task에 활용된다. GAN의 또 다른 특징은 두 종류의 모델을 동시에, 그리고 적대적으로 학습시킨다는 것이다. 이 두 종류의 모델을 주로 Generator와 Discriminator라고 부른다. Generator는 데이터를 생성하는 네트워크이고 Discriminator는 데이터가 원본 데이터인지 생성된 데이터인지 판별하는 네트워크이다. Adversarial loss라는 GAN만의 특별한 loss function을 통해, Generator는 Discriminator를 속이는 방향으로 학습하고 Discriminator는 원본을 더 잘 구별하는 방법으로 학습한다. Generator와 Discriminator가 균형을 잘 잡을 수 있도록 학습시키면 Generator는 사람이 보기에 정말 그럴듯한 데이터들을 생성할 수 있게 된다.
SRGAN은 super resolution task를 GAN을 이용해서 풀어냈다. SRGAN이 등장하기 전에 저해상도의 이미지를 고해상도 이미지로 바꾸기 위한 state of arts 기법은 SRResNet이었다. SRResNet은 ResNet 구조를 이용한 super resolution 네트워크이다. 이 모델은 metric 스코어 측면에서 성능이 좋았지만 픽셀이 얼버무려지는 모습이 어색한 부분이 있어서 일반 사람이 보기에 아쉬운 정도의 수준이었다. SRGAN은 metric을 사용한 성능 측정을 벗어나서, 일반 사람들이 보기에도 자연스러운 super resolution 작업을 위해서 SRResNet에 GAN 구조를 도입한 모델이다.
모델의 구조
이제까지 내가 구현해본 딥러닝 모델과 달리 서로 다른 모델 두개를 사용하는 모델이라서 조금 낯설긴 했다. 나로써는 SRGAN으로 GAN 구조를 처음 접해본 것이니 학습 과정도 신기했다. 나는 우선 논문 본문에 있는 모델 그림을 참고해서 모듈 클래스를 만들기로 했다. 처음엔 다른 코드를 참고하지 않고 내 힘으로 모델을 구현했는데 반복되는 레이어가 많아서 코드를 짜기가 꽤 번거로웠다. 특히 generator의 경우에는 residual layer가 많아서 구조가 복잡스러웠다. 근데 추후에 깃허브에 올라있던 코드들을 참고해보니 residual layer를 만드는 함수를 내부 메소드로 구현해두고 for문을 이용해 residual block을 만들고 있었다(...) 아무래도 이런게 모델 구현의 노하우라는 것인가 보다. 아래는 내가 직접 만든 조금은 부끄러운 코드이다. 다른 코드를 참고한 후 고쳐볼까 했으나 내가 직접 짠 코드를 남겨 놓는게 더 좋지 않을까 싶어 그냥 놔뒀다.
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.input_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4),
nn.PReLU()
)
self.resid_block1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64))
self.resid_block2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64))
self.resid_block3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64))
self.resid_block4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64))
self.resid_block5 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64))
self.output_layer1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64))
self.output_layer2 = nn.Sequential(nn.Conv2d(64, 256, kernel_size=3, padding=1),
nn.PixelShuffle(2),
nn.PReLU(),
nn.Conv2d(64, 256, kernel_size=3, padding=1),
nn.PixelShuffle(2),
nn.PReLU(),
nn.Conv2d(64, 3, kernel_size=3, padding=1))
def forward(self, x):
x = self.input_layer(x)
r1 = self.resid_block1(x) + x
r2 = self.resid_block2(r1) + r1
r3 = self.resid_block3(r2) + r2
r4 = self.resid_block4(r3) + r3
r5 = self.resid_block5(r4) + r4
output = self.output_layer1(r5) + x
output = self.output_layer2(output)
return output
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.output_shape = (1, int(512 / (2 ** 4)), int(512 / (2 ** 4)))
self.input_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU())
self.conv_block1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU())
self.conv_block2 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU())
self.conv_block3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU())
self.conv_block4 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU())
self.conv_block5 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU())
self.conv_block6 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU())
self.output_layer = nn.Sequential(nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1))
def forward(self, x):
x = self.input_layer(x)
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.conv_block4(x)
x = self.conv_block5(x)
x = self.conv_block6(x)
x = self.output_layer(x)
return x
이 모델이 논문의 모델과 다른 점은, 논문 모델의 경우 discriminator 모델의 마지막 레이어는 fc layer인데 반해, 나는 마지막 레이어를 CNN layer로 처리했다. 이것은 patchGAN이라는 모델의 구조를 이용한 것으로, fc layer가 유발하는 패러미터 수 폭발을 피할 수 있다. patchGAN은 마지막에 출력되는 값들을 진짜 이미지일 경우 1로 채워진 텐서, 생성된 이미지일 경우 0으로 채워진 텐서와 MSE를 구하여 모델 패러미터를 업데이트 한다.
모델 학습
GAN 모델은 학습 과정도 좀 특이한데 optimizer 객체를 두 개 사용한다. 하나는 generator를 업데이트하는 optimizer, 하나는 discriminator를 업데이트하는 optimizer이다. 이들을 이용해 1 epoch당 generator와 discriminator를 한번씩 패러미터를 업데이트 하게 된다. 나중에 안 사실이지만 GAN 학습에는 다양한 방법이 있는 모양이다. 다른 레퍼런스들을 참고해보니 discriminator를 어느 정도 학습하고 나서 generator와 같이 학습이 들어가기도 하고 generator를 pre-trained 모델로 사용하기도 했다. 논문의 경우는 pre-trained SRResNet을 사용했다. 나는 그냥 1 epoch당 두 모델을 동시에 업데이트 하는 방향으로 학습시켰다.
def train_model(dataloader, generator, discriminator, feature_extractor, epochs, learning_rate=1e-4):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator.to(device)
discriminator.to(device)
feature_extractor.to(device)
feature_extractor.eval()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
content_criterion = nn.MSELoss()
adv_criterion = nn.MSELoss()
discriminator_criterion = nn.MSELoss()
for epoch in range(epochs):
for i, imgs in enumerate(dataloader):
LR = imgs['LR'].to(device)
HR = imgs['HR'].to(device)
# discriminator labels
real = torch.tensor(()).new_ones((LR.size()[0], *discriminator.output_shape), requires_grad=False).to(device)
fake = torch.tensor(()).new_zeros((LR.size()[0], *discriminator.output_shape), requires_grad=False).to(device)
# Generator 학습
optimizer_G.zero_grad()
generated_image = generator(LR)
# content loss
gen_features = feature_extractor(generated_image)
real_features = feature_extractor(HR)
content_loss = content_criterion(gen_features, real_features.detach())
# Adversarial loss
adv_loss = adv_criterion(discriminator(generated_image), real)
# Generator loss
generator_loss = 1e-3 * content_loss + adv_loss
generator_loss.backward()
optimizer_G.step()
# Discriminator 학습
optimizer_D.zero_grad()
real_loss = discriminator_criterion(discriminator(HR), real)
fake_loss = discriminator_criterion(discriminator(generated_image.detach()), fake)
discriminator_loss = (real_loss + fake_loss) / 2
discriminator_loss.backward()
optimizer_D.step()
# print loss
print(f'EPOCH: {epoch + 1}, BATCH: {i + 1}, Discriminator loss: {discriminator_loss.item()}, Generator loss: {generator_loss.item()}')
if (epoch+1) % 5 == 0:
torch.save(generator.state_dict(), f'saved_model/generator_{epoch + 1}.pth')
torch.save(discriminator.state_dict(), f'saved_model/discriminator_{epoch + 1}.pth')
return generator, discriminator
중간에 나오는 feature extractor의 경우 SRGAN만의 특징인 VGG loss를 구하기 위한 VGG 모델이다. SRGAN은 VGG loss와 adversarial loss를 특정 비율로 반영해 패러미터를 업데이트 한다.
후기
처음으로 논문을 처음부터 끝까지 내 힘으로 구현해봤고 모듈 구조로 여러 파일로 나누어 작성해본 것도 처음이다. pytorch에 익숙해진지도 얼마 되지 않았는데 그래도 내 힘으로 이런 코드를 짜봤다는게 뿌듯하긴 하다. 앞으로 여러 논문들의 모델들을 직접 구현해보고 싶다.
'Project' 카테고리의 다른 글
[프로젝트]모델의 inference time을 줄이기 위한 발버둥 (0) | 2022.04.19 |
---|---|
KoGPT2를 활용해 K-유튜브 제목을 생성해보자 - 2편 파인튜닝 (0) | 2022.03.06 |
KoGPT2를 활용해 K-유튜브 제목을 생성해보자 - 1편 EDA (0) | 2022.02.24 |