{"id":8416,"date":"2023-12-11T12:21:08","date_gmt":"2023-12-11T20:21:08","guid":{"rendered":"https:\/\/live-cometml.pantheonsite.io\/?p=8416"},"modified":"2025-04-24T17:03:55","modified_gmt":"2025-04-24T17:03:55","slug":"image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization","status":"publish","type":"post","link":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization\/","title":{"rendered":"Image Captioning Model with TensorFlow, Transformers, and Kangas for Image Visualization"},"content":{"rendered":"\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*7tQcn_GPZzHb7fj22o7xDg.jpeg\" alt=\"scrabble tiles spelling caption\"\/><figcaption class=\"wp-element-caption\">Photo by <a class=\"af gy\" href=\"https:\/\/unsplash.com\/es\/@monicadear?utm_source=unsplash&amp;utm_medium=referral&amp;utm_content=creditCopyText\" target=\"_blank\" rel=\"noopener ugc nofollow\">Monica Flores<\/a> on <a class=\"af gy\" href=\"https:\/\/unsplash.com\/photos\/p4mFOzM-asQ?utm_source=unsplash&amp;utm_medium=referral&amp;utm_content=creditCopyText\" target=\"_blank\" rel=\"noopener ugc nofollow\">Unsplash<\/a><\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"5f03\">Image captioning is a compelling field that connects computer vision and natural language processing, enabling machines to generate textual descriptions of visual content. In an era dominated by visual content, the ability of machines to understand and describe images is a powerful stride towards human-like intelligence. This article will explore image captioning using <strong class=\"be fs\">TensorFlow<\/strong>. We will explore the process of training an image captioning model to generate descriptive captions for images, highlighting the critical steps involved. The model leverages an <strong class=\"be fs\">Encoder <\/strong>and <strong class=\"be fs\">Decoder <\/strong>based on the <strong class=\"be fs\">Transformer <\/strong>architecture as covered in &#8220;<a class=\"af gy\" href=\"https:\/\/arxiv.org\/abs\/1706.03762\" target=\"_blank\" rel=\"noopener ugc nofollow\"><em class=\"abi\">Attention is all you need<\/em><\/a>,&#8221; so some knowledge can come in handy. Still, we will implement them here for understanding.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"4918\">Also, please acquaint yourself with <a href=\"https:\/\/github.com\/comet-ml\/kangas\"><strong class=\"be fs\">Kangas<\/strong><\/a>, as we will use it <strong class=\"be fs\">to visualize image data in this article<\/strong>. Below are resources to get you started:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><a href=\"https:\/\/heartbeat.comet.ml\/create-and-visualize-image-data-with-kangas-for-computer-vision-tasks-83fdad794d\"><strong class=\"be fs\"><em class=\"abi\">Creating and visualizing image data with Kangas.<\/em><\/strong><\/a><\/li>\n\n\n\n<li><a href=\"https:\/\/heartbeat.comet.ml\/constructing-and-visualizing-kangas-datagrid-on-kangas-ui-f63d4350ab61\"><strong class=\"be fs\"><em class=\"abi\">Constructing and visualizing Kangas DataGrid on Kangas UI.<\/em><\/strong><\/a><\/li>\n<\/ul>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"9756\">You can follow along on <code class=\"ef abr abs abt abu b\"><a class=\"af gy\" href=\"https:\/\/colab.research.google.com\/drive\/1BYMUagWjbIV4yJbtzlkm7ZvaBN5CJ2x9\" target=\"_blank\" rel=\"noopener ugc nofollow\">this notebook<\/a>.<\/code><\/p>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"0d1b\">What Exactly Is An Image Captioning Model?<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"6786\">An image captioning model is a model that can effectively generate a descriptive sentence based on the contents of a particular image.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"221a\">In recent years, image captioning has improved tremendously, fueled by the advancements in machine translation, where the encoder and decoder can generate more coherent sentences. Such progress comes from the introduction of Transformer encoder and decoder models, which have remarkably improved performance compared to traditional RNN-based encoder and decoder models.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"18cc\">A perfect image captioning model should:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong class=\"be fs\">Understand the context of a given image.<\/strong><\/li>\n\n\n\n<li><strong class=\"be fs\">Accurately represent that understanding as a textual description.<\/strong><\/li>\n<\/ul>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"57e1\">For instance, given the following image, the model should be able to produce acceptable captions describing the contents of the image. The captions should be good since various interpretations of the same image can exist.<\/p>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/1*3TK0eLFiWH9fBqc51zBVoQ.jpeg\" alt=\"brown dog lying down with a cat lying on top\"\/><figcaption class=\"wp-element-caption\">Photo by <a class=\"af gy\" href=\"https:\/\/unsplash.com\/@glomadmarketing?utm_source=unsplash&amp;utm_medium=referral&amp;utm_content=creditCopyText\" target=\"_blank\" rel=\"noopener ugc nofollow\">Glomad Marketing<\/a> on <a class=\"af gy\" href=\"https:\/\/unsplash.com\/photos\/6VQlKJp2vpo?utm_source=unsplash&amp;utm_medium=referral&amp;utm_content=creditCopyText\" target=\"_blank\" rel=\"noopener ugc nofollow\">Unsplash<\/a><\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"b1d4\">The captions for the above image can be:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><em class=\"abi\">A cat and a dog are sleeping on the floor<\/em>.<\/li>\n\n\n\n<li><em class=\"abi\">A black cat is resting on a brown dog.<\/em><\/li>\n\n\n\n<li><em class=\"abi\">A cat and a dog are resting at the garage entrance<\/em>.<\/li>\n<\/ul>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"72bc\">Approach to Creating the Model<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"38a6\">The model is inspired by <a href=\"https:\/\/keras.io\/examples\/vision\/image_captioning\/\">Implementing an image captioning model using a CNN and a Transformer<\/a> and <a href=\"https:\/\/www.tensorflow.org\/tutorials\/text\/image_captioning\">Image captioning with visual attention on TensorFlow<\/a>. Some of the processes we will undertake:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Source a dataset that has <code class=\"ef abr abs abt abu b\">image, caption<\/code> pairs.<\/li>\n\n\n\n<li>Visualize the dataset with Kangas to see its representation.<\/li>\n\n\n\n<li>Preprocess the images and captions.<\/li>\n\n\n\n<li>Resizing the images for pixel consistency through the model.<\/li>\n\n\n\n<li>Using a pre-trained CNN model to obtain image features.<\/li>\n\n\n\n<li>Create a Transformer encoder and decoder.<\/li>\n\n\n\n<li>Training the model.<\/li>\n\n\n\n<li>Generating captions using the trained model.<\/li>\n\n\n\n<li>Visualizing the model&#8217;s accuracy and loss.<\/li>\n<\/ul>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"dcbd\">The Dataset<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"eb00\">There are several datasets available for image captioning tasks:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><a href=\"https:\/\/www.kaggle.com\/datasets\/adityajn105\/flickr8k\">Flickr8k<\/a>: Has a little above 8k images paired with their respective captions.<\/li>\n\n\n\n<li><a href=\"https:\/\/www.kaggle.com\/datasets\/adityajn105\/flickr30k\">Flickr30k<\/a>: Has over 30k images paired with their respective captions.<\/li>\n\n\n\n<li><a href=\"https:\/\/www.kaggle.com\/datasets\/sabahesaraki\/2017-2017\">MSCOCO<\/a>: Have over 160k images paired with their respective captions.<\/li>\n<\/ul>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"bfa9\">These datasets have been widely used and are reliable in learning or building the image captioning model. We will stick with the <a href=\"https:\/\/www.kaggle.com\/datasets\/adityajn105\/flickr8k\"><strong class=\"be fs\">Flickr8k dataset<\/strong><\/a> as it is more convenient for a broader range of audiences with inadequate resources for preparing and training more complicated datasets.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"d16f\"><strong class=\"be fs\">Download <\/strong>the dataset, and let&#8217;s get started!<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"7fa9\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">%pip install opendatasets <span class=\"hljs-comment\"># to help download data directly from Kaggle<\/span><\/span><\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"34d3\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">import<\/span> opendatasets <span class=\"hljs-keyword\">as<\/span> od\n\n<span class=\"hljs-comment\"># download<\/span>\n<span class=\"hljs-comment\"># Kaggle API key required<\/span>\nod.download(<span class=\"hljs-string\">\"https:\/\/www.kaggle.com\/datasets\/adityajn105\/flickr8k\"<\/span>)<\/span><\/pre>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"205c\">Visualizing the Dataset With Kangas<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"7bc9\">Kangas comes in handy for visualizing multimedia data. Unlike Pandas, Kangas comes packed with an effortless and straightforward way of visualizing image data (<strong class=\"be fs\">Kangas UI<\/strong>), and we do not have to rely on other libraries and packages to do so. I have provided the well-structured resources above to help you get started quickly.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"4e22\">First, <strong class=\"be fs\">install Kangas<\/strong>:<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"c6c4\"><code class=\"ef abr abs abt abu b\">%pip install kangas<\/code><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"c8b0\">Next, <strong class=\"be fs\">import Kangas<\/strong> with an alias &#8220;<em class=\"abi\">kg<\/em>&#8220;:<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"32f8\"><code class=\"ef abr abs abt abu b\">import kangas as kg<\/code><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"ab47\">The base structure of Kangas is a DataGrid. However, we will first read the data as a Pandas DataFrame to process and add a column, after which we will read the DataFrame with Kangas to get the DataGrid.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"dd33\">Read the data. I am using Google Colab, hence the paths:<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"ee64\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">captions_file = <span class=\"hljs-string\">'\/content\/flickr8k\/captions.txt'<\/span>\ndf_captioned = pd.read_csv(captions_file)\n\n<span class=\"hljs-comment\"># Add actual image path <\/span>\ndf_captioned[<span class=\"hljs-string\">'image'<\/span>] = df_captioned[<span class=\"hljs-string\">'image'<\/span>].apply(\n<span class=\"hljs-keyword\">lambda<\/span> x: <span class=\"hljs-string\">f'\/content\/flickr8k\/Images\/<span class=\"hljs-subst\">{x}<\/span>'<\/span>)\n\n<span class=\"hljs-comment\"># Rename the 'image' column to 'image_path'<\/span>\ndf_captioned.rename({<span class=\"hljs-string\">'image'<\/span>:<span class=\"hljs-string\">'image_path'<\/span>}, axis=<span class=\"hljs-number\">1<\/span>, inplace=<span class=\"hljs-literal\">True<\/span>)\ndf_captioned.head()<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:675\/1*Y1UUn0j58z8-oYsegyOm5Q.png\" alt=\"DataFrame with image paths and caption\"\/><figcaption class=\"wp-element-caption\">DataFrame with image paths and caption<\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"e155\">To visualize the images in Kangas, we need to convert the images to <strong class=\"be fs\">Kangas image assets<\/strong> with <code class=\"ef abr abs abt abu b\">Image()<\/code> or convert them to Pillow images(PIL).<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"a5bc\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-comment\"># convert the images from the image paths<\/span>\n<span class=\"hljs-comment\"># to Kangas image assets<\/span>\nimages= df_captioned[<span class=\"hljs-string\">'image_path'<\/span>].<span class=\"hljs-built_in\">map<\/span>(\n    <span class=\"hljs-keyword\">lambda<\/span> x: kg.Image(x)\n)\n\n<span class=\"hljs-comment\"># Add a new column with the image assets(actual images)<\/span>\ndf_captioned.insert(loc=<span class=\"hljs-number\">1<\/span>, column=<span class=\"hljs-string\">'image'<\/span>, value=images)\n\ndf_captioned.head()<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*zQSBSCSlTNYyLn-xxqXWgg.png\" alt=\"DataFrame with Kangas Image assets\"\/><figcaption class=\"wp-element-caption\">DataFrame with Kangas Image assets<\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"0df3\">Let&#8217;s visualize some of the images with Kangas:<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"912a\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">viewRandomImages<\/span>(<span class=\"hljs-params\">samples=<span class=\"hljs-number\">1<\/span><\/span>):\n  random_rows = df_captioned.sample(samples) <span class=\"hljs-comment\">#random images<\/span>\n\n  <span class=\"hljs-keyword\">for<\/span> idx, row <span class=\"hljs-keyword\">in<\/span> random_rows.iterrows():\n    <span class=\"hljs-comment\"># view with Kangas<\/span>\n    image = kg.Image(row[<span class=\"hljs-string\">'image_path'<\/span>])\n    image.show()\n    <span class=\"hljs-built_in\">print<\/span>(<span class=\"hljs-string\">'\\n'<\/span>, row[<span class=\"hljs-string\">'caption'<\/span>],<span class=\"hljs-string\">'\\n'<\/span>)<\/span><\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"7e7a\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">viewRandomImages(<span class=\"hljs-number\">2<\/span>) <span class=\"hljs-comment\">#view two images with captions<\/span><\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:481\/1*E7b-lTtEoPrqOOw9qEEAYQ.png\" alt=\"Images with captions. Viewed with Kangas\"\/><figcaption class=\"wp-element-caption\">Images with captions. Viewed with Kangas<\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"5f84\">Kangas can read data in various formats into a DataGrid. Since we have the DataFrame, we will use Kangas&#8217;s <code class=\"ef abr abs abt abu b\">read_dataframe()<\/code> method to return a DataGrid. The best part of Kangas is the <strong class=\"be fs\">interactive Kangas UI<\/strong>. Instead of visualizing them individually, the UI creates a central place to view the images.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"eb91\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-comment\"># view a shuffled DataGrid<\/span>\ndg_captioned = kg.read_dataframe(df_captioned.sample(frac=<span class=\"hljs-number\">1<\/span>))\n\n<span class=\"hljs-comment\"># The dg.show() method to fire up the UI<\/span>\ndg_captioned.show()<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*9Hknek5yuS_pmvkphuhX7Q.png\" alt=\"Image Captioning data on Kangas UI\"\/><figcaption class=\"wp-element-caption\">Image Captioning data on Kangas UI<\/figcaption><\/figure>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"0525\">You can see that each image has a corresponding caption. On the UI, you can click on any image to view\/zoom\/apply grayscale, sort, or group the data as you wish to explore.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"5903\">For instance, we can <strong class=\"be fs\">view the data<\/strong> <strong class=\"be fs\">without the &#8220;<em class=\"abi\">image_path<\/em>&#8221; column<\/strong>. Just click on the &#8220;<em class=\"abi\">columns<\/em>&#8221; tab and remove the row.<\/p>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*4V6Gef5xuRwucN-PoHCxgA.png\" alt=\"Removed the &quot;image_path&quot; column: Kangas UI\"\/><figcaption class=\"wp-element-caption\">Removed the &#8220;<em class=\"adj\">image_path<\/em>&#8221; column: Kangas UI<\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"30e2\">Perfect! Now that you have visualized how the data has been represented, it is time to create the model. But let&#8217;s first import all the libraries we will require.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"87cc\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">import<\/span> pandas <span class=\"hljs-keyword\">as<\/span> pd\n<span class=\"hljs-keyword\">import<\/span> numpy <span class=\"hljs-keyword\">as<\/span> np\n<span class=\"hljs-keyword\">import<\/span> matplotlib.pyplot <span class=\"hljs-keyword\">as<\/span> plt\n<span class=\"hljs-keyword\">import<\/span> kangas <span class=\"hljs-keyword\">as<\/span> kg\n\n<span class=\"hljs-keyword\">import<\/span> re\n<span class=\"hljs-keyword\">import<\/span> tensorflow\n<span class=\"hljs-keyword\">import<\/span> tensorflow <span class=\"hljs-keyword\">as<\/span> tf\n<span class=\"hljs-keyword\">from<\/span> tensorflow <span class=\"hljs-keyword\">import<\/span> keras\n<span class=\"hljs-keyword\">from<\/span> tensorflow.keras <span class=\"hljs-keyword\">import<\/span> layers\n<span class=\"hljs-keyword\">from<\/span> tensorflow.keras.layers <span class=\"hljs-keyword\">import<\/span> TextVectorization\n\n<span class=\"hljs-keyword\">from<\/span> tensorflow.keras.applications <span class=\"hljs-keyword\">import<\/span> efficientnet <span class=\"hljs-comment\">#Image feature extractor<\/span><\/span><\/pre>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"e8ed\">Preparing the Dataset<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"4a50\">The first step in building any model is converting the data into a carefully curated dataset to suit the model requirements before training. We require a paired dataset with images and their respective captions for an image captioning dataset.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"38e1\">Looking at the <code class=\"ef abr abs abt abu b\">captions.txt<\/code> file:<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"58a2\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">with<\/span> <span class=\"hljs-built_in\">open<\/span>(captions_file) <span class=\"hljs-keyword\">as<\/span> caption_data:\n  caption_data = caption_data.readlines()\n  <span class=\"hljs-keyword\">for<\/span> data <span class=\"hljs-keyword\">in<\/span> caption_data[<span class=\"hljs-number\">20<\/span>:<span class=\"hljs-number\">23<\/span>]:\n    <span class=\"hljs-built_in\">print<\/span>(data)<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:655\/1*mz8L6w_ShSW1lXYrEEWaiw.png\" alt=\"Some data from captions.txt\"\/><figcaption class=\"wp-element-caption\">Some data from captions.txt<\/figcaption><\/figure>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"6dd0\">You notice that commas separate each image from its corresponding caption. Our goal is to separate the two entities.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"e815\">Since we know that each image in the dataset has at least five captions to choose from, we will <strong class=\"be fs\">create a dictionary that maps each image (<\/strong>as keys<strong class=\"be fs\">) to its corresponding captions (<\/strong>as values<strong class=\"be fs\">)<\/strong>. Also, for better consistency and model training, we will<strong class=\"be fs\"> filter out the captions that are too short and those that are too long <\/strong>(marked as outliers) by predefining a sequence length.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"e6d4\">If you are familiar with sequence-to-sequence tasks like machine translation, <strong class=\"be fs\">adding the start and end tokens<\/strong> to the captions will not surprise you. The start and end tokens act as explicit delimiters to the beginning and end of a sequence, thus helping the model identify the boundaries of the input sequence during training and inference.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"a832\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">load_captions<\/span>(<span class=\"hljs-params\">caption_filename<\/span>):\n\n  <span class=\"hljs-keyword\">with<\/span> <span class=\"hljs-built_in\">open<\/span>(captions_file) <span class=\"hljs-keyword\">as<\/span> caption_data:\n    caption_data = caption_data.readlines()\n\n    mapping_dict = {} <span class=\"hljs-comment\"># dict to store image to caption mapping<\/span>\n    text_data = [] <span class=\"hljs-comment\"># stores a list of preocessed captions<\/span>\n    outlier_imgs = <span class=\"hljs-built_in\">set<\/span>()\n\n    <span class=\"hljs-keyword\">for<\/span> line <span class=\"hljs-keyword\">in<\/span> caption_data:\n      line = line.strip(<span class=\"hljs-string\">'\\n'<\/span>).split(<span class=\"hljs-string\">','<\/span>) <span class=\"hljs-comment\"># split image and caption at the commas<\/span>\n      image_codeName, caption = line[<span class=\"hljs-number\">0<\/span>], line[<span class=\"hljs-number\">1<\/span>]\n      image_name = os.path.join(image_paths, image_codeName)<span class=\"hljs-comment\"># create full path to image<\/span>\n\n      caption_tokens = caption.strip().split() <span class=\"hljs-comment\"># create tokens<\/span>\n\n      <span class=\"hljs-comment\"># filter the images using the caption lengths<\/span>\n      <span class=\"hljs-keyword\">if<\/span> <span class=\"hljs-built_in\">len<\/span>(caption_tokens) &lt; <span class=\"hljs-number\">5<\/span> <span class=\"hljs-keyword\">or<\/span> <span class=\"hljs-built_in\">len<\/span>(caption_tokens) &gt; sequenceLength:\n        outlier_imgs.add(image_name)\n        <span class=\"hljs-keyword\">continue<\/span>\n\n      <span class=\"hljs-comment\"># get all .jpg images<\/span>\n      <span class=\"hljs-comment\"># add START and END tokens to each caption<\/span>\n      <span class=\"hljs-comment\"># convert the captions to lowercase<\/span>\n      <span class=\"hljs-keyword\">if<\/span> image_name.endswith(<span class=\"hljs-string\">'.jpg'<\/span>) <span class=\"hljs-keyword\">and<\/span> image_name <span class=\"hljs-keyword\">not<\/span> <span class=\"hljs-keyword\">in<\/span> outlier_imgs:\n        caption = <span class=\"hljs-string\">\"&lt;START&gt; \"<\/span> + caption.strip().lower() + <span class=\"hljs-string\">\" &lt;END&gt;\"<\/span>\n        text_data.append(caption)\n\n        <span class=\"hljs-keyword\">if<\/span> image_name <span class=\"hljs-keyword\">in<\/span> mapping_dict:\n            mapping_dict[image_name].append(caption)\n        <span class=\"hljs-keyword\">else<\/span>:\n            mapping_dict[image_name] = \n\n    <span class=\"hljs-keyword\">for<\/span> image_name <span class=\"hljs-keyword\">in<\/span> outlier_imgs:\n      <span class=\"hljs-keyword\">if<\/span> image_name <span class=\"hljs-keyword\">in<\/span> mapping_dict:\n        <span class=\"hljs-keyword\">del<\/span> mapping_dict[image_name]\n\n\n    <span class=\"hljs-keyword\">return<\/span> mapping_dict, text_data<\/span><\/pre>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"7bf8\"><code class=\"ef abr abs abt abu b\">mapping_dict<\/code> contains images (keys) mapped to their captions( values) while the <code class=\"ef abr abs abt abu b\">text_data<\/code> has all the preprocessed captions.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"edb5\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-comment\"># mapped images to their caption<\/span>\nmapping_dict, text_data = load_captions(captions_file)\n\n<span class=\"hljs-built_in\">list<\/span>(mapping_dict.keys())[:<span class=\"hljs-number\">2<\/span>], <span class=\"hljs-built_in\">list<\/span>(mapping_dict.values())[:<span class=\"hljs-number\">2<\/span>]<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:693\/1*o3tZAIsTVKcAYsHI2TI7Uw.png\" alt=\"Images and captions\"\/><figcaption class=\"wp-element-caption\">Images and captions<\/figcaption><\/figure>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"b354\">Let&#8217;s see the captions of one of the images:<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"f42e\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">mapping_dict[<span class=\"hljs-string\">'\/content\/flickr8k\/Images\/1000268201_693b08cb0e.jpg'<\/span>]<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:683\/1*R6QlB3CM6TlmNmrcKEHIdg.png\" alt=\"Some captions for a single image in the mapping_dict\"\/><figcaption class=\"wp-element-caption\">Some captions for a single image in the mapping_dict<\/figcaption><\/figure>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"1a16\">Each image is mapped to five corresponding captions.<\/p>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"d35c\">Split the Data Into Training and Validation Sets<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"cf04\">We will split the captioning data into two separate dictionaries for the training and validation data.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"4d84\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">train_val_split<\/span>(<span class=\"hljs-params\">caption_data, train_sample=<span class=\"hljs-number\">0.8<\/span><\/span>):\n\n  images = <span class=\"hljs-built_in\">list<\/span>(caption_data.keys()) <span class=\"hljs-comment\"># gather all images<\/span>\n\n  train_sample = <span class=\"hljs-built_in\">int<\/span>(<span class=\"hljs-built_in\">len<\/span>(caption_data) * train_sample) <span class=\"hljs-comment\"># split<\/span>\n\n  training_set = {\n      image_name: caption_data[image_name] <span class=\"hljs-keyword\">for<\/span> image_name <span class=\"hljs-keyword\">in<\/span> images[:train_sample]\n  }\n  validation_set = {\n      image_name: caption_data[image_name] <span class=\"hljs-keyword\">for<\/span> image_name <span class=\"hljs-keyword\">in<\/span> images[train_sample:]\n  }\n\n  <span class=\"hljs-keyword\">return<\/span> training_set, validation_set<\/span><\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"3cec\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">training_set, validation_set = train_val_split(mapping_dict)\n<span class=\"hljs-built_in\">print<\/span>(<span class=\"hljs-string\">f\"Training data: <span class=\"hljs-subst\">{<span class=\"hljs-built_in\">len<\/span>(training_set)}<\/span>\\nValidation data: <span class=\"hljs-subst\">{<span class=\"hljs-built_in\">len<\/span>(validation_set)}<\/span>\"<\/span>)<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:303\/1*b4CAzmGtLAgfTLEoVdiJLQ.png\" alt=\"\"\/><figcaption class=\"wp-element-caption\">Training and validation data<\/figcaption><\/figure>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"0a41\">Vectorizing the Data<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"fc58\">To feed the data into the model, we need to vectorize it. That means that we need to convert the strings into integer sequences where each integer represents the index of a word in a vocabulary. TensorFlow provides the <a class=\"af gy\" href=\"https:\/\/www.tensorflow.org\/api_docs\/python\/tf\/keras\/layers\/TextVectorization\" target=\"_blank\" rel=\"noopener ugc nofollow\">TextVectorization<\/a> layer for this.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"3646\">The layer learns the vocabulary from the captions through the <code class=\"ef abr abs abt abu b\">adapt()<\/code>method. The <code class=\"ef abr abs abt abu b\">adapt()<\/code> The method iterates over all captions, splits them into words, checks the frequency of each string value in the caption, and computes a vocabulary of their most frequently used words.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"c824\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">VOCAB_SIZE = <span class=\"hljs-number\">10000<\/span>\n\n<span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">standardization<\/span>(<span class=\"hljs-params\"><span class=\"hljs-built_in\">input<\/span><\/span>):\n    lowercase = tf.strings.lower(<span class=\"hljs-built_in\">input<\/span>)\n    <span class=\"hljs-keyword\">return<\/span> tf.strings.regex_replace(lowercase, <span class=\"hljs-string\">\"[%s]\"<\/span> % re.escape(strip_chars), <span class=\"hljs-string\">\"\"<\/span>)\n\n\nstrip_chars = <span class=\"hljs-string\">\"!\\\"#$%&amp;'()*+,-.\/:;&lt;=&gt;?@[\\]^_`{|}~\"<\/span>\nstrip_chars = strip_chars.replace(<span class=\"hljs-string\">\"&lt;\"<\/span>, <span class=\"hljs-string\">\"\"<\/span>)\nstrip_chars = strip_chars.replace(<span class=\"hljs-string\">\"&gt;\"<\/span>, <span class=\"hljs-string\">\"\"<\/span>)\n\nvectorization = TextVectorization(\n    max_tokens=VOCAB_SIZE,\n    output_mode=<span class=\"hljs-string\">\"int\"<\/span>,\n    output_sequence_length=sequenceLength,\n    standardize=standardization,\n)\nvectorization.adapt(text_data)<\/span><\/pre>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"ecef\">We can check some vocabulary that has been computed after vectorization.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"3b28\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-comment\"># Get some vocabulary<\/span>\n<span class=\"hljs-built_in\">print<\/span>(vectorization.get_vocabulary()[:<span class=\"hljs-number\">15<\/span>])<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:555\/1*EtyE1auJAb9tBSOQS4RjDQ.png\" alt=\"\"\/><figcaption class=\"wp-element-caption\">Vocabulary examples from the vectorization layer<\/figcaption><\/figure>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"c562\">Let&#8217;s apply vectorization to some data to see the output sequences.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"daa6\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">vectorizer = vectorization([[<span class=\"hljs-string\">'a dog sleeping under a tree'<\/span>], [<span class=\"hljs-string\">'a bird feeding small chicks'<\/span>]])\nvectorizer<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:588\/1*SIKHzduo7OwtRtrIVcSqcA.png\" alt=\"Example Integer sequences from vectorization layer\"\/><figcaption class=\"wp-element-caption\">Example Integer sequences from vectorization layer<\/figcaption><\/figure>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"0559\">Create the tf.data.Dataset Pipeline<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"9bf0\">At this point, we need to transform, preprocess, and prepare the training and validation data for model training. We do this by creating a pipeline using the <code class=\"ef abr abs abt abu b\">tf.data.Dataset<\/code> API. With the pipeline, we can:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Shuffle the dataset.<\/li>\n\n\n\n<li>Tokenize all captions for each image through the vectorization layer.<\/li>\n\n\n\n<li>Map the images to their respective captions.<\/li>\n<\/ul>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"9555\">In addition, we will create a function that will load each image and resize it to a fixed size for the model. That ensures that the same number of pixels represents all the images.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"e348\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">IMAGE_SIZE = (<span class=\"hljs-number\">299<\/span>, <span class=\"hljs-number\">299<\/span>)\nBATCH_SIZE = <span class=\"hljs-number\">64<\/span>\nEPOCHS = <span class=\"hljs-number\">30<\/span>\nAUTOTUNE = tf.data.AUTOTUNE\n\n<span class=\"hljs-comment\"># load and resize each image to IMAGE_SIZE<\/span>\n<span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">decode_and_resize<\/span>(<span class=\"hljs-params\">image_path<\/span>):\n\n  image = tf.io.read_file(image_path)\n  image = tf.image.decode_jpeg(image, channels=<span class=\"hljs-number\">3<\/span>)\n  image = tf.image.resize(image, IMAGE_SIZE)\n  image = tf.image.convert_image_dtype(image, tf.float32)\n  <span class=\"hljs-keyword\">return<\/span> image\n\n<span class=\"hljs-comment\"># map each resized image to respective vectorized captions<\/span>\n<span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">process_input<\/span>(<span class=\"hljs-params\">img_path, captions<\/span>):\n  <span class=\"hljs-keyword\">return<\/span> decode_and_resize(img_path), vectorization(captions)\n\n<span class=\"hljs-comment\"># Function defining the transformation pipeline<\/span>\n<span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">make_dataset<\/span>(<span class=\"hljs-params\">images, captions<\/span>):\n  dataset = tf.data.Dataset.from_tensor_slices((images, captions))\n  dataset = dataset.shuffle(BATCH_SIZE * <span class=\"hljs-number\">8<\/span>)\n  dataset = dataset.<span class=\"hljs-built_in\">map<\/span>(process_input, num_parallel_calls=AUTOTUNE)\n  dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)\n  <span class=\"hljs-keyword\">return<\/span> dataset<\/span><\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"fb62\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-comment\"># create transformed training and validation data<\/span>\ntraining_data = make_dataset(<span class=\"hljs-built_in\">list<\/span>(training_set.keys()), <span class=\"hljs-built_in\">list<\/span>(training_set.values()))\n\nvalidation_data = make_dataset(<span class=\"hljs-built_in\">list<\/span>(validation_set.keys()), <span class=\"hljs-built_in\">list<\/span>(validation_set.values()))<\/span><\/pre>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"1a3e\">Building the Model<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"c94f\">The model will consist of three parts:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>An image feature extractor.<\/li>\n\n\n\n<li>The Transformer-based Encoder.<\/li>\n\n\n\n<li>The Transformer-based Decoder.<\/li>\n<\/ul>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"6d3e\">Image Feature Extractor<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"8949\">We will use an image model to extract features from each image. The model is pre-trained on ImageNet as an image classification model. However, in this case, we don&#8217;t need the classification layer but the last layer with feature maps. We will use the Keras <a class=\"af gy\" href=\"https:\/\/www.tensorflow.org\/api_docs\/python\/tf\/keras\/applications\/efficientnet\/EfficientNetB0\" target=\"_blank\" rel=\"noopener ugc nofollow\">EfficientNetB0<\/a> model.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"b9d0\">Let&#8217;s take a look at the model results:<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"e0c2\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">img_path = <span class=\"hljs-built_in\">list<\/span>(training_set.keys())[<span class=\"hljs-number\">1<\/span>]\n\nmodel = efficientnet.EfficientNetB0(\n      input_shape=(*IMAGE_SIZE, <span class=\"hljs-number\">3<\/span>),\n      include_top=<span class=\"hljs-literal\">False<\/span>, weights = <span class=\"hljs-string\">'imagenet'<\/span>,\n  )\n\ntest_img_batch = decode_and_resize(img_path)[tf.newaxis, :]\n<span class=\"hljs-built_in\">print<\/span>(test_img_batch.shape)\n<span class=\"hljs-built_in\">print<\/span>(model(test_img_batch).shape)<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:270\/1*jC7AtjKfwxd5wqlSJRc33g.png\" alt=\"\"\/><figcaption class=\"wp-element-caption\">Feature map from <a class=\"af gy\" href=\"https:\/\/www.tensorflow.org\/api_docs\/python\/tf\/keras\/applications\/efficientnet\/EfficientNetB0\" target=\"_blank\" rel=\"noopener ugc nofollow\">EfficientNetB0<\/a> model<\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"9028\">The feature extractor returns a feature map for each model.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"b886\">Based on this model, we will create a new Convolutional Neural Network (CNN) Keras model for feature extraction. The CNN model will take as input the input tensor of feature maps from the EfficientNetB0 model.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"77cf\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">get_cnn_model<\/span>():\n\n  <span class=\"hljs-comment\"># include_top = False: return model without the <\/span>\n  <span class=\"hljs-comment\"># classification layer<\/span>\n  model = efficientnet.EfficientNetB0(\n      input_shape=(*IMAGE_SIZE, <span class=\"hljs-number\">3<\/span>), include_top=<span class=\"hljs-literal\">False<\/span>, weights = <span class=\"hljs-string\">'imagenet'<\/span>,\n\n  )\n\n  model.trainable = <span class=\"hljs-literal\">False<\/span>\n  model_out = model.output\n  model_out = layers.Reshape((-<span class=\"hljs-number\">1<\/span>, model_out.shape[-<span class=\"hljs-number\">1<\/span>]))(model_out)\n  cnn_model = keras.models.Model(model.<span class=\"hljs-built_in\">input<\/span>, model_out)\n  <span class=\"hljs-keyword\">return<\/span> cnn_model<\/span><\/pre>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"7751\">Next, we build a Transformer-based Encoder and Decoder.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"b439\">Earlier sequence-to-sequence models implemented Recurrent Neural Networks (RNNs) like LSTM and GRU. The input sequence fed into those models was encoded into a fixed-length representation with information about the input sequence for output sequence generation. However, the fixed-length representations often posed limitations where the input sequence was too long and contained crucial information at different positions.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"a25a\">To fix that problem, an attention mechanism was added to enable the RNN models to focus on more relevant parts of the input sequence during the decoding process. So, instead of relying solely on the fixed-length representations, the attention mechanism calculates attention weights for each input position and computes a weighted sum of the input sequence&#8217;s encoder outputs. This weighted sum, often called the &#8220;attention context,&#8221; is an additional input to the decoder at each decoding step. However, the RNNs suffered from parallelism since they decoded one token at a time, making the model train slower, especially on long input sequences.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"64d8\">In this article, we implement the Transformer architecture for encoder and decoder. It is similar to the RNN model with attention, but the main difference is that Transformers entirely replace RNNs with an attention mechanism. That makes them parallelizable, and computations can happen simultaneously. Layer outputs can be computed in parallel instead of one at a time, like in RNNs.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"1c7a\">To learn more about how Transformers work, you can read:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>&#8220;<a class=\"af gy\" href=\"https:\/\/arxiv.org\/abs\/1706.03762\" target=\"_blank\" rel=\"noopener ugc nofollow\"><em class=\"abi\">Attention is all you need<\/em><\/a>&#8221; paper.<\/li>\n\n\n\n<li><a class=\"af gy\" href=\"https:\/\/jalammar.github.io\/illustrated-transformer\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">Illustrated Transformer<\/a> by Jay Alammar.<\/li>\n<\/ul>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"f975\">The Transformer-Based Encoder<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"a0bb\">We will pass the image features we have extracted as inputs to an encoder to generate new representations. The inputs first go through a self-attention layer. The layer creates three vectors (query, key, and value vectors), calculated by multiplying the embedding by the matrices from the training process. The self-attention layer adds MultiHeadAttention to enable the model to focus on different positions.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"2a47\">The self-attention layer can add variation in outputs. Adding layer normalization helps normalize the outputs to make them compatible with the original inputs (residue connection), which allows the preservation of important information and gradients.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"49b4\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">class<\/span> <span class=\"hljs-title.class\">Encoder<\/span>(keras.layers.Layer):\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">__init__<\/span>(<span class=\"hljs-params\">self, embedding_dim, dense_dim, num_heads<\/span>):\n    <span class=\"hljs-built_in\">super<\/span>().__init__()\n    self.embedding_dim = embedding_dim\n    self.dense_dim = dense_dim\n    self.num_heads = num_heads\n\n    <span class=\"hljs-comment\"># Create the attention layer<\/span>\n    self.attention = keras.layers.MultiHeadAttention(\n        num_heads = num_heads, key_dim=embedding_dim, dropout=<span class=\"hljs-number\">0.0<\/span>\n    )\n\n    <span class=\"hljs-comment\"># Layer normalization<\/span>\n    self.layernorm1 = layers.LayerNormalization()\n    self.layernorm2 = layers.LayerNormalization()\n\n    self.dense = layers.Dense(embedding_dim, activation=<span class=\"hljs-string\">'relu'<\/span>)\n\n    <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">call<\/span>(<span class=\"hljs-params\">self, inputs, training, mask=<span class=\"hljs-literal\">None<\/span><\/span>):\n      inputs = self.layernorm1(inputs)\n      inputs = self.dense(inputs)\n\n      attention_output = self.attention(\n          query = inputs,\n          value = inputs,\n          keys = inputs,\n          attention_mask = <span class=\"hljs-literal\">None<\/span>,\n          training = training\n      )\n\n      <span class=\"hljs-comment\"># residue connecttion<\/span>\n      <span class=\"hljs-comment\"># add actual inputs and self attention outputs<\/span>\n      <span class=\"hljs-comment\"># normalize them<\/span>\n      out = self.layernorm2(inputs + attention_output)\n      <span class=\"hljs-keyword\">return<\/span> out<\/span><\/pre>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"cb11\">Positional Embedding Layer<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"c261\">Transformers do not have an inherent knowledge of order or position like RNNs. They would take the input sequence as Bag of Words, which may be indistinguishable. So before passing the image features as inputs to the encoder, we need to convert them into token embeddings and add positional information to each token. By doing so, the model can effectively encode both the content and the position of tokens in the input sequence, enabling it to capture positional relationships and dependencies in the data.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"f9e5\">Below, we create two embedding layers for token embedding and one for positional embedding. The token embedding layer maps the tokens to dense vectors, while the positional embedding layer maps positions within the sequence of dense vectors.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"0d43\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">class<\/span> <span class=\"hljs-title.class\">PositionalEmbedding<\/span>(keras.layers.Layer):\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">__init__<\/span>(<span class=\"hljs-params\">self, seq_length, vocab_size, embedding_dim<\/span>):\n    <span class=\"hljs-built_in\">super<\/span>().__init__()\n    self.token_embeddings = layers.Embedding(\n        input_dim=vocab_size, output_dim=embedding_dim\n    )\n    self.position_embeddings = layers.Embedding(\n        input_dim=seq_length, output_dim=embedding_dim\n    )\n    self.seq_length = seq_length\n    self.vocab_size = vocab_size\n    self.embedding_dim = embedding_dim\n    self.embed_scale = tf.math.sqrt(tf.cast(embedding_dim, tf.float32))\n\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">call<\/span>(<span class=\"hljs-params\">self, inputs<\/span>):\n    length = tf.shape(inputs)[-<span class=\"hljs-number\">1<\/span>]\n    positions = tf.<span class=\"hljs-built_in\">range<\/span>(start=<span class=\"hljs-number\">0<\/span>, limit=length, delta=<span class=\"hljs-number\">1<\/span>)\n    embedded_tokens = self.token_embeddings(inputs)\n    embedded_tokens = embedded_tokens * self.embed_scale\n    embedded_positions = self.position_embeddings(positions)\n    <span class=\"hljs-keyword\">return<\/span> embedded_tokens + embedded_positions\n\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">compute_mask<\/span>(<span class=\"hljs-params\">self, inputs, mask=<span class=\"hljs-literal\">None<\/span><\/span>):\n    <span class=\"hljs-keyword\">return<\/span> tf.math.not_equal(inputs, <span class=\"hljs-number\">0<\/span>)<\/span><\/pre>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"2eda\">The Transformer-Based Decoder<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"1a29\">The decoder is more complex to implement. It generates the output one by one while consulting the representation generated by the encoder. Like in an encoder, the decoder has a positional embedding layer and stack of layers.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"ace8\">The output of the top encoder is transformed into a set of attention vectors used in the &#8220;encoder-decoder attention&#8221; layer, enabling the decoder to focus on appropriate places in the input sequence. The decoder&#8217;s self-attention layer can only attend to earlier positions in the output sequence. That is done by masking future positions before the softmax step in the self-attention calculation.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"e984\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">class<\/span> <span class=\"hljs-title.class\">Decoder<\/span>(keras.layers.Layer):\n<span class=\"hljs-meta\">    @classmethod<\/span>\n      <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">add_method<\/span>(<span class=\"hljs-params\">cls, func<\/span>):\n        <span class=\"hljs-built_in\">setattr<\/span>(cls, func.__name__, func)\n        <span class=\"hljs-keyword\">return<\/span> func\n\n    <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">__init__<\/span>(<span class=\"hljs-params\">self, embedding_dim, ff_dim, num_heads<\/span>):\n        <span class=\"hljs-built_in\">super<\/span>().__init__()\n        self.embedding_dim = embedding_dim\n        self.ff_dim = ff_dim\n        self.num_heads = num_heads\n        self.attention1 = layers.MultiHeadAttention(\n            num_heads=num_heads, key_dim=embedding_dim, dropout=<span class=\"hljs-number\">0.1<\/span>\n        )\n        self.attention2 = layers.MultiHeadAttention(\n            num_heads=num_heads, key_dim=embedding_dim, dropout=<span class=\"hljs-number\">0.1<\/span>\n        )\n        self.ffn_layer1 = layers.Dense(ff_dim, activation=<span class=\"hljs-string\">\"relu\"<\/span>)\n        self.ffn_layer2 = layers.Dense(embedding_dim)\n\n        self.layernorm1 = layers.LayerNormalization()\n        self.layernorm2 = layers.LayerNormalization()\n        self.layernorm3 = layers.LayerNormalization()\n\n        self.embedding = PositionalEmbedding(\n              embedding_dim=<span class=\"hljs-number\">512<\/span>, seq_length=sequenceLength, vocab_size=VOCAB_SIZE\n          )\n        self.out = layers.Dense(VOCAB_SIZE, activation=<span class=\"hljs-string\">\"softmax\"<\/span>)\n\n        self.dropout1 = layers.Dropout(<span class=\"hljs-number\">0.3<\/span>)\n        self.dropout2 = layers.Dropout(<span class=\"hljs-number\">0.5<\/span>)\n        self.supports_masking = <span class=\"hljs-literal\">True<\/span>\n\n      <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">call<\/span>(<span class=\"hljs-params\">self, inputs, encoder_outputs, training, mask=<span class=\"hljs-literal\">None<\/span><\/span>):\n        inputs = self.embedding(inputs)\n        causal_mask = self.get_causal_attention_mask(inputs)\n\n        <span class=\"hljs-keyword\">if<\/span> mask <span class=\"hljs-keyword\">is<\/span> <span class=\"hljs-keyword\">not<\/span> <span class=\"hljs-literal\">None<\/span>:\n            padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)\n            combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)\n            combined_mask = tf.minimum(combined_mask, causal_mask)\n\n        attention_output1 = self.attention1(\n            query=inputs,\n            value=inputs,\n            key=inputs,\n            attention_mask=combined_mask,\n            training=training,\n        )\n        out1 = self.layernorm1(inputs + attention_output1)\n\n        attention_output2 = self.attention2(\n            query=out1,\n            value=encoder_outputs,\n            key=encoder_outputs,\n            attention_mask=padding_mask,\n            training=training,\n        )\n        out2 = self.layernorm2(out1 + attention_output2)\n\n        ffn_out = self.ffn_layer1(out2)\n        ffn_out = self.dropout1(ffn_out, training=training)\n        ffn_out = self.ffn_layer2(ffn_out)\n\n        ffn_out = self.layernorm3(ffn_out + out2, training=training)\n        ffn_out = self.dropout2(ffn_out, training=training)\n        preds = self.out(ffn_out)\n        <span class=\"hljs-keyword\">return<\/span> preds<\/span><\/pre>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"6b98\">Below, we write a method to generate a causal attention mask for the self-attention mechanism in a decoder layer. The causal attention mask ensures that each token can only attend to its previous positions and itself during self-attention, preventing information flow from future positions to past positions.<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"c3aa\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-meta\">@Decoder.add_method<\/span>\n<span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">get_causal_attention_mask<\/span>(<span class=\"hljs-params\">self, inputs<\/span>):\n        input_shape = tf.shape(inputs)\n        batch_size, sequence_length = input_shape[<span class=\"hljs-number\">0<\/span>], input_shape[<span class=\"hljs-number\">1<\/span>]\n        i = tf.<span class=\"hljs-built_in\">range<\/span>(sequence_length)[:, tf.newaxis] <span class=\"hljs-comment\">#(sequence_length, 1)<\/span>\n        j = tf.<span class=\"hljs-built_in\">range<\/span>(sequence_length) <span class=\"hljs-comment\">#(sequence_length,)<\/span>\n\n        <span class=\"hljs-comment\">#create the causal attention mask<\/span>\n        mask = tf.cast(i &gt;= j, dtype=<span class=\"hljs-string\">\"int32\"<\/span>)\n        mask = tf.reshape(mask, (<span class=\"hljs-number\">1<\/span>, input_shape[<span class=\"hljs-number\">1<\/span>], input_shape[<span class=\"hljs-number\">1<\/span>]))\n        mult = tf.concat(\n            [tf.expand_dims(batch_size, -<span class=\"hljs-number\">1<\/span>), tf.constant([<span class=\"hljs-number\">1<\/span>, <span class=\"hljs-number\">1<\/span>], dtype=tf.int32)],\n            axis=<span class=\"hljs-number\">0<\/span>,\n        )\n        <span class=\"hljs-keyword\">return<\/span> tf.tile(mask, mult)<\/span><\/pre>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"bbbb\">The Model<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"2bb5\">In this section, we build the captioning model. The model combines the feature extractor from the CNN model (cnn_model method), the encoder, and the decoder to generate the captions for images. When we call the model for training, it should receive the <code class=\"ef abr abs abt abu b\">image, caption<\/code> pairs.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"68e2\">The model also calculates the loss and the average accuracy (by comparing the true labels and the predicted labels).<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"e523\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-keyword\">class<\/span> <span class=\"hljs-title.class\">ImageCaptioningModel<\/span>(keras.Model):\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">__init__<\/span>(<span class=\"hljs-params\">\n      self, cnn_model,\n      encoder, decoder,\n      num_captions_per_image=<span class=\"hljs-number\">5<\/span>\n  <\/span>):\n\n      <span class=\"hljs-built_in\">super<\/span>().__init__()\n      self.cnn_model = cnn_model\n      self.encoder = encoder\n      self.decoder = decoder\n      self.loss_tracker = keras.metrics.Mean(name=<span class=\"hljs-string\">\"loss\"<\/span>)\n      self.acc_tracker = keras.metrics.Mean(name=<span class=\"hljs-string\">\"accuracy\"<\/span>)\n      self.num_captions_per_image = num_captions_per_image\n      self.image_aug = image_aug\n\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">calculate_loss<\/span>(<span class=\"hljs-params\">self, y_true, y_pred, mask<\/span>):\n    loss = self.loss(y_true, y_pred)\n    mask = tf.cast(mask, dtype=loss.dtype)\n    loss *= mask\n    <span class=\"hljs-keyword\">return<\/span> tf.reduce_sum(loss) \/ tf.reduce_sum(mask)\n\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">calculate_accuracy<\/span>(<span class=\"hljs-params\">self, y_true, y_pred, mask<\/span>):\n    accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=<span class=\"hljs-number\">2<\/span>))\n    accuracy = tf.math.logical_and(mask, accuracy)\n    accuracy = tf.cast(accuracy, dtype=tf.float32)\n    mask = tf.cast(mask, dtype=tf.float32)\n    <span class=\"hljs-keyword\">return<\/span> tf.reduce_sum(accuracy) \/ tf.reduce_sum(mask)\n\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">_compute_caption_loss_and_acc<\/span>(<span class=\"hljs-params\">self, img_embed, batch_seq, training=<span class=\"hljs-literal\">True<\/span><\/span>):\n    encoder_out = self.encoder(img_embed, training=training)\n    batch_seq_inp = batch_seq[:, :-<span class=\"hljs-number\">1<\/span>]\n    batch_seq_true = batch_seq[:, <span class=\"hljs-number\">1<\/span>:]\n    mask = tf.math.not_equal(batch_seq_true, <span class=\"hljs-number\">0<\/span>)\n    batch_seq_pred = self.decoder(\n        batch_seq_inp, encoder_out, training=training, mask=mask\n    )\n    loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)\n    acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)\n    <span class=\"hljs-keyword\">return<\/span> loss, acc\n\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">train_step<\/span>(<span class=\"hljs-params\">self, batch_data<\/span>):\n    batch_img, batch_seq = batch_data\n    batch_loss = <span class=\"hljs-number\">0<\/span>\n    batch_acc = <span class=\"hljs-number\">0<\/span>\n\n    <span class=\"hljs-keyword\">if<\/span> self.image_aug:\n        batch_img = self.image_aug(batch_img)\n\n    <span class=\"hljs-comment\"># 1. Get image embeddings<\/span>\n    img_embed = self.cnn_model(batch_img)\n\n    <span class=\"hljs-comment\"># 2. Pass each of the five captions one by one to the decoder<\/span>\n    <span class=\"hljs-comment\"># along with the encoder outputs and compute the loss as well as accuracy<\/span>\n    <span class=\"hljs-comment\"># for each caption.<\/span>\n    <span class=\"hljs-keyword\">for<\/span> i <span class=\"hljs-keyword\">in<\/span> <span class=\"hljs-built_in\">range<\/span>(self.num_captions_per_image):\n        <span class=\"hljs-keyword\">with<\/span> tf.GradientTape() <span class=\"hljs-keyword\">as<\/span> tape:\n            loss, acc = self._compute_caption_loss_and_acc(\n                img_embed, batch_seq[:, i, :], training=<span class=\"hljs-literal\">True<\/span>\n            )\n\n            <span class=\"hljs-comment\"># 3. Update loss and accuracy<\/span>\n            batch_loss += loss\n            batch_acc += acc\n\n        <span class=\"hljs-comment\"># 4. Get the list of all the trainable weights<\/span>\n        train_vars = (\n            self.encoder.trainable_variables + self.decoder.trainable_variables\n        )\n\n        <span class=\"hljs-comment\"># 5. Get the gradients<\/span>\n        grads = tape.gradient(loss, train_vars)\n\n        <span class=\"hljs-comment\"># 6. Update the trainable weights<\/span>\n        self.optimizer.apply_gradients(<span class=\"hljs-built_in\">zip<\/span>(grads, train_vars))\n\n    <span class=\"hljs-comment\"># 7. Update the trackers<\/span>\n    batch_acc \/= <span class=\"hljs-built_in\">float<\/span>(self.num_captions_per_image)\n    self.loss_tracker.update_state(batch_loss)\n    self.acc_tracker.update_state(batch_acc)\n\n    <span class=\"hljs-comment\"># 8. Return the loss and accuracy values<\/span>\n    <span class=\"hljs-keyword\">return<\/span> {<span class=\"hljs-string\">\"loss\"<\/span>: self.loss_tracker.result(), <span class=\"hljs-string\">\"acc\"<\/span>: self.acc_tracker.result()}\n\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">test_step<\/span>(<span class=\"hljs-params\">self, batch_data<\/span>):\n    batch_img, batch_seq = batch_data\n    batch_loss = <span class=\"hljs-number\">0<\/span>\n    batch_acc = <span class=\"hljs-number\">0<\/span>\n\n    <span class=\"hljs-comment\"># 1. Get image embeddings<\/span>\n    img_embed = self.cnn_model(batch_img)\n\n    <span class=\"hljs-comment\"># 2. Pass each of the five captions one by one to the decoder<\/span>\n    <span class=\"hljs-comment\"># along with the encoder outputs and compute the loss as well as accuracy<\/span>\n    <span class=\"hljs-comment\"># for each caption.<\/span>\n    <span class=\"hljs-keyword\">for<\/span> i <span class=\"hljs-keyword\">in<\/span> <span class=\"hljs-built_in\">range<\/span>(self.num_captions_per_image):\n        loss, acc = self._compute_caption_loss_and_acc(\n            img_embed, batch_seq[:, i, :], training=<span class=\"hljs-literal\">False<\/span>\n        )\n\n        <span class=\"hljs-comment\"># 3. Update batch loss and batch accuracy<\/span>\n        batch_loss += loss\n        batch_acc += acc\n\n    batch_acc \/= <span class=\"hljs-built_in\">float<\/span>(self.num_captions_per_image)\n\n    <span class=\"hljs-comment\"># 4. Update the trackers<\/span>\n    self.loss_tracker.update_state(batch_loss)\n    self.acc_tracker.update_state(batch_acc)\n\n    <span class=\"hljs-comment\"># 5. Return the loss and accuracy values<\/span>\n    <span class=\"hljs-keyword\">return<\/span> {<span class=\"hljs-string\">\"loss\"<\/span>: self.loss_tracker.result(), <span class=\"hljs-string\">\"acc\"<\/span>: self.acc_tracker.result()}\n\n<span class=\"hljs-meta\">  @property<\/span>\n  <span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">metrics<\/span>(<span class=\"hljs-params\">self<\/span>):\n    <span class=\"hljs-comment\"># We need to list our metrics here so the `reset_states()` can be<\/span>\n    <span class=\"hljs-comment\"># called automatically.<\/span>\n    <span class=\"hljs-keyword\">return<\/span> [self.loss_tracker, self.acc_tracker]<\/span><\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"c77d\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">cnn_model = get_cnn_model()\nencoder = Encoder(embedding_dim=<span class=\"hljs-number\">512<\/span>, dense_dim=<span class=\"hljs-number\">512<\/span>, num_heads=<span class=\"hljs-number\">1<\/span>)\ndecoder = Decoder(embedding_dim=<span class=\"hljs-number\">512<\/span>, ff_dim=<span class=\"hljs-number\">512<\/span>, num_heads=<span class=\"hljs-number\">2<\/span>)\ncaption_model = ImageCaptioningModel(\n    cnn_model=cnn_model,\n    encoder=encoder,\n    decoder=decoder\n)<\/span><\/pre>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"98ae\">Train the Model<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"4100\">Since we have successfully implemented the model architecture, it is time to train it on the training data. We will monitor the model&#8217;s validation loss to gauge its performance. We do this by defining an <code class=\"ef abr abs abt abu b\">EarlyStopping<\/code> callback, which will stop the training if the model does not improve for three consecutive epochs (the model is overfitting).<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"6c07\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-comment\"># Define the loss function<\/span>\ncross_entropy = keras.losses.SparseCategoricalCrossentropy(\n    from_logits=<span class=\"hljs-literal\">False<\/span>, reduction=<span class=\"hljs-string\">\"none\"<\/span>\n)\n\n<span class=\"hljs-comment\"># EarlyStopping criteria<\/span>\nearly_stopping = keras.callbacks.EarlyStopping(\n    patience=<span class=\"hljs-number\">3<\/span>,\n    restore_best_weights=<span class=\"hljs-literal\">True<\/span>\n    )\n\n<span class=\"hljs-comment\"># Compile the model<\/span>\ncaption_model.<span class=\"hljs-built_in\">compile<\/span>(\n    optimizer=keras.optimizers.Adam(learning_rate=<span class=\"hljs-number\">1e-4<\/span>),\n    loss=cross_entropy)\n\n<span class=\"hljs-comment\"># Fit the model<\/span>\ncaption_model.fit(\n    training_data,\n    epochs=EPOCHS,\n    validation_data=validation_data,\n    callbacks=[early_stopping],\n)<\/span><\/pre>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"25b7\">The accuracies and the losses at each training epoch.<\/p>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*7yN9bNbekWFnoi8jmupEJQ.png\" alt=\"Model training\"\/><figcaption class=\"wp-element-caption\">Model training<\/figcaption><\/figure>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"a67a\">Generating Captions<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"6149\">Finally, it&#8217;s time to predict captions for images using the trained Image captioning model. To caption an image with this model:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>Retrieve vocabulary from the training step and map each token position back to their corresponding words in the vocabulary.<\/li>\n\n\n\n<li>We will select a random image and its image features from the CNN model.<\/li>\n\n\n\n<li>Pass the image features to the encoder for encoding.<\/li>\n\n\n\n<li>Create a caption generation loop that generates tokens from <code class=\"ef abr abs abt abu b\">&lt;start><\/code> of the caption until the maximum decoded sentence length is reached or the end token <code class=\"ef abr abs abt abu b\">&lt;end><\/code> is generated.<br><br>&#8211; Tokenize each caption with the vectorization layer.<br><br>&#8211; Use the decoder to predict the next token in the sequence based on the encoded image features.<\/li>\n<\/ul>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"34b9\">The model uses the Kangas <code class=\"ef abr abs abt abu b\">Image()<\/code> class to view each randomly selected image.<\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"1a0c\">So, let&#8217;s add a &#8220;simple&#8221; method to do just that:<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"11b6\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">vocab = vectorization.get_vocabulary()\nindex_lookup = <span class=\"hljs-built_in\">dict<\/span>(<span class=\"hljs-built_in\">zip<\/span>(<span class=\"hljs-built_in\">range<\/span>(<span class=\"hljs-built_in\">len<\/span>(vocab)), vocab))\nmax_decoded_sentence_length = sequenceLength - <span class=\"hljs-number\">1<\/span>\nvalid_images = <span class=\"hljs-built_in\">list<\/span>(validation_set.keys())\n\n\n<span class=\"hljs-keyword\">def<\/span> <span class=\"hljs-title.function\">generate_caption<\/span>():\n    <span class=\"hljs-comment\"># Select a random image from the validation dataset<\/span>\n    sample_img = np.random.choice(valid_images)\n\n    <span class=\"hljs-comment\"># Read the image from the disk<\/span>\n    sample_img = decode_and_resize(sample_img)\n    img = sample_img.numpy().clip(<span class=\"hljs-number\">0<\/span>, <span class=\"hljs-number\">255<\/span>).astype(np.uint8)\n    kg.Image(img).show()\n\n    <span class=\"hljs-comment\"># Pass the image to the CNN<\/span>\n    img = tf.expand_dims(sample_img, <span class=\"hljs-number\">0<\/span>)\n    img = caption_model.cnn_model(img)\n\n    <span class=\"hljs-comment\"># Pass the image features to the Transformer encoder<\/span>\n    encoded_img = caption_model.encoder(img, training=<span class=\"hljs-literal\">False<\/span>)\n\n    <span class=\"hljs-comment\"># Generate the caption using the Transformer decoder<\/span>\n    decoded_caption = <span class=\"hljs-string\">\"&lt;start&gt; \"<\/span>\n    <span class=\"hljs-keyword\">for<\/span> i <span class=\"hljs-keyword\">in<\/span> <span class=\"hljs-built_in\">range<\/span>(max_decoded_sentence_length):\n        tokenized_caption = vectorization([decoded_caption])[:, :-<span class=\"hljs-number\">1<\/span>]\n        mask = tf.math.not_equal(tokenized_caption, <span class=\"hljs-number\">0<\/span>)\n        predictions = caption_model.decoder(\n            tokenized_caption, encoded_img, training=<span class=\"hljs-literal\">False<\/span>, mask=mask\n        )\n        sampled_token_index = np.argmax(predictions[<span class=\"hljs-number\">0<\/span>, i, :])\n        sampled_token = index_lookup[sampled_token_index]\n        <span class=\"hljs-keyword\">if<\/span> sampled_token == <span class=\"hljs-string\">\"&lt;end&gt;\"<\/span>:\n            <span class=\"hljs-keyword\">break<\/span>\n        decoded_caption += <span class=\"hljs-string\">\" \"<\/span> + sampled_token\n\n    decoded_caption = decoded_caption.replace(<span class=\"hljs-string\">\"&lt;start&gt; \"<\/span>, <span class=\"hljs-string\">\"\"<\/span>)\n    decoded_caption = decoded_caption.replace(<span class=\"hljs-string\">\" &lt;end&gt;\"<\/span>, <span class=\"hljs-string\">\"\"<\/span>).strip().capitalize()\n    <span class=\"hljs-built_in\">print<\/span>(<span class=\"hljs-string\">\"PREDICTED CAPTION: \"<\/span>, decoded_caption)\n\n\n<span class=\"hljs-comment\"># Check predictions for a few samples<\/span>\ngenerate_caption()\ngenerate_caption()\ngenerate_caption()<\/span><\/pre>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"695d\">Predicted captions for each randomly selected image. We have displayed each image with Kangas.<\/p>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:638\/1*YfZFgqxuwR_dF2x71cj99A.png\" alt=\"Predicted captions for random images: Image Captioning mode.\"\/><figcaption class=\"wp-element-caption\">Predicted captions for random images: Image Captioning mode.<\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"d263\">Perfect!<\/p>\n\n\n\n<h2 class=\"wp-block-heading adn abw ug be abx mb ado mc mf mg adp mh mk ml adq mm mp mq adr mr mu mv ads mw mz adt bj\" id=\"8280\">Visualize the Loss and Accuracy<\/h2>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"0827\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\">plt.plot(caption_model.history.history[<span class=\"hljs-string\">'loss'<\/span>], label=<span class=\"hljs-string\">'loss'<\/span>)\nplt.plot(caption_model.history.history[<span class=\"hljs-string\">'val_loss'<\/span>], label=<span class=\"hljs-string\">'val_loss'<\/span>)\nplt.ylim([<span class=\"hljs-number\">0<\/span>, <span class=\"hljs-built_in\">max<\/span>(plt.ylim())])\nplt.xlabel(<span class=\"hljs-string\">'Epochs'<\/span>)\nplt.ylabel(<span class=\"hljs-string\">'CE\/token'<\/span>)\nplt.legend()<\/span><\/pre>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"4a84\">Loss:<\/p>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:588\/1*PttjYUgfn6wNZOGSlZqEiQ.png\" alt=\"model loss line graph\"\/><figcaption class=\"wp-element-caption\">Model loss<\/figcaption><\/figure>\n\n\n\n<p><\/p>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu zs zt zu ux zv zw zx ml zy zz aba mq abb abc abd mv abe abf abg abh er bj\" id=\"ee1d\">Accuracy:<\/p>\n\n\n\n<pre class=\"wp-block-preformatted\"><span id=\"589b\" class=\"acy abw ug abu b bf acz ada l adb adc\" data-selectable-paragraph=\"\"><span class=\"hljs-selector-tag\">plt<\/span><span class=\"hljs-selector-class\">.plot<\/span>(caption_model.history.history[<span class=\"hljs-string\">'val_acc'<\/span>], label=<span class=\"hljs-string\">'val_accuracy'<\/span>)\n<span class=\"hljs-selector-tag\">plt<\/span><span class=\"hljs-selector-class\">.plot<\/span>(caption_model.history.history[<span class=\"hljs-string\">'acc'<\/span>], label=<span class=\"hljs-string\">'accuracy'<\/span>)\n<span class=\"hljs-selector-tag\">plt<\/span><span class=\"hljs-selector-class\">.ylim<\/span>([<span class=\"hljs-number\">0<\/span>, <span class=\"hljs-built_in\">max<\/span>(plt.<span class=\"hljs-built_in\">ylim<\/span>())])\n<span class=\"hljs-selector-tag\">plt<\/span><span class=\"hljs-selector-class\">.xlabel<\/span>(<span class=\"hljs-string\">'Epochs'<\/span>)\n<span class=\"hljs-selector-tag\">plt<\/span><span class=\"hljs-selector-class\">.ylabel<\/span>(<span class=\"hljs-string\">'CE\/token'<\/span>)\n<span class=\"hljs-selector-tag\">plt<\/span><span class=\"hljs-selector-class\">.legend<\/span>()<\/span><\/pre>\n\n\n\n<figure class=\"wp-block-image zb zc zd ze zf zg lk ll paragraph-image\"><img decoding=\"async\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:581\/1*WAFVSpcZvbebOhAWyr6oPg.png\" alt=\"model accuracy line graph\"\/><figcaption class=\"wp-element-caption\">Model accuracy<\/figcaption><\/figure>\n\n\n\n<h2 class=\"wp-block-heading abv abw ug be abx aby abz uw mf aca acb uz mk acc acd ace acf acg ach aci acj ack acl acm acn aco bj\" id=\"86a1\">Final Thoughts<\/h2>\n\n\n\n<p class=\"pw-post-body-paragraph zq zr ug be b uu acp zt zu ux acq zw zx ml acr zz aba mq acs abc abd mv act abf abg abh er bj\" id=\"0cbd\">In this piece, we have learned to generate image captions with TensorFlow and Transformer based encoder and decoder. We have learned:<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>How to visualize image data with Kangas and using the Kangas UI.<\/li>\n\n\n\n<li>How to preprocess image and caption data for proper model compatibility.<\/li>\n\n\n\n<li>How to create an image captioning model.<\/li>\n<\/ul>\n","protected":false},"excerpt":{"rendered":"<p>Image captioning is a compelling field that connects computer vision and natural language processing, enabling machines to generate textual descriptions of visual content. In an era dominated by visual content, the ability of machines to understand and describe images is a powerful stride towards human-like intelligence. This article will explore image captioning using TensorFlow. We [&hellip;]<\/p>\n","protected":false},"author":108,"featured_media":0,"comment_status":"closed","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"customer_name":"","customer_description":"","customer_industry":"","customer_technologies":"","customer_logo":"","footnotes":""},"categories":[7],"tags":[],"coauthors":[206],"class_list":["post-8416","post","type-post","status-publish","format-standard","hentry","category-tutorials"],"yoast_head":"<!-- This site is optimized with the Yoast SEO Premium plugin v25.9 (Yoast SEO v25.9) - https:\/\/yoast.com\/wordpress\/plugins\/seo\/ -->\n<title>Image Captioning Model for Image Visualization<\/title>\n<meta name=\"description\" content=\"Explore an image captioning model using TensorFlow, highlighting the critical steps involved in this blog article.\" \/>\n<meta name=\"robots\" content=\"index, follow, max-snippet:-1, max-image-preview:large, max-video-preview:-1\" \/>\n<link rel=\"canonical\" href=\"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization\" \/>\n<meta property=\"og:locale\" content=\"en_US\" \/>\n<meta property=\"og:type\" content=\"article\" \/>\n<meta property=\"og:title\" content=\"Image Captioning Model with TensorFlow, Transformers, and Kangas for Image Visualization\" \/>\n<meta property=\"og:description\" content=\"Explore an image captioning model using TensorFlow, highlighting the critical steps involved in this blog article.\" \/>\n<meta property=\"og:url\" content=\"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization\" \/>\n<meta property=\"og:site_name\" content=\"Comet\" \/>\n<meta property=\"article:publisher\" content=\"https:\/\/www.facebook.com\/cometdotml\" \/>\n<meta property=\"article:published_time\" content=\"2023-12-11T20:21:08+00:00\" \/>\n<meta property=\"article:modified_time\" content=\"2025-04-24T17:03:55+00:00\" \/>\n<meta property=\"og:image\" content=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*7tQcn_GPZzHb7fj22o7xDg.jpeg\" \/>\n<meta name=\"author\" content=\"Brian Mutea\" \/>\n<meta name=\"twitter:card\" content=\"summary_large_image\" \/>\n<meta name=\"twitter:creator\" content=\"@Cometml\" \/>\n<meta name=\"twitter:site\" content=\"@Cometml\" \/>\n<meta name=\"twitter:label1\" content=\"Written by\" \/>\n\t<meta name=\"twitter:data1\" content=\"Brian Mutea\" \/>\n\t<meta name=\"twitter:label2\" content=\"Est. reading time\" \/>\n\t<meta name=\"twitter:data2\" content=\"22 minutes\" \/>\n<!-- \/ Yoast SEO Premium plugin. -->","yoast_head_json":{"title":"Image Captioning Model for Image Visualization","description":"Explore an image captioning model using TensorFlow, highlighting the critical steps involved in this blog article.","robots":{"index":"index","follow":"follow","max-snippet":"max-snippet:-1","max-image-preview":"max-image-preview:large","max-video-preview":"max-video-preview:-1"},"canonical":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization","og_locale":"en_US","og_type":"article","og_title":"Image Captioning Model with TensorFlow, Transformers, and Kangas for Image Visualization","og_description":"Explore an image captioning model using TensorFlow, highlighting the critical steps involved in this blog article.","og_url":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization","og_site_name":"Comet","article_publisher":"https:\/\/www.facebook.com\/cometdotml","article_published_time":"2023-12-11T20:21:08+00:00","article_modified_time":"2025-04-24T17:03:55+00:00","og_image":[{"url":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*7tQcn_GPZzHb7fj22o7xDg.jpeg","type":"","width":"","height":""}],"author":"Brian Mutea","twitter_card":"summary_large_image","twitter_creator":"@Cometml","twitter_site":"@Cometml","twitter_misc":{"Written by":"Brian Mutea","Est. reading time":"22 minutes"},"schema":{"@context":"https:\/\/schema.org","@graph":[{"@type":"Article","@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization#article","isPartOf":{"@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization\/"},"author":{"name":"Brian Mutea","@id":"https:\/\/www.comet.com\/site\/#\/schema\/person\/45acdda6535e03a9542e665f23953c3b"},"headline":"Image Captioning Model with TensorFlow, Transformers, and Kangas for Image Visualization","datePublished":"2023-12-11T20:21:08+00:00","dateModified":"2025-04-24T17:03:55+00:00","mainEntityOfPage":{"@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization\/"},"wordCount":2455,"publisher":{"@id":"https:\/\/www.comet.com\/site\/#organization"},"image":{"@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization#primaryimage"},"thumbnailUrl":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*7tQcn_GPZzHb7fj22o7xDg.jpeg","articleSection":["Tutorials"],"inLanguage":"en-US"},{"@type":"WebPage","@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization\/","url":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization","name":"Image Captioning Model for Image Visualization","isPartOf":{"@id":"https:\/\/www.comet.com\/site\/#website"},"primaryImageOfPage":{"@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization#primaryimage"},"image":{"@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization#primaryimage"},"thumbnailUrl":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*7tQcn_GPZzHb7fj22o7xDg.jpeg","datePublished":"2023-12-11T20:21:08+00:00","dateModified":"2025-04-24T17:03:55+00:00","description":"Explore an image captioning model using TensorFlow, highlighting the critical steps involved in this blog article.","breadcrumb":{"@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization#breadcrumb"},"inLanguage":"en-US","potentialAction":[{"@type":"ReadAction","target":["https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization"]}]},{"@type":"ImageObject","inLanguage":"en-US","@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization#primaryimage","url":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*7tQcn_GPZzHb7fj22o7xDg.jpeg","contentUrl":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*7tQcn_GPZzHb7fj22o7xDg.jpeg"},{"@type":"BreadcrumbList","@id":"https:\/\/www.comet.com\/site\/blog\/image-captioning-model-with-tensorflow-transformers-and-kangas-for-image-visualization#breadcrumb","itemListElement":[{"@type":"ListItem","position":1,"name":"Home","item":"https:\/\/www.comet.com\/site\/"},{"@type":"ListItem","position":2,"name":"Image Captioning Model with TensorFlow, Transformers, and Kangas for Image Visualization"}]},{"@type":"WebSite","@id":"https:\/\/www.comet.com\/site\/#website","url":"https:\/\/www.comet.com\/site\/","name":"Comet","description":"Build Better Models Faster","publisher":{"@id":"https:\/\/www.comet.com\/site\/#organization"},"potentialAction":[{"@type":"SearchAction","target":{"@type":"EntryPoint","urlTemplate":"https:\/\/www.comet.com\/site\/?s={search_term_string}"},"query-input":{"@type":"PropertyValueSpecification","valueRequired":true,"valueName":"search_term_string"}}],"inLanguage":"en-US"},{"@type":"Organization","@id":"https:\/\/www.comet.com\/site\/#organization","name":"Comet ML, Inc.","alternateName":"Comet","url":"https:\/\/www.comet.com\/site\/","logo":{"@type":"ImageObject","inLanguage":"en-US","@id":"https:\/\/www.comet.com\/site\/#\/schema\/logo\/image\/","url":"https:\/\/www.comet.com\/site\/wp-content\/uploads\/2025\/01\/logo_comet_square.png","contentUrl":"https:\/\/www.comet.com\/site\/wp-content\/uploads\/2025\/01\/logo_comet_square.png","width":310,"height":310,"caption":"Comet ML, Inc."},"image":{"@id":"https:\/\/www.comet.com\/site\/#\/schema\/logo\/image\/"},"sameAs":["https:\/\/www.facebook.com\/cometdotml","https:\/\/x.com\/Cometml","https:\/\/www.youtube.com\/channel\/UCmN63HKvfXSCS-UwVwmK8Hw"]},{"@type":"Person","@id":"https:\/\/www.comet.com\/site\/#\/schema\/person\/45acdda6535e03a9542e665f23953c3b","name":"Brian Mutea","image":{"@type":"ImageObject","inLanguage":"en-US","@id":"https:\/\/www.comet.com\/site\/#\/schema\/person\/image\/0008644e1041f4f2e48e3566c59bc055","url":"https:\/\/www.comet.com\/site\/wp-content\/uploads\/2023\/11\/1652705747012-96x96.jpg","contentUrl":"https:\/\/www.comet.com\/site\/wp-content\/uploads\/2023\/11\/1652705747012-96x96.jpg","caption":"Brian Mutea"},"url":"https:\/\/www.comet.com\/site\/blog\/author\/brianmuteakgmail-com\/"}]}},"_links":{"self":[{"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/posts\/8416","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/users\/108"}],"replies":[{"embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/comments?post=8416"}],"version-history":[{"count":1,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/posts\/8416\/revisions"}],"predecessor-version":[{"id":15423,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/posts\/8416\/revisions\/15423"}],"wp:attachment":[{"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/media?parent=8416"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/categories?post=8416"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/tags?post=8416"},{"taxonomy":"author","embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/coauthors?post=8416"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}