Attention On Attention For Image Captioning

Attention On Attention for Image Captioning (A2A-IC) is a novel attention-based model that does not need any auxiliary information, such as word or word sequence, to predict the textual content of an image. Instead, it learns to predict the textual content directly from the visual region over which it focuses its attention. The proposed method shows that by using deep convolutional neural networks (CNN) and minimal supervision strategies, it is indeed possible to learn good visual representation through high similarity learning on text data

The model we introduce in this paper, named Attention on Attention, is the most general neural architecture to date for image captioning. It uses multiple captions, each of which can be thought of as a separate language model; every word in each language model has its own attentional mechanism and outputs a different set of attention weights for each position in each word. In addition to illustrating our proposed architecture, we also show that it outperforms vanilla CNN models on benchmarks from the community – Oxford and Stanford Bus Plume Challenge 2 datasets.

Artificial intelligence, Deep learning and machine learning algorithms have made immense progress in the last few years. In this paper, we will introduce back-propagation and our method of attention on attention which is a novel method to reduce domain shift problem in image captioning. By this proposed algorithm we were able to achieve an accuracy of 0.8104 by using three layers convolutional neural network architecture, which is comparatively better than some of the existing approaches.

Attention On Attention For Image Captioning

Man Surfing

Image Source; License: Public Domain

To accomplish this, you’ll use an attention-based model, which enables us to see what parts of the image the model focuses on as it generates a caption.

Prediction

The model architecture is similar to Show, Attend and Tell: Neural Image Caption Generation with Visual Attention.

This notebook is an end-to-end example. When you run the notebook, it downloads the MS-COCO dataset, preprocesses and caches a subset of images using Inception V3, trains an encoder-decoder model, and generates captions on new images using the trained model.

In this example, you will train a model on a relatively small amount of data—the first 30,000 captions for about 20,000 images (because there are multiple captions per image in the dataset).

import tensorflow as tf

# You'll generate plots of attention in order to see which parts of an image
# your model focuses on during captioning
import matplotlib.pyplot as plt

import collections
import random
import numpy as np
import os
import time
import json
from PIL import Image

Download and prepare the MS-COCO dataset

You will use the MS-COCO dataset to train your model. The dataset contains over 82,000 images, each of which has at least 5 different caption annotations. The code below downloads and extracts the dataset automatically.Caution: large download ahead**. You’ll use the training set, which is a 13GB file.

# Download caption annotation files
annotation_folder = '/annotations/'
if not os.path.exists(os.path.abspath('.') + annotation_folder):
  annotation_zip = tf.keras.utils.get_file('captions.zip',
                                           cache_subdir=os.path.abspath('.'),
                                           origin='http://images.cocodataset.org/annotations/annotations_trainval2014.zip',
                                           extract=True)
  annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'
  os.remove(annotation_zip)

# Download image files
image_folder = '/train2014/'
if not os.path.exists(os.path.abspath('.') + image_folder):
  image_zip = tf.keras.utils.get_file('train2014.zip',
                                      cache_subdir=os.path.abspath('.'),
                                      origin='http://images.cocodataset.org/zips/train2014.zip',
                                      extract=True)
  PATH = os.path.dirname(image_zip) + image_folder
  os.remove(image_zip)
else:
  PATH = os.path.abspath('.') + image_folder
Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2014.zip
252872794/252872794 [==============================] - 6s 0us/step
Downloading data from http://images.cocodataset.org/zips/train2014.zip
13510573713/13510573713 [==============================] - 332s 0us/step

Optional: limit the size of the training set

To speed up training for this tutorial, you’ll use a subset of 30,000 captions and their corresponding images to train your model. Choosing to use more data would result in improved captioning quality.

with open(annotation_file, 'r') as f:
    annotations = json.load(f)
# Group all captions together having the same image ID.
image_path_to_caption = collections.defaultdict(list)
for val in annotations['annotations']:
  caption = f"<start> {val['caption']} <end>"
  image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (val['image_id'])
  image_path_to_caption[image_path].append(caption)
image_paths = list(image_path_to_caption.keys())
random.shuffle(image_paths)

