機械学習好きのメモ帳

MMのメモ帳

機械学習、データ分析、アルゴリズムなどを扱っていきます。雑記も含まれます。

DCGANでガンプラの箱を生成

はじめに

n番煎じではありますが、GANを用いた画像の生成に挑戦していきます。今回使う画像はガンプラの箱の画像です。

DCGANとは

GAN(Generative adversarial network)は生成モデルの一種です。Generator(G)とDiscriminator(D)の二つのネットワークが互いに相反する目的の元学習していき、与えられた画像によく似た画像を生成します。GはDが見分けもつかないような精巧な画像を生成すること、Dは本物の画像と偽物の(生成された)画像を正しく見分けることを目的とします。
DCGANの場合、Gは一様分布、または正規分布から生成されたノイズを入力として受け取り、転置畳み込み(fractionally-strided convolutions)を行って画像を生成します。Dはいわゆる普通のCNNのような識別器です。

f:id:mugimike:20190302143222p:plain
https://medium.freecodecamp.org/an-intuitive-introduction-to-generative-adversarial-networks-gans-7a2264a81394

データセット

今回使うデータセットガンプラの箱画像です。以下のような画像をネットから拾ってきました。だいたい160枚くらいです。
f:id:mugimike:20190302144635j:plain

これらの画像を128x128にリサイズします。遠目で見ればわかる程度のぼやけ具合ですかね笑
f:id:mugimike:20190302144907j:plain

実装

以下がGeneratorの実装です。実装には以下の記事やGithubリポジトリを参考にしました。
GANについて概念から実装まで ~DCGANによるキルミーベイベー生成~ - Qiita

image-generator-with-keras-dcgan/keras-dcgan.py at master · elm200/image-generator-with-keras-dcgan · GitHub

画像のサイズに関わるところを変更しました。また、Google Colaboratoryで実行したところメモリ不足によるエラーを吐いたので、最初のチャンネル数の箇所を128 → 32に変更しました。

def generator_model():
  model = Sequential()
  
  model.add(Dense(1024, input_shape=(100,)))
  model.add(Activation('relu'))
  
  model.add(Dense(32 * 32 * 32))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  
  model.add(Reshape((32, 32, 32)))
  model.add(UpSampling2D(size=(2, 2)))
  model.add(Conv2D(64, (5, 5), padding='same'))
  model.add(Activation('relu'))
  model.add(BatchNormalization())
  
  model.add(UpSampling2D(size=(2, 2)))
  model.add(Conv2D(32, (5, 5), padding='same'))
  model.add(Activation('relu'))
  model.add(BatchNormalization())
  
  #model.add(UpSampling2D(size=(2, 2)))
  model.add(Conv2D(n_colors, (5, 5), padding='same'))
  model.add(Activation('tanh'))
  return model

次にDiscriminatorの実装です。

def discriminator_model():
  model = Sequential()
  
  model.add(Conv2D(32, (5, 5), strides=2, input_shape=(128, 128, n_colors), padding='same'))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  
  model.add(Conv2D(64, (5, 5), strides=2, padding="same"))
  model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  model.add(BatchNormalization(momentum=0.8))
  
  model.add(Conv2D(128, (5, 5), strides=2, padding="same"))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  model.add(BatchNormalization(momentum=0.8))

  model.add(Conv2D(256, (5, 5), strides=1, padding="same"))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  
  model.add(Flatten())
  model.add(Dense(1, activation="sigmoid"))
  
  return model

結果

epoch: 0
f:id:mugimike:20190302150546j:plain

epoch: 1000
f:id:mugimike:20190302150634j:plain

epoch: 10000
f:id:mugimike:20190302150732j:plain

epoch: 30000
f:id:mugimike:20190302150757j:plain

うーん、ガンプラのガの字も出てこない感じですね・・・
やはり128x128という大きめの画像に対してデータ数160は少なすぎたのか、安易にネットワークの構造を変えたのがよくなかったのか。
ひとまずデータ数を増やして再実行したいと思います。また追記します。