본문 바로가기
Paper Review/Style Transfer

FUNIT : Few-Shot Unsupervised Image-to-Image Translation (2019)

by HanByol Jang 2021. 10. 31.
728x90

기존의 딥러닝 기반 학습들의 문제점은, 학습에 사용하지 않은 데이터를 테스트 하려할때 제대로 되지 않는다는 한계점이 있습니다. Few-Shot Unsupervised Image-to-Image Translation은 이러한 한계점을 극복하고자 학습에 사용하지 않은 class 혹은 적은 수의 테스트 이미지만 가지고 있을때도 자연스러운 변환이 되도록 구현한 논문입니다. 

https://arxiv.org/abs/1905.01723


FUNIT Framework

FUNIT Training & Deployment

위 figure는 이 논문의 학습과 테스트시의 framework입니다. 먼저, source class라고 불리는 다양한 class의 이미지들을 준비합니다. 물론 모두 정확히 같은 포즈의 이미지들이 아니고 다양한 포즈, 다양한 종들의 강아지 사진들입니다. 그리고 테스트시에 사용하는 이미지는 Target class로 이는 학습에 사용하지 않은 굉장히 적은 수의 데이터들입니다.

학습시에는 source class에서 하나의 input image를 선택하고 해당 이미지의 class가 아닌 다른 class이미지를 선택해서 해당 class로 변환하도록 학습을 진행합니다. 그리고 테스트시에는 source class에 속하지 않는 target class로의 변환을 시도합니다.


Few-shot Image Translator

 

Generator Achitecture

위의 figure처럼 먼저 content image x를 통해 다수의 convolution layer를 통과시켜 spatial feture map을 나타내는 content code zx를 추출합니다. 그리고 content image 이외의 class image들은 다수의 convolution layer를 통과시킨 뒤 average polling을 통해 vector를 만들고 평균을 취해 class code zy를 생성합니다.

그리고 decoder는 이전 논문인 MUNIT의 방식처럼 content 이미지에 style을 입혀주는 AdaIN 방법을 통해 output 이미지를 생성합니다. 

https://hanstar4.tistory.com/12

 

MUNIT : Multimodal Unsupervised Image-to-Image Translation

Multimodal Unsupervised Image-to-Image Translation는 UNIT 논문을 발전한 형태로 unimodal이 아닌 multimodal로의 변환이 가능하도록 구현한 논문이다. UNIT와 마찬가지로 unpaired한 데이터들간의 변환이고,..

hanstar4.tistory.com

Content code는 pose같은 local한 부분들을 담당하고 class code는 global한 이미지의 특징들을 담당합니다.

이러한 네트워크 구조를 통해 테스트시에 학습에 사용하지 않는 class의 이미지가 들어와도 few-shot test image들의 class 특징을 추출해 content image에 입혀줌으로서 target image처럼 만들수 있게 되는 것입니다.


Multi-task Adversarial Discriminator

Source class가 multi이다보니 discriminator 역시 multi-class classification으로 진행되고, 각각에 class에 대해 binary로 진행합니다. 그리고 discriminator가 업데이트가 되는 시점은, x class에 속한 real image를 넣었을때 False라고 결정하거나, fake image를 넣었을때 True라고 결정할때만 페널티를 주고 discriminator를 업데이트합니다. 여기서 주의할 점은 x class의 classification 결과만 영향을 준다는 것입니다. 다른 class의 classification을 틀렸다 할지라도 이는 영향을 주지 않습니다.

generator 역시 해당 class의 결과만 영향을 받아 업데이트합니다.

 


Loss Function

FUNIT 모델의 학습에는 총 3가지의 Loss function이 사용됩니다.

 

1. Adversarial Loss (GAN Loss)

gan loss

일반적인 GAN loss의 형태이며, 앞서 설명했듯이 source class에 대해 해당 class에 대한 classification 결과만 반영하여 loss를 계산합니다.

 

2. Reconstruction Loss

reconstruction loss

두번째는 reconstruction loss로서, 이는 content image와 soucre class image가 같은 이미지가 들어갔을 경우입니다. socure class image와 content image가 같기 때문에 generator는 동일한 content image를 만들어야합니다. 위의 loss는 generator가 얼마나 자기자신을 그대로 복원해내는지에 대한 loss입니다.

 

3. Feature Matching Loss

feature matching loss

세번째는 feature matching loss로서, 이는 discriminator의 prediction 전 마지막 layer의 feature간의 차이를 계산한 loss입니다. 이 loss를 통해 discriminator는 단순하게 class가 실제냐 아니냐를 맞추는 정도만 학습하는 것이 아닌, 정말 같은 feature를 만들고 그 feature를 가지고 class 판단을 할 수 있도록 유도하는 역할을 합니다.

 

4. Total Loss

FUNIT loss

위 세가지의 loss를 합하여 최종적으로 FUNIT 모델은 학습을하게 됩니다.


Experiment Results

x content image를 통해 y1,y2 처럼 바뀐 x'의 모습을 볼 수 있습니다.

물론 위처럼 드라마틱하게 다른 형태의 이미지로의 변환을 시도하면 경과가 좋지 않은 것을 볼 수 있습니다. 이는 모든 image-to-image translation에서의 한계점으로 남아있습니다.

728x90

댓글