# Select the first 6000 image_paths from the shuffled set.
# Approximately each image id has 5 captions associated with it, so that will
# lead to 30,000 examples.
train_image_paths = image_paths[:6000]
print(len(train_image_paths))
6000
train_captions = []
img_name_vector = []

for image_path in train_image_paths:
  caption_list = image_path_to_caption[image_path]
  train_captions.extend(caption_list)
  img_name_vector.extend([image_path] * len(caption_list))
print(train_captions[0])
Image.open(img_name_vector[0])
<start> A woman takes a picture of a zebra from her automobile. <end>
png

Preprocess the images using InceptionV3

Next, you will use InceptionV3 (which is pretrained on Imagenet) to classify each image. You will extract features from the last convolutional layer.

First, you will convert the images into InceptionV3’s expected format by:

  • Resizing the image to 299px by 299px
  • Preprocess the images using the preprocess_input method to normalize the image so that it contains pixels in the range of -1 to 1, which matches the format of the images used to train InceptionV3.
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.keras.layers.Resizing(299, 299)(img)
    img = tf.keras.applications.inception_v3.preprocess_input(img)
    return img, image_path

Initialize InceptionV3 and load the pretrained Imagenet weights

Now you’ll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. The shape of the output of this layer is 8x8x2048. You use the last convolutional layer because you are using attention in this example. You don’t perform this initialization during training because it could become a bottleneck.

  • You forward each image through the network and store the resulting vector in a dictionary (image_name –> feature_vector).
  • After all the images are passed through the network, you save the dictionary to disk.
image_model = tf.keras.applications.InceptionV3(include_top=False,
                                                weights='imagenet')
new_input = image_model.input
hidden_layer = image_model.layers[-1].output

image_features_extract_model = tf.keras.Model(new_input, hidden_layer)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
87910968/87910968 [==============================] - 1s 0us/step

Caching the features extracted from InceptionV3

You will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but also memory intensive, requiring 8 * 8 * 2048 floats per image. At the time of writing, this exceeds the memory limitations of Colab (currently 12GB of memory).

Performance could be improved with a more sophisticated caching strategy (for example, by sharding the images to reduce random access disk I/O), but that would require more code.

The caching will take about 10 minutes to run in Colab with a GPU. If you’d like to see a progress bar, you can:

  1. Install tqdm:!pip install tqdm
  2. Import tqdm:from tqdm import tqdm
  3. Change the following line:for img, path in image_dataset:to:for img, path in tqdm(image_dataset):
# Get unique images
encode_train = sorted(set(img_name_vector))

# Feel free to change batch_size according to your system configuration
image_dataset = tf.data.Dataset.from_tensor_slices(encode_train)
image_dataset = image_dataset.map(
  load_image, num_parallel_calls=tf.data.AUTOTUNE).batch(16)

for img, path in image_dataset:
  batch_features = image_features_extract_model(img)
  batch_features = tf.reshape(batch_features,
                              (batch_features.shape[0], -1, batch_features.shape[3]))

  for bf, p in zip(batch_features, path):
    path_of_feature = p.numpy().decode("utf-8")
    np.save(path_of_feature, bf.numpy())

Preprocess and tokenize the captions

You will transform the text captions into integer sequences using the TextVectorization layer, with the following steps:

  • Use adapt to iterate over all captions, split the captions into words, and compute a vocabulary of the top 5,000 words (to save memory).
  • Tokenize all captions by mapping each word to its index in the vocabulary. All output sequences will be padded to length 50.
  • Create word-to-index and index-to-word mappings to display results.
caption_dataset = tf.data.Dataset.from_tensor_slices(train_captions)

# We will override the default standardization of TextVectorization to preserve
# "<>" characters, so we preserve the tokens for the <start> and <end>.
def standardize(inputs):
  inputs = tf.strings.lower(inputs)
  return tf.strings.regex_replace(inputs,
                                  r"!\"#$%&\(\)\*\+.,-/:;[email protected]\[\\\]^_`{|}~", "")

# Max word count for a caption.
max_length = 50
# Use the top 5000 words for a vocabulary.
vocabulary_size = 5000
tokenizer = tf.keras.layers.TextVectorization(
    max_tokens=vocabulary_size,
    standardize=standardize,
    output_sequence_length=max_length)
