DCGANでガンプラの箱を生成
はじめに
n番煎じではありますが、GANを用いた画像の生成に挑戦していきます。今回使う画像はガンプラの箱の画像です。
DCGANとは
GAN(Generative adversarial network)は生成モデルの一種です。Generator(G)とDiscriminator(D)の二つのネットワークが互いに相反する目的の元学習していき、与えられた画像によく似た画像を生成します。GはDが見分けもつかないような精巧な画像を生成すること、Dは本物の画像と偽物の(生成された)画像を正しく見分けることを目的とします。
DCGANの場合、Gは一様分布、または正規分布から生成されたノイズを入力として受け取り、転置畳み込み(fractionally-strided convolutions)を行って画像を生成します。Dはいわゆる普通のCNNのような識別器です。
データセット
今回使うデータセットはガンプラの箱画像です。以下のような画像をネットから拾ってきました。だいたい160枚くらいです。
これらの画像を128x128にリサイズします。遠目で見ればわかる程度のぼやけ具合ですかね笑
実装
以下がGeneratorの実装です。実装には以下の記事やGithubのリポジトリを参考にしました。
GANについて概念から実装まで ~DCGANによるキルミーベイベー生成~ - Qiita
画像のサイズに関わるところを変更しました。また、Google Colaboratoryで実行したところメモリ不足によるエラーを吐いたので、最初のチャンネル数の箇所を128 → 32に変更しました。
def generator_model(): model = Sequential() model.add(Dense(1024, input_shape=(100,))) model.add(Activation('relu')) model.add(Dense(32 * 32 * 32)) model.add(BatchNormalization()) model.add(Activation('relu')) model.add(Reshape((32, 32, 32))) model.add(UpSampling2D(size=(2, 2))) model.add(Conv2D(64, (5, 5), padding='same')) model.add(Activation('relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) model.add(Conv2D(32, (5, 5), padding='same')) model.add(Activation('relu')) model.add(BatchNormalization()) #model.add(UpSampling2D(size=(2, 2))) model.add(Conv2D(n_colors, (5, 5), padding='same')) model.add(Activation('tanh')) return model
次にDiscriminatorの実装です。
def discriminator_model(): model = Sequential() model.add(Conv2D(32, (5, 5), strides=2, input_shape=(128, 128, n_colors), padding='same')) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Conv2D(64, (5, 5), strides=2, padding="same")) model.add(ZeroPadding2D(padding=((0, 1), (0, 1)))) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(128, (5, 5), strides=2, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(256, (5, 5), strides=1, padding="same")) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(1, activation="sigmoid")) return model
結果
epoch: 0
epoch: 1000
epoch: 10000
epoch: 30000
うーん、ガンプラのガの字も出てこない感じですね・・・
やはり128x128という大きめの画像に対してデータ数160は少なすぎたのか、安易にネットワークの構造を変えたのがよくなかったのか。
ひとまずデータ数を増やして再実行したいと思います。また追記します。