How to do Novelty Detection in Keras with Generative Adversarial Network (Part 2)

(Comments)

cornfield2

Previous part introduced how the ALOCC model for novelty detection works along with some background information about autoencoder and GANs, and in this post, we are going to implement it in Keras.

It is recommended to have a general understanding of how the model works before continuing. You can read part 1 here, How to do Novelty Detection in Keras with Generative Adversarial Network (Part 1)


Download the source code from my GitHub.

Building the model

ALOCC

Let's start with the R network as shown in the image above. The model is implemented in Keras functional API.

image = Input(shape=input_shape, name='z')
# Encoder.
x = Conv2D(filters=self.df_dim * 2, kernel_size = 5, strides=2, padding='same', name='g_encoder_h0_conv')(image)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2D(filters=self.df_dim * 4, kernel_size = 5, strides=2, padding='same', name='g_encoder_h1_conv')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2D(filters=self.df_dim * 8, kernel_size = 5, strides=2, padding='same', name='g_encoder_h2_conv')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

# Decoder.
x = Conv2D(self.gf_dim*1, kernel_size=5, activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(self.gf_dim*1, kernel_size=5, activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(self.gf_dim*2, kernel_size=3, activation='relu')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(self.c_dim, kernel_size=5, activation='sigmoid', padding='same')(x)
generator = Model(image, x, name='R')

Some key points worth mentioning.

  • To have improved stability of the network, we use using strided convolutions instead of pooling layers in this network.
  • After each convolutional layer, a batch normalization operation is exploited, which adds stability to our structure. To learn more, you can refer to my post dedicated to the topic, One simple trick to train Keras model faster with Batch Normalization.
  • UpSampling layers are adopted instead of Keras' Conv2DTranspose to reduce generated artifacts and make output shape more deterministic.
  • We recommend using a LeakyReLU layer instead of a ReLU activation. It is similar to ReLU, but it relaxes sparsity constraints by allowing small negative activation values.

The architecture for D or discriminator is a sequence of convolutional layers, which are trained to eventually distinguish the novel or outlier samples, without any supervision.

image = Input(shape=input_shape, name='d_input')
x = Conv2D(filters=self.df_dim, kernel_size = 5, strides=2, padding='same', name='d_h0_conv')(image)
x = LeakyReLU()(x)

x = Conv2D(filters=self.df_dim*2, kernel_size = 5, strides=2, padding='same', name='d_h1_conv')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Conv2D(filters=self.df_dim*4, kernel_size = 5, strides=2, padding='same', name='d_h2_conv')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Conv2D(filters=self.df_dim*8, kernel_size = 5, strides=2, padding='same', name='d_h3_conv')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Flatten()(x)
x = Dense(1, activation='sigmoid', name='d_h3_lin')(x)

D network outputs a single floating point number ranges between 0~1 relative to the likelihood of the input belongs the target class.

Training the model

For simplicity and reproducible reason, we choose to teach the model to recognize the MNIST handwritten digit labeled "1" as the target or normal images, while the model will be able to distinguish other digits as novelties/anomaly at test phase.

We train the R+D neural networks in an adversarial procedure.

When training the D network, it is exposed to both the reconstructed and the original images as inputs where their outputs are labeled as 0s and 1s respectively. D network learns to discern the real vs. the generated images by minimizing the binary_crossentropy loss for those two types of data.

When training the R network, statistical noise sampled from standard deviation is added to the input to make R robust to noise and distortions in the input images. That is what the η stands for in the previous image. R is trained to jointly reduce reconstruction loss and the "fooling R network to output target class" loss. There is a trade-off hyperparameter that controls the relative importance of the two terms.

The following code constructs and connects discriminator and generator modules.

Notice that before compiling the combined adversarial_model,  we set the discriminator's weights to be non-trainable since for the combined model we only want to train the generator as you will discover shortly. It won't prevent the already compiled discriminator model from training.  Also, self.r_alpha is a small floating point number to trade-off the relative importance of the two generator/R network losses.

image_dims = [self.input_height, self.input_width, self.c_dim]
optimizer = RMSprop(lr=0.002, clipvalue=1.0, decay=1e-8)
# Construct discriminator/D network takes real image as input.
# D - sigmoid and D_logits -linear output.
self.discriminator = self.build_discriminator(image_dims)

# Model to train D to discrimate real images.
self.discriminator.compile(optimizer=optimizer, loss='binary_crossentropy')

# Construct generator/R network.
self.generator = self.build_generator(image_dims)
img = Input(shape=image_dims)

reconstructed_img = self.generator(img)

self.discriminator.trainable = False
validity = self.discriminator(reconstructed_img)

# Model to train Generator/R to minimize reconstruction loss and trick D to see
# generated images as real ones.
self.adversarial_model = Model(img, [reconstructed_img, validity])
self.adversarial_model.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
    loss_weights=[self.r_alpha, 1],
    optimizer=optimizer)

With the model constructed and compiled, we can start the training.

Firstly, only the "1"s in the MNIST training sets are extracted, a statistical noise is applied to a copy of the "1"s for the generator / R input.

Here is the code to train one batch of data. The D network is trained first on real and generated images with different output labels.

Then the R network trains twice on the same batch of noisy data to minimize its losses.

batch_fake_images = self.generator.predict(batch_noise_images)

# Update D network, minimize real images inputs->D-> ones, noisy z->R->D->zeros loss.
d_loss_real = self.discriminator.train_on_batch(batch_images, ones)
d_loss_fake = self.discriminator.train_on_batch(batch_fake_images, zeros)

# Update R network twice, minimize noisy z->R->D->ones and reconstruction loss.
self.adversarial_model.train_on_batch(batch_noise_images, [batch_clean_images, ones])
g_loss = self.adversarial_model.train_on_batch(batch_noise_images, [batch_clean_images, ones])

One final tip for the output g_loss variable, since the combined adversarial_model was compiled with two loss functions and no additional metrics, g_loss will be a list of 3 numbers, [total_weighted_loss, loss_1, loss_2], where loss_1 is the reconstruction loss, and loss_2 is the "fooling R network loss". Training a GAN network longer general produce better result, while in our case stopping the training too early leads to immature learned network weights while overtraining the networks confuses the R network and yields undesirable outputs. We must define an appropriate training stop criterion.

The author proposed the training procedure stops when R can reconstruct its input with the minimum error which can be monitored by keeping track of loss_1 /the reconstruction loss.

Novelty Detection

The following graph shows the R network reconstruction loss during the training phase of 5 epochs, looks like the reconstruction loss reaches its minimal at the end of epoch 3, so let's use the model weights saved after epoch 3 for our novelty detection. You can download and run the test phase Jupyter notebook test.ipynb from my GitHub repository.

plot_g_recon_losses

We can test the reconstruction loss and discriminator output. A novel/abnormal image has a larger reconstruction loss and smaller discriminator output value shown below. Where the image of handwritten "1" is the target, and other numbers are novel/abnormal cases.

reconstruction-examples

Conclusion and further reading

We covered how to build a novelty detection ALOCC model implemented in Keras with generative adversarial network and encoder-decoder network.

Check out the original paper: https://arxiv.org/abs/1802.09088.

Here is an interesting Q&A on Quora about whether GAN can do outlier/novelty detection answered by GAN's creator - Ian Goodfellow.

Don't forget to download the source code from my GitHub.

Current rating: 4.8

Comments