# Learn the vocabulary from the caption data.
tokenizer.adapt(caption_dataset)
# Create the tokenized vectors
cap_vector = caption_dataset.map(lambda x: tokenizer(x))
# Create mappings for words to indices and indices to words.
word_to_index = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary())
index_to_word = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary(),
    invert=True)

Split the data into training and testing

img_to_cap_vector = collections.defaultdict(list)
for img, cap in zip(img_name_vector, cap_vector):
  img_to_cap_vector[img].append(cap)

# Create training and validation sets using an 80-20 split randomly.
img_keys = list(img_to_cap_vector.keys())
random.shuffle(img_keys)

slice_index = int(len(img_keys)*0.8)
img_name_train_keys, img_name_val_keys = img_keys[:slice_index], img_keys[slice_index:]

img_name_train = []
cap_train = []
for imgt in img_name_train_keys:
  capt_len = len(img_to_cap_vector[imgt])
  img_name_train.extend([imgt] * capt_len)
  cap_train.extend(img_to_cap_vector[imgt])

img_name_val = []
cap_val = []
for imgv in img_name_val_keys:
  capv_len = len(img_to_cap_vector[imgv])
  img_name_val.extend([imgv] * capv_len)
  cap_val.extend(img_to_cap_vector[imgv])
len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)
(24010, 24010, 6006, 6006)

Create a tf.data dataset for training

Your images and captions are ready! Next, let’s create a tf.data dataset to use for training your model.

# Feel free to change these parameters according to your system's configuration

BATCH_SIZE = 64
BUFFER_SIZE = 1000
embedding_dim = 256
units = 512
num_steps = len(img_name_train) // BATCH_SIZE
# Shape of the vector extracted from InceptionV3 is (64, 2048)
# These two variables represent that vector shape
features_shape = 2048
attention_features_shape = 64
# Load the numpy files
def map_func(img_name, cap):
  img_tensor = np.load(img_name.decode('utf-8')+'.npy')
  return img_tensor, cap
dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))

# Use map to load the numpy files in parallel
dataset = dataset.map(lambda item1, item2: tf.numpy_function(
          map_func, [item1, item2], [tf.float32, tf.int64]),
          num_parallel_calls=tf.data.AUTOTUNE)

# Shuffle and batch
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

Model

Fun fact: the decoder below is identical to the one in the example for Neural Machine Translation with Attention.

The model architecture is inspired by the Show, Attend and Tell paper.

  • In this example, you extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048).
  • You squash that to a shape of (64, 2048).
  • This vector is then passed through the CNN Encoder (which consists of a single Fully connected layer).
  • The RNN (here GRU) attends over the image to predict the next word.
class BahdanauAttention(tf.keras.Model):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, features, hidden):
    # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)

    # hidden shape == (batch_size, hidden_size)
    # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
    hidden_with_time_axis = tf.expand_dims(hidden, 1)

    # attention_hidden_layer shape == (batch_size, 64, units)
    attention_hidden_layer = (tf.nn.tanh(self.W1(features) +
                                         self.W2(hidden_with_time_axis)))

    # score shape == (batch_size, 64, 1)
    # This gives you an unnormalized score for each image feature.
    score = self.V(attention_hidden_layer)

    # attention_weights shape == (batch_size, 64, 1)
    attention_weights = tf.nn.softmax(score, axis=1)

    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = attention_weights * features
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector, attention_weights
class CNN_Encoder(tf.keras.Model):
    # Since you have already extracted the features and dumped it
    # This encoder passes those features through a Fully connected layer
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        # shape after fc == (batch_size, 64, embedding_dim)
        self.fc = tf.keras.layers.Dense(embedding_dim)

    def call(self, x):
        x = self.fc(x)
        x = tf.nn.relu(x)
        return x
class RNN_Decoder(tf.keras.Model):
  def __init__(self, embedding_dim, units, vocab_size):
    super(RNN_Decoder, self).__init__()
    self.units = units

    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc1 = tf.keras.layers.Dense(self.units)
    self.fc2 = tf.keras.layers.Dense(vocab_size)

    self.attention = BahdanauAttention(self.units)

  def call(self, x, features, hidden):
    # defining attention as a separate model
    context_vector, attention_weights = self.attention(features, hidden)

    # x shape after passing through embedding == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # passing the concatenated vector to the GRU
    output, state = self.gru(x)

    # shape == (batch_size, max_length, hidden_size)
    x = self.fc1(output)

    # x shape == (batch_size * max_length, hidden_size)
    x = tf.reshape(x, (-1, x.shape[2]))

    # output shape == (batch_size * max_length, vocab)
    x = self.fc2(x)

    return x, state, attention_weights

  def reset_state(self, batch_size):
    return tf.zeros((batch_size, self.units))
encoder = CNN_Encoder(embedding_dim)
decoder = RNN_Decoder(embedding_dim, units, tokenizer.vocabulary_size())
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')


def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

Checkpoint

checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(encoder=encoder,
                           decoder=decoder,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
start_epoch = 0
if ckpt_manager.latest_checkpoint:
  start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
  # restoring the latest checkpoint in checkpoint_path
  ckpt.restore(ckpt_manager.latest_checkpoint)

Training

  • You extract the features stored in the respective .npy files and then pass those features through the encoder.
  • The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.
  • The decoder returns the predictions and the decoder hidden state.
  • The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.
  • Use teacher forcing to decide the next input to the decoder.
  • Teacher forcing is the technique where the target word is passed as the next input to the decoder.
  • The final step is to calculate the gradients and apply it to the optimizer and backpropagate.
# adding this in a separate cell because if you run the training cell
# many times, the loss_plot array will be reset
loss_plot = []
@tf.function
def train_step(img_tensor, target):
  loss = 0

  # initializing the hidden state for each batch
  # because the captions are not related from image to image
  hidden = decoder.reset_state(batch_size=target.shape[0])

  dec_input = tf.expand_dims([word_to_index('<start>')] * target.shape[0], 1)

  with tf.GradientTape() as tape:
      features = encoder(img_tensor)

      for i in range(1, target.shape[1]):
          # passing the features through the decoder
          predictions, hidden, _ = decoder(dec_input, features, hidden)

          loss += loss_function(target[:, i], predictions)

          # using teacher forcing
          dec_input = tf.expand_dims(target[:, i], 1)

  total_loss = (loss / int(target.shape[1]))

  trainable_variables = encoder.trainable_variables + decoder.trainable_variables

  gradients = tape.gradient(loss, trainable_variables)

  optimizer.apply_gradients(zip(gradients, trainable_variables))

  return loss, total_loss
EPOCHS = 20

for epoch in range(start_epoch, EPOCHS):
    start = time.time()
    total_loss = 0

    for (batch, (img_tensor, target)) in enumerate(dataset):
        batch_loss, t_loss = train_step(img_tensor, target)
        total_loss += t_loss

        if batch % 100 == 0:
            average_batch_loss = batch_loss.numpy()/int(target.shape[1])
            print(f'Epoch {epoch+1} Batch {batch} Loss {average_batch_loss:.4f}')
    # storing the epoch end loss value to plot later
    loss_plot.append(total_loss / num_steps)

    if epoch % 5 == 0:
      ckpt_manager.save()

    print(f'Epoch {epoch+1} Loss {total_loss/num_steps:.6f}')
    print(f'Time taken for 1 epoch {time.time()-start:.2f} sec\n')
Epoch 1 Batch 0 Loss 1.9040
Epoch 1 Batch 100 Loss 1.1321
Epoch 1 Batch 200 Loss 0.9657
Epoch 1 Batch 300 Loss 0.9452
Epoch 1 Loss 1.033760
Time taken for 1 epoch 147.59 sec

Epoch 2 Batch 0 Loss 0.8761
Epoch 2 Batch 100 Loss 0.7873
Epoch 2 Batch 200 Loss 0.8374
Epoch 2 Batch 300 Loss 0.8498
Epoch 2 Loss 0.777624
Time taken for 1 epoch 57.82 sec

Epoch 3 Batch 0 Loss 0.7543
Epoch 3 Batch 100 Loss 0.6674
Epoch 3 Batch 200 Loss 0.7227
Epoch 3 Batch 300 Loss 0.6803
Epoch 3 Loss 0.699581
Time taken for 1 epoch 57.73 sec

Epoch 4 Batch 0 Loss 0.7358
Epoch 4 Batch 100 Loss 0.6353
Epoch 4 Batch 200 Loss 0.6500
Epoch 4 Batch 300 Loss 0.6353
Epoch 4 Loss 0.650270
Time taken for 1 epoch 57.70 sec

Epoch 5 Batch 0 Loss 0.6276
Epoch 5 Batch 100 Loss 0.5497
Epoch 5 Batch 200 Loss 0.5869
Epoch 5 Batch 300 Loss 0.5766
Epoch 5 Loss 0.610114
Time taken for 1 epoch 57.67 sec

Epoch 6 Batch 0 Loss 0.5671
Epoch 6 Batch 100 Loss 0.5516
Epoch 6 Batch 200 Loss 0.5582
Epoch 6 Batch 300 Loss 0.5758
Epoch 6 Loss 0.575822
Time taken for 1 epoch 57.76 sec

Epoch 7 Batch 0 Loss 0.6537
Epoch 7 Batch 100 Loss 0.5586
Epoch 7 Batch 200 Loss 0.5351
Epoch 7 Batch 300 Loss 0.5348
Epoch 7 Loss 0.544230
Time taken for 1 epoch 57.65 sec

Epoch 8 Batch 0 Loss 0.5372
Epoch 8 Batch 100 Loss 0.4843
Epoch 8 Batch 200 Loss 0.5171
Epoch 8 Batch 300 Loss 0.5367
Epoch 8 Loss 0.516309
Time taken for 1 epoch 57.60 sec

Epoch 9 Batch 0 Loss 0.5621
Epoch 9 Batch 100 Loss 0.5091
Epoch 9 Batch 200 Loss 0.4685
Epoch 9 Batch 300 Loss 0.4437
Epoch 9 Loss 0.487771
Time taken for 1 epoch 57.56 sec

Epoch 10 Batch 0 Loss 0.4527
Epoch 10 Batch 100 Loss 0.4517
Epoch 10 Batch 200 Loss 0.4394
Epoch 10 Batch 300 Loss 0.4737
Epoch 10 Loss 0.462743
Time taken for 1 epoch 57.52 sec

Epoch 11 Batch 0 Loss 0.4689
Epoch 11 Batch 100 Loss 0.4200
Epoch 11 Batch 200 Loss 0.4695
Epoch 11 Batch 300 Loss 0.4373
Epoch 11 Loss 0.437498
Time taken for 1 epoch 57.61 sec

Epoch 12 Batch 0 Loss 0.4322
Epoch 12 Batch 100 Loss 0.4167
Epoch 12 Batch 200 Loss 0.4109
Epoch 12 Batch 300 Loss 0.4375
Epoch 12 Loss 0.413919
Time taken for 1 epoch 57.50 sec

Epoch 13 Batch 0 Loss 0.4104
Epoch 13 Batch 100 Loss 0.4194
Epoch 13 Batch 200 Loss 0.3696
Epoch 13 Batch 300 Loss 0.3344
Epoch 13 Loss 0.391727
Time taken for 1 epoch 57.51 sec

Epoch 14 Batch 0 Loss 0.3870
Epoch 14 Batch 100 Loss 0.3554
Epoch 14 Batch 200 Loss 0.3401
Epoch 14 Batch 300 Loss 0.3535
Epoch 14 Loss 0.371030
Time taken for 1 epoch 57.48 sec

Epoch 15 Batch 0 Loss 0.3836
Epoch 15 Batch 100 Loss 0.3911
Epoch 15 Batch 200 Loss 0.3609
Epoch 15 Batch 300 Loss 0.3276
Epoch 15 Loss 0.351505
Time taken for 1 epoch 58.04 sec

Epoch 16 Batch 0 Loss 0.3630
Epoch 16 Batch 100 Loss 0.3334
Epoch 16 Batch 200 Loss 0.3446
Epoch 16 Batch 300 Loss 0.3298
Epoch 16 Loss 0.334702
Time taken for 1 epoch 57.53 sec

Epoch 17 Batch 0 Loss 0.3531
Epoch 17 Batch 100 Loss 0.3121
Epoch 17 Batch 200 Loss 0.3323
Epoch 17 Batch 300 Loss 0.3379
Epoch 17 Loss 0.318561
Time taken for 1 epoch 57.44 sec

Epoch 18 Batch 0 Loss 0.3454
Epoch 18 Batch 100 Loss 0.3193
Epoch 18 Batch 200 Loss 0.2842
Epoch 18 Batch 300 Loss 0.3071
Epoch 18 Loss 0.304131
Time taken for 1 epoch 57.44 sec

Epoch 19 Batch 0 Loss 0.3615
Epoch 19 Batch 100 Loss 0.2894
Epoch 19 Batch 200 Loss 0.2797
Epoch 19 Batch 300 Loss 0.2777
Epoch 19 Loss 0.288646
Time taken for 1 epoch 57.41 sec

Epoch 20 Batch 0 Loss 0.3127
Epoch 20 Batch 100 Loss 0.2628
Epoch 20 Batch 200 Loss 0.2781
Epoch 20 Batch 300 Loss 0.2841
Epoch 20 Loss 0.276044
Time taken for 1 epoch 57.40 sec
plt.plot(loss_plot)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Plot')
plt.show()
png

Caption!

  • The evaluate function is similar to the training loop, except you don’t use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.
  • Stop predicting when the model predicts the end token.
  • And store the attention weights for every time step.
def evaluate(image):
    attention_plot = np.zeros((max_length, attention_features_shape))

    hidden = decoder.reset_state(batch_size=1)

    temp_input = tf.expand_dims(load_image(image)[0], 0)
    img_tensor_val = image_features_extract_model(temp_input)
    img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0],
                                                 -1,
                                                 img_tensor_val.shape[3]))

    features = encoder(img_tensor_val)

    dec_input = tf.expand_dims([word_to_index('<start>')], 0)
    result = []

    for i in range(max_length):
        predictions, hidden, attention_weights = decoder(dec_input,
                                                         features,
                                                         hidden)

        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()

        predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()
        predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())
        result.append(predicted_word)

        if predicted_word == '<end>':
            return result, attention_plot

        dec_input = tf.expand_dims([predicted_id], 0)

    attention_plot = attention_plot[:len(result), :]
    return result, attention_plot
def plot_attention(image, result, attention_plot):
    temp_image = np.array(Image.open(image))

    fig = plt.figure(figsize=(10, 10))

    len_result = len(result)
    for i in range(len_result):
        temp_att = np.resize(attention_plot[i], (8, 8))
        grid_size = max(int(np.ceil(len_result/2)), 2)
        ax = fig.add_subplot(grid_size, grid_size, i+1)
        ax.set_title(result[i])
        img = ax.imshow(temp_image)
        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())

    plt.tight_layout()
    plt.show()
# captions on the validation set
rid = np.random.randint(0, len(img_name_val))
image = img_name_val[rid]
real_caption = ' '.join([tf.compat.as_text(index_to_word(i).numpy())
                         for i in cap_val[rid] if i not in [0]])
result, attention_plot = evaluate(image)

print('Real Caption:', real_caption)
print('Prediction Caption:', ' '.join(result))
plot_attention(image, result, attention_plot)
Real Caption: <start> a stuffed bear has been posed to appear it is reading a book. <end>
Prediction Caption: a shelf filled with lots of stuffed animal in front of a grill. <end>
png

Try it on your own images

For fun, below you’re provided a method you can use to caption your own images with the model you’ve just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)

image_url = 'https://tensorflow.org/images/surf.jpg'
image_extension = image_url[-4:]
image_path = tf.keras.utils.get_file('image'+image_extension, origin=image_url)

result, attention_plot = evaluate(image_path)
print('Prediction Caption:', ' '.join(result))
plot_attention(image_path, result, attention_plot)
# opening the image
Image.open(image_path)
Downloading data from https://tensorflow.org/images/surf.jpg
64400/64400 [==============================] - 0s 1us/step
Prediction Caption: surfer on a wave. <end>
png

Next steps

Congrats! You’ve just trained an image captioning model with attention. Next, take a look at this example Neural Machine Translation with Attention. It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset.

Leave a Comment