{"id":7611,"date":"2023-09-22T11:45:28","date_gmt":"2023-09-22T19:45:28","guid":{"rendered":"https:\/\/live-cometml.pantheonsite.io\/?p=7611"},"modified":"2025-04-24T17:13:56","modified_gmt":"2025-04-24T17:13:56","slug":"tracking-jax-and-flax-models-with-comet","status":"publish","type":"post","link":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/","title":{"rendered":"Tracking JAX and Flax models with Comet"},"content":{"rendered":"\n<link rel=\"canonical\" href=\"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\">\n\n\n\n<div class=\"fh fi fj fk fl\">\n<div class=\"ab ca\">\n<div class=\"ch bg et eu ev ew\">\n<figure class=\"lw lx ly lz ma mb lt lu paragraph-image\">\n<div class=\"mc md eb me bg mf\" tabindex=\"0\" role=\"button\">\n<figure><img loading=\"lazy\" decoding=\"async\" class=\"bg mg mh c\" role=\"presentation\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*WpEtFvcW7jxALJ5dXz44BQ.png\" alt=\"\" width=\"700\" height=\"394\"><\/figure><div class=\"lt lu lv\"><picture><\/picture><\/div>\n<\/div>\n<\/figure>\n<p id=\"7b08\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\"><a class=\"af nf\" href=\"https:\/\/jax.readthedocs.io\/en\/latest\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">JAX<\/a> is a <a class=\"af nf\" href=\"https:\/\/www.machinelearningnuggets.com\/python-for-data-science\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">Python<\/a> library offering high performance in machine learning with <a class=\"af nf\" href=\"https:\/\/www.tensorflow.org\/xla\" target=\"_blank\" rel=\"noopener ugc nofollow\">XLA<\/a> and Just In Time (JIT) compilation. Its API is similar to <a class=\"af nf\" href=\"https:\/\/www.machinelearningnuggets.com\/numpy-tutorial\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">NumPy\u2019s<\/a> with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include:<\/p>\n<ul class=\"\">\n<li id=\"6cf8\" class=\"mi mj fo be b mk ml mm mn mo mp mq mr ms ng mu mv mw nh my mz na ni nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Automatic differentiation<\/li>\n<li id=\"053e\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Vectorization<\/li>\n<li id=\"a09e\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">JIT compilation<\/li>\n<\/ul>\n<p id=\"7869\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Flax is a neural network library for JAX. This article will cover how to track JAX and Flax models with Comet.<\/p>\n<p id=\"8f70\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Let\u2019s get started.<\/p>\n<\/div>\n<\/div>\n<\/div>\n\n\n\n<div class=\"fh fi fj fk fl\">\n<div class=\"ab ca\">\n<div class=\"ch bg et eu ev ew\">\n<h1 id=\"79b6\" class=\"nz oa fo be ob oc od oe of og oh oi oj ok ol om on oo op oq or os ot ou ov ow bj\" data-selectable-paragraph=\"\">Getting started<\/h1>\n<p id=\"fc91\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">Start by installing <a class=\"af nf\" href=\"https:\/\/www.comet.com\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">Comet<\/a>.<\/p>\n<pre class=\"pc pd pe pf pg ph pi pj pk ax pl bj\"><span id=\"71a6\" class=\"pm oa fo pi b ho pn po l ie pp\" data-selectable-paragraph=\"\">pip install comet_ml<\/span><\/pre>\n<p id=\"194e\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Next, install JAX and Flax.<\/p>\n<pre class=\"pc pd pe pf pg ph pi pj pk ax pl bj\"><span id=\"1d2a\" class=\"pm oa fo pi b ho pn po l ie pp\" data-selectable-paragraph=\"\">pip install -q jax jaxlib flax<\/span><\/pre>\n<p id=\"757f\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Import the libraries you\u2019ll use in this project.<\/p>\n<pre>import comet_ml\nimport jax\nimport jax.numpy as jnp                # JAX NumPy\n\nfrom flax import linen as nn           # The Linen API\nfrom flax.training import train_state  # Useful dataclass to keep train state\n\nimport numpy as np                     # Ordinary NumPy\nimport optax                           # Optimizers\nimport tensorflow_datasets as tfds     # TFDS for MNIST<\/pre>\n<h1 id=\"b2c7\" class=\"nz oa fo be ob oc pt oe of og pu oi oj ok pv om on oo pw oq or os px ou ov ow bj\" data-selectable-paragraph=\"\">Log parameters<\/h1>\n<p id=\"01b9\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">In this project, you\u2019ll build a simple <a class=\"af nf\" href=\"https:\/\/www.machinelearningnuggets.com\/cnn-tensorflow\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">Convolutional Neural Network<\/a> using the MNIST dataset. Define the network parameters and log them to Comet.<\/p>\n<p id=\"4528\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">To achieve that, create a Comet experiment.<\/p>\n<pre>experiment = comet_ml.Experiment(\n    api_key=\"YOUR_API_KEY\",\n    project_name=\"JAX_Flax_CNN\", log_code=True)<\/pre>\n<p id=\"7541\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Next, use this experiment to log the network metrics.<\/p>\n<pre># these will all get logged\nparams = {\n    \"features_1\": 32,\n    \"kernel_size\": 3,\n    \"window_shape\": 2,\n    \"categories\": 10,\n    \"features_2\": 64,\n    \"features_3\":256,\n    \"strides\": 2,\n    \"cross_entropy_loss\": \"softmax_cross_entropy\",\n    \"dataset\": \"MNIST\"\n}\n\nexperiment.log_parameters(params)<\/pre>\n<h1 id=\"7b46\" class=\"nz oa fo be ob oc pt oe of og pu oi oj ok pv om on oo pw oq or os px ou ov ow bj\" data-selectable-paragraph=\"\">Flax network definition<\/h1>\n<p id=\"995f\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">In Flax, networks are defined using the <a class=\"af nf\" href=\"https:\/\/flax.readthedocs.io\/en\/latest\/api_reference\/flax.linen.html?highlight=linen\" target=\"_blank\" rel=\"noopener ugc nofollow\">Linen package<\/a>. Define a simple <a class=\"af nf\" href=\"https:\/\/www.machinelearningnuggets.com\/cnn-tensorflow\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">CNN network<\/a> using the parameters defined above.<\/p>\n<p id=\"76c5\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Flax networks are defined explicitly using <code class=\"cw py pz qa pi b\"><a class=\"af nf\" href=\"https:\/\/flax.readthedocs.io\/en\/latest\/notebooks\/optax_update_guide.html?highlight=setup\" target=\"_blank\" rel=\"noopener ugc nofollow\">setup<\/a><\/code> or <code class=\"cw py pz qa pi b\">inline<\/code> using <code class=\"cw py pz qa pi b\"><a class=\"af nf\" href=\"https:\/\/flax.readthedocs.io\/en\/latest\/api_reference\/flax.linen.html?highlight=compact#flax.linen.compact\" target=\"_blank\" rel=\"noopener ugc nofollow\">nn.compact<\/a><\/code>.<\/p>\n<p data-selectable-paragraph=\"\"><strong><a href=\"https:\/\/flax.readthedocs.io\/en\/latest\/guides\/setup_or_nncompact.html?source=post_page-----9adad64608da--------------------------------\">setup vs compact<\/a><\/strong><\/p>\n<div class=\"qb qc qd qe qf qg\">\n<div class=\"qh ab hy\">\n<div class=\"qi ab cn ca qj qk\">\n<div class=\"qo l\">\n<pre>class CNN(nn.Module):\n  \"\"\"A simple CNN model.\"\"\"\n\n  @nn.compact\n  def __call__(self, x):\n    x = nn.Conv(features=params['features_1'], kernel_size=(params['kernel_size'], params['kernel_size']))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(params['window_shape'], params['window_shape']), strides=(params['strides'], params['strides']))\n    x = nn.Conv(features=params['features_2'], kernel_size=(params['kernel_size'], params['kernel_size']))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(params['window_shape'], params['window_shape']), strides=(params['strides'], params['strides']))\n    x = x.reshape((x.shape[0], -1))  # flatten\n    x = nn.Dense(features=params['features_3'])(x)\n    x = nn.relu(x)\n    x = nn.Dense(features=params['categories'])(x)\n    return x<\/pre>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<h1 id=\"a0fe\" class=\"nz oa fo be ob oc pt oe of og pu oi oj ok pv om on oo pw oq or os px ou ov ow bj\" data-selectable-paragraph=\"\">Compute metrics<\/h1>\n<p id=\"e0a2\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">Next, define the metrics used to compute loss and accuracy during training. In JAX, we compute the loss using the <a class=\"af nf\" href=\"https:\/\/optax.readthedocs.io\/en\/latest\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">Optax library<\/a>.<\/p>\n<pre>def cross_entropy_loss(*, logits, labels):\n  labels_onehot = jax.nn.one_hot(labels, num_classes=params['categories'])\n  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()<\/pre>\n<p id=\"b94d\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">The <code class=\"cw py pz qa pi b\">compute_metrics<\/code> function will calculate and return the loss and accuracy.<\/p>\n<pre>def compute_metrics(*, logits, labels):\n  loss = cross_entropy_loss(logits=logits, labels=labels)\n  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n  metrics = {\n      'loss': loss,\n      'accuracy': accuracy,\n  }\n  return metrics<\/pre>\n<h1 id=\"8950\" class=\"nz oa fo be ob oc pt oe of og pu oi oj ok pv om on oo pw oq or os px ou ov ow bj\" data-selectable-paragraph=\"\">Loading data in JAX<\/h1>\n<p id=\"d0b1\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">JAX and Flax don\u2019t ship with data loaders. Therefore, you have to use data loaders from <a class=\"af nf\" href=\"https:\/\/www.machinelearningnuggets.com\/tag\/tensorflow\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">TensorFlow<\/a> and <a class=\"af nf\" href=\"https:\/\/pytorch.org\/tutorials\/beginner\/basics\/data_tutorial.html\" target=\"_blank\" rel=\"noopener ugc nofollow\">PyTorch<\/a>.<\/p>\n<p id=\"3d35\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">In this case, we load the dataset using TensorFlow.<\/p>\n<pre>def get_datasets():\n  \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n  ds_builder = tfds.builder('mnist')\n  ds_builder.download_and_prepare()\n  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n  train_ds['image'] = jnp.float32(train_ds['image']) \/ 255.\n  test_ds['image'] = jnp.float32(test_ds['image']) \/ 255.\n  return train_ds, test_ds<\/pre>\n<h1 id=\"c2bb\" class=\"nz oa fo be ob oc pt oe of og pu oi oj ok pv om on oo pw oq or os px ou ov ow bj\" data-selectable-paragraph=\"\">Create Flax training state<\/h1>\n<p id=\"a153\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">In Flax, we create a training state to store all the training information, such as parameters and the optimizer state. This is achieved using <code class=\"cw py pz qa pi b\"><a class=\"af nf\" href=\"https:\/\/flax.readthedocs.io\/en\/latest\/api_reference\/flax.training.html#train-state\" target=\"_blank\" rel=\"noopener ugc nofollow\">train_state<\/a><\/code> from Flax. In the training state function:<\/p>\n<ul class=\"\">\n<li id=\"aeaa\" class=\"mi mj fo be b mk ml mm mn mo mp mq mr ms ng mu mv mw nh my mz na ni nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Create an instance of the network.<\/li>\n<li id=\"c5d3\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Call the <code class=\"cw py pz qa pi b\">init<\/code> method to obtain network parameters by passing a sample data point.<\/li>\n<li id=\"ba60\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Return the training state by <a class=\"af nf\" href=\"https:\/\/flax.readthedocs.io\/en\/latest\/api_reference\/flax.linen.html?highlight=apply#flax.linen.apply\" target=\"_blank\" rel=\"noopener ugc nofollow\">applying<\/a> the model while passing the parameters and the optimizer.<\/li>\n<\/ul>\n<pre>def create_train_state(rng, learning_rate, momentum):\n  \"\"\"Creates initial `TrainState`.\"\"\"\n  cnn = CNN()\n  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']\n  tx = optax.sgd(learning_rate, momentum)\n  return train_state.TrainState.create(\n      apply_fn=cnn.apply, params=params, tx=tx)<\/pre>\n<figure class=\"pc pd pe pf pg mb\"><\/figure>\n<h1 id=\"3167\" class=\"nz oa fo be ob oc pt oe of og pu oi oj ok pv om on oo pw oq or os px ou ov ow bj\" data-selectable-paragraph=\"\">Network training<\/h1>\n<p id=\"7102\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">The next step is to define the model training functions. Let\u2019s start by defining a function that will train the network for one step.<\/p>\n<p id=\"3617\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">The function:<\/p>\n<ul class=\"\">\n<li id=\"51c4\" class=\"mi mj fo be b mk ml mm mn mo mp mq mr ms ng mu mv mw nh my mz na ni nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Applies the network to a batch of data and computes the loss and logits.<\/li>\n<li id=\"2c69\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Computes gradients with respect to the loss.<\/li>\n<li id=\"8627\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Applies the gradients to obtain a new state.<\/li>\n<li id=\"1daf\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Computes and returns the model metrics.<\/li>\n<\/ul>\n<p id=\"ef3e\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Applying <code class=\"cw py pz qa pi b\"><a class=\"af nf\" href=\"https:\/\/jax.readthedocs.io\/en\/latest\/notebooks\/thinking_in_jax.html?highlight=jit#to-jit-or-not-to-jit\" target=\"_blank\" rel=\"noopener ugc nofollow\">jax.jit<\/a><\/code> makes the function run faster.<\/p>\n<pre>@jax.jit\ndef train_step(state, batch):\n  \"\"\"Train for a single step.\"\"\"\n  def loss_fn(params):\n    logits = CNN().apply({'params': params}, batch['image'])\n    loss = cross_entropy_loss(logits=logits, labels=batch['label'])\n    return loss, logits\n  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n  (_, logits), grads = grad_fn(state.params)\n  state = state.apply_gradients(grads=grads)\n  metrics = compute_metrics(logits=logits, labels=batch['label'])\n  return state, metrics<\/pre>\n<p id=\"3df2\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Next, define a training function that applies the above training step. The function:<\/p>\n<ul class=\"\">\n<li id=\"dd7f\" class=\"mi mj fo be b mk ml mm mn mo mp mq mr ms ng mu mv mw nh my mz na ni nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Shuffles the training data.<\/li>\n<li id=\"5785\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Runs the training step for each batch.<\/li>\n<li id=\"21b2\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Obtains the training metrics from the device using <code class=\"cw py pz qa pi b\">jax.device_get<\/code><\/li>\n<li id=\"e812\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Computes the mean of the metrics from each batch.<\/li>\n<li id=\"4883\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Returns the new state together with the model metrics.<\/li>\n<\/ul>\n<pre>def train_epoch(state, train_ds, batch_size, epoch, rng):\n  \"\"\"Train for a single epoch.\"\"\"\n  train_ds_size = len(train_ds['image'])\n  steps_per_epoch = train_ds_size \/\/ batch_size\n\n  perms = jax.random.permutation(rng, train_ds_size)\n  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch\n  perms = perms.reshape((steps_per_epoch, batch_size))\n  batch_metrics = []\n  for perm in perms:\n    batch = {k: v[perm, ...] for k, v in train_ds.items()}\n    state, metrics = train_step(state, batch)\n    batch_metrics.append(metrics)\n\n  # compute mean of metrics across each batch in epoch.\n  batch_metrics_np = jax.device_get(batch_metrics)\n  epoch_metrics_np = {\n      k: np.mean([metrics[k] for metrics in batch_metrics_np])\n      for k in batch_metrics_np[0]}\n\n\n\n  return state, epoch_metrics_np['loss'],epoch_metrics_np['accuracy'] * 100<\/pre>\n<figure class=\"pc pd pe pf pg mb\"><\/figure>\n<h1 id=\"ce61\" class=\"nz oa fo be ob oc pt oe of og pu oi oj ok pv om on oo pw oq or os px ou ov ow bj\" data-selectable-paragraph=\"\">Network evaluation<\/h1>\n<p id=\"ca45\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">The evaluation step checks the performance of the network on the testing data.<\/p>\n<pre>@jax.jit\ndef eval_step(params, batch):\n  logits = CNN().apply({'params': params}, batch['image'])\n  return compute_metrics(logits=logits, labels=batch['label'])<\/pre>\n<p id=\"f3aa\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Next, define a function that applies the evaluation step to the testing data. The function also obtains the evaluation metrics from the device.<\/p>\n<pre>def eval_model(params, test_ds):\n  metrics = eval_step(params, test_ds)\n  metrics = jax.device_get(metrics)\n  summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)\n  return summary['loss'], summary['accuracy']<\/pre>\n<\/div>\n<\/div>\n<\/div>\n\n\n\n<div class=\"fh fi fj fk fl\">\n<div class=\"ab ca\">\n<div class=\"ch bg et eu ev ew\">\n<blockquote class=\"qv\"><p id=\"ce2e\" class=\"qw qx fo be qy qz ra rb rc rd re ne dv\" data-selectable-paragraph=\"\">Want to try Comet for yourself? <a class=\"af nf\" href=\"\/signup?utm_source=heartbeat&amp;utm_medium=referral&amp;utm_campaign=AMS_US_EN_SNUP_heartbeat_CTA\" target=\"_blank\" rel=\"noopener ugc nofollow\">Sign up for a free account today!<\/a><\/p><\/blockquote>\n<\/div>\n<\/div>\n<\/div>\n\n\n\n<div class=\"fh fi fj fk fl\">\n<div class=\"ab ca\">\n<div class=\"ch bg et eu ev ew\">\n<h1 id=\"a8e2\" class=\"nz oa fo be ob oc od oe of og oh oi oj ok ol om on oo op oq or os ot ou ov ow bj\" data-selectable-paragraph=\"\">Train the Flax network<\/h1>\n<p id=\"cd81\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">You now have all the required building blocks for training the Flax CNN.<\/p>\n<h2 id=\"76e0\" class=\"pm oa fo be ob rf rg rh of ri rj rk oj ms rl rm rn mw ro rp rq na rr rs rt ru bj\" data-selectable-paragraph=\"\">Download data<\/h2>\n<p id=\"da20\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">Start by downloading the data.<\/p>\n<pre>train_ds, test_ds = get_datasets()<\/pre>\n<h2 id=\"ee16\" class=\"pm oa fo be ob rf rg rh of ri rj rk oj ms rl rm rn mw ro rp rq na rr rs rt ru bj\" data-selectable-paragraph=\"\">Set random seed<\/h2>\n<p id=\"a7d4\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">JAX requires pure functions. This means that they should not have any side effects. Therefore, even the random seed generation should be reproducible. These numbers are generated using <code class=\"cw py pz qa pi b\"><a class=\"af nf\" href=\"https:\/\/jax.readthedocs.io\/en\/latest\/_autosummary\/jax.random.PRNGKey.html#jax.random.PRNGKey\" target=\"_blank\" rel=\"noopener ugc nofollow\">jax.random.PRNGKey<\/a><\/code>.<\/p>\n<pre>rng = jax.random.PRNGKey(0)\nrng, init_rng = jax.random.split(rng)<\/pre>\n<div class=\"qb qc qd qe qf qg\">\n<div class=\"qh ab hy\">\n<div class=\"qi ab cn ca qj qk\">\n<div class=\"qo l\">\n<p class=\"be b dw z ig ql ii ij qm il in dv\"><strong><a href=\"https:\/\/jax.readthedocs.io\/en\/latest\/jep\/263-prng.html?source=post_page-----9adad64608da--------------------------------\">JAX PRNG Design<\/a><\/strong><\/p>\n<\/div>\n<\/div>\n<\/div>\n<\/div>\n<h2 id=\"a2d4\" class=\"pm oa fo be ob rf rg rh of ri rj rk oj ms rl rm rn mw ro rp rq na rr rs rt ru bj\" data-selectable-paragraph=\"\">Initialize train state<\/h2>\n<p id=\"4bcf\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">Create a training state using the function defined earlier.<\/p>\n<pre>learning_rate = 0.1\nmomentum = 0.9\n\nstate = create_train_state(init_rng, learning_rate, momentum)\ndel init_rng  # Must not be used anymore.<\/pre>\n<h2 id=\"2057\" class=\"pm oa fo be ob rf rg rh of ri rj rk oj ms rl rm rn mw ro rp rq na rr rs rt ru bj\" data-selectable-paragraph=\"\">Log model metrics<\/h2>\n<p id=\"1a5c\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">The next step is to apply the training function for the desired number of epochs. At each epoch, we:<\/p>\n<ul class=\"\">\n<li id=\"057f\" class=\"mi mj fo be b mk ml mm mn mo mp mq mr ms ng mu mv mw nh my mz na ni nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Obtain the training and test metrics.<\/li>\n<li id=\"55fe\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Log the metric to Comet.<\/li>\n<li id=\"91e8\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Print the metrics to the console.<\/li>\n<\/ul>\n<pre>num_epochs = 10\nbatch_size = 32\n\ntraining_loss = []\ntraining_accuracy = []\ntesting_loss = []\ntesting_accuracy = []\n\n\nfor epoch in range(1, num_epochs + 1):\n  # Use a separate PRNG key to permute image data during shuffling\n  rng, input_rng = jax.random.split(rng)\n  # Run an optimization step over a training batch\n  state, train_loss, train_accuracy = train_epoch(state, train_ds, batch_size, epoch, input_rng)\n  training_loss.append(train_loss)\n  training_accuracy.append(train_accuracy)\n  # Evaluate on the test set after each training epoch\n  test_loss, test_accuracy = eval_model(state.params, test_ds)\n  testing_loss.append(test_loss)\n  testing_accuracy.append(test_accuracy)\n\n  experiment.log_metric(\"train_loss\", train_loss, step=None, epoch=epoch, include_context=True)\n  experiment.log_metric(\"train_accuracy\", train_accuracy, step=None, epoch=epoch, include_context=True)\n  experiment.log_metric(\"test_loss\", test_loss, step=None, epoch=epoch, include_context=True)\n  experiment.log_metric(\"test_accuracy\", test_accuracy, step=None, epoch=epoch, include_context=True)\n  print(f'Epoch {epoch} train loss {train_loss} train accuracy {train_accuracy}. Test_loss {test_loss} test accuracy {test_accuracy}')<\/pre>\n<figure class=\"pc pd pe pf pg mb lt lu paragraph-image\">\n<div class=\"mc md eb me bg mf\" tabindex=\"0\" role=\"button\">\n<figure><img loading=\"lazy\" decoding=\"async\" class=\"bg mg mh c\" role=\"presentation\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*F4atM7B3prIstZgCCZ5umA.png\" alt=\"\" width=\"700\" height=\"129\"><\/figure><div class=\"lt lu rw\"><picture><\/picture><\/div>\n<\/div>\n<\/figure>\n<h2 id=\"abbb\" class=\"pm oa fo be ob rf rg rh of ri rj rk oj ms rl rm rn mw ro rp rq na rr rs rt ru bj\" data-selectable-paragraph=\"\">Log model charts<\/h2>\n<p id=\"d5f1\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">Since the metrics are saved in a list, you can plot the data and log the chart to Comet.<\/p>\n<pre>fig = plt.figure(figsize=(8, 6))\n\nplt.plot(training_loss, label=\"Training\")\nplt.plot(testing_loss, label=\"Test\")\nplt.xlabel(\"Epoch\")\nplt.ylabel(\"Accuracy\")\nplt.legend()\nplt.show()\n\nexperiment.log_figure(figure_name=\"Loss visualization\", figure=fig)<\/pre>\n<p id=\"f14d\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">Don\u2019t forget to end the experiment once you are done<\/p>\n<pre>experiment.end()<\/pre>\n<figure class=\"pc pd pe pf pg mb lt lu paragraph-image\">\n<div class=\"mc md eb me bg mf\" tabindex=\"0\" role=\"button\">\n<figure><img loading=\"lazy\" decoding=\"async\" class=\"bg mg mh c\" role=\"presentation\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*NoAUQgwRbDFkCkSHMnexVQ.png\" alt=\"\" width=\"700\" height=\"428\"><\/figure><div class=\"lt lu rx\"><picture><source srcset=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/format:webp\/1*NoAUQgwRbDFkCkSHMnexVQ.png 640w, https:\/\/miro.medium.com\/v2\/resize:fit:720\/format:webp\/1*NoAUQgwRbDFkCkSHMnexVQ.png 720w, https:\/\/miro.medium.com\/v2\/resize:fit:750\/format:webp\/1*NoAUQgwRbDFkCkSHMnexVQ.png 750w, https:\/\/miro.medium.com\/v2\/resize:fit:786\/format:webp\/1*NoAUQgwRbDFkCkSHMnexVQ.png 786w, https:\/\/miro.medium.com\/v2\/resize:fit:828\/format:webp\/1*NoAUQgwRbDFkCkSHMnexVQ.png 828w, https:\/\/miro.medium.com\/v2\/resize:fit:1100\/format:webp\/1*NoAUQgwRbDFkCkSHMnexVQ.png 1100w, https:\/\/miro.medium.com\/v2\/resize:fit:1400\/format:webp\/1*NoAUQgwRbDFkCkSHMnexVQ.png 1400w\" type=\"image\/webp\" sizes=\"(min-resolution: 4dppx) and (max-width: 700px) 50vw, (-webkit-min-device-pixel-ratio: 4) and (max-width: 700px) 50vw, (min-resolution: 3dppx) and (max-width: 700px) 67vw, (-webkit-min-device-pixel-ratio: 3) and (max-width: 700px) 65vw, (min-resolution: 2.5dppx) and (max-width: 700px) 80vw, (-webkit-min-device-pixel-ratio: 2.5) and (max-width: 700px) 80vw, (min-resolution: 2dppx) and (max-width: 700px) 100vw, (-webkit-min-device-pixel-ratio: 2) and (max-width: 700px) 100vw, 700px\"><source srcset=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/1*NoAUQgwRbDFkCkSHMnexVQ.png 640w, https:\/\/miro.medium.com\/v2\/resize:fit:720\/1*NoAUQgwRbDFkCkSHMnexVQ.png 720w, https:\/\/miro.medium.com\/v2\/resize:fit:750\/1*NoAUQgwRbDFkCkSHMnexVQ.png 750w, https:\/\/miro.medium.com\/v2\/resize:fit:786\/1*NoAUQgwRbDFkCkSHMnexVQ.png 786w, https:\/\/miro.medium.com\/v2\/resize:fit:828\/1*NoAUQgwRbDFkCkSHMnexVQ.png 828w, https:\/\/miro.medium.com\/v2\/resize:fit:1100\/1*NoAUQgwRbDFkCkSHMnexVQ.png 1100w, https:\/\/miro.medium.com\/v2\/resize:fit:1400\/1*NoAUQgwRbDFkCkSHMnexVQ.png 1400w\" sizes=\"(min-resolution: 4dppx) and (max-width: 700px) 50vw, (-webkit-min-device-pixel-ratio: 4) and (max-width: 700px) 50vw, (min-resolution: 3dppx) and (max-width: 700px) 67vw, (-webkit-min-device-pixel-ratio: 3) and (max-width: 700px) 65vw, (min-resolution: 2.5dppx) and (max-width: 700px) 80vw, (-webkit-min-device-pixel-ratio: 2.5) and (max-width: 700px) 80vw, (min-resolution: 2dppx) and (max-width: 700px) 100vw, (-webkit-min-device-pixel-ratio: 2) and (max-width: 700px) 100vw, 700px\" data-testid=\"og\"><\/picture><\/div>\n<\/div>\n<\/figure>\n<h2 id=\"b608\" class=\"pm oa fo be ob rf rg rh of ri rj rk oj ms rl rm rn mw ro rp rq na rr rs rt ru bj\" data-selectable-paragraph=\"\">View the experiment on Comet<\/h2>\n<p id=\"0d84\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">Click the link generated when you end the experiment to view the experiment on Comet\u2019s UI.<\/p>\n<p id=\"949c\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">The Charts dashboard shows plots of the metrics you logged.<\/p>\n<figure class=\"pc pd pe pf pg mb lt lu paragraph-image\">\n<div class=\"mc md eb me bg mf\" tabindex=\"0\" role=\"button\">\n<figure><img loading=\"lazy\" decoding=\"async\" class=\"bg mg mh c\" role=\"presentation\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*jvfw4y3uvBQzHcRwepQ68A.png\" alt=\"\" width=\"700\" height=\"291\"><\/figure><div class=\"lt lu ry\"><picture><source srcset=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/format:webp\/1*jvfw4y3uvBQzHcRwepQ68A.png 640w, https:\/\/miro.medium.com\/v2\/resize:fit:720\/format:webp\/1*jvfw4y3uvBQzHcRwepQ68A.png 720w, https:\/\/miro.medium.com\/v2\/resize:fit:750\/format:webp\/1*jvfw4y3uvBQzHcRwepQ68A.png 750w, https:\/\/miro.medium.com\/v2\/resize:fit:786\/format:webp\/1*jvfw4y3uvBQzHcRwepQ68A.png 786w, https:\/\/miro.medium.com\/v2\/resize:fit:828\/format:webp\/1*jvfw4y3uvBQzHcRwepQ68A.png 828w, https:\/\/miro.medium.com\/v2\/resize:fit:1100\/format:webp\/1*jvfw4y3uvBQzHcRwepQ68A.png 1100w, https:\/\/miro.medium.com\/v2\/resize:fit:1400\/format:webp\/1*jvfw4y3uvBQzHcRwepQ68A.png 1400w\" type=\"image\/webp\" sizes=\"(min-resolution: 4dppx) and (max-width: 700px) 50vw, (-webkit-min-device-pixel-ratio: 4) and (max-width: 700px) 50vw, (min-resolution: 3dppx) and (max-width: 700px) 67vw, (-webkit-min-device-pixel-ratio: 3) and (max-width: 700px) 65vw, (min-resolution: 2.5dppx) and (max-width: 700px) 80vw, (-webkit-min-device-pixel-ratio: 2.5) and (max-width: 700px) 80vw, (min-resolution: 2dppx) and (max-width: 700px) 100vw, (-webkit-min-device-pixel-ratio: 2) and (max-width: 700px) 100vw, 700px\"><source srcset=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/1*jvfw4y3uvBQzHcRwepQ68A.png 640w, https:\/\/miro.medium.com\/v2\/resize:fit:720\/1*jvfw4y3uvBQzHcRwepQ68A.png 720w, https:\/\/miro.medium.com\/v2\/resize:fit:750\/1*jvfw4y3uvBQzHcRwepQ68A.png 750w, https:\/\/miro.medium.com\/v2\/resize:fit:786\/1*jvfw4y3uvBQzHcRwepQ68A.png 786w, https:\/\/miro.medium.com\/v2\/resize:fit:828\/1*jvfw4y3uvBQzHcRwepQ68A.png 828w, https:\/\/miro.medium.com\/v2\/resize:fit:1100\/1*jvfw4y3uvBQzHcRwepQ68A.png 1100w, https:\/\/miro.medium.com\/v2\/resize:fit:1400\/1*jvfw4y3uvBQzHcRwepQ68A.png 1400w\" sizes=\"(min-resolution: 4dppx) and (max-width: 700px) 50vw, (-webkit-min-device-pixel-ratio: 4) and (max-width: 700px) 50vw, (min-resolution: 3dppx) and (max-width: 700px) 67vw, (-webkit-min-device-pixel-ratio: 3) and (max-width: 700px) 65vw, (min-resolution: 2.5dppx) and (max-width: 700px) 80vw, (-webkit-min-device-pixel-ratio: 2.5) and (max-width: 700px) 80vw, (min-resolution: 2dppx) and (max-width: 700px) 100vw, (-webkit-min-device-pixel-ratio: 2) and (max-width: 700px) 100vw, 700px\" data-testid=\"og\"><\/picture><\/div>\n<\/div>\n<\/figure>\n<p id=\"9d10\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">The hyperparameters dashboard shows the logged parameters.<\/p>\n<figure class=\"pc pd pe pf pg mb lt lu paragraph-image\">\n<div class=\"mc md eb me bg mf\" tabindex=\"0\" role=\"button\">\n<figure><img loading=\"lazy\" decoding=\"async\" class=\"bg mg mh c\" role=\"presentation\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*oGZGOjmTcivDeI0dQHxcwQ.png\" alt=\"\" width=\"700\" height=\"291\"><\/figure><div class=\"lt lu ry\"><picture><source srcset=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/format:webp\/1*oGZGOjmTcivDeI0dQHxcwQ.png 640w, https:\/\/miro.medium.com\/v2\/resize:fit:720\/format:webp\/1*oGZGOjmTcivDeI0dQHxcwQ.png 720w, https:\/\/miro.medium.com\/v2\/resize:fit:750\/format:webp\/1*oGZGOjmTcivDeI0dQHxcwQ.png 750w, https:\/\/miro.medium.com\/v2\/resize:fit:786\/format:webp\/1*oGZGOjmTcivDeI0dQHxcwQ.png 786w, https:\/\/miro.medium.com\/v2\/resize:fit:828\/format:webp\/1*oGZGOjmTcivDeI0dQHxcwQ.png 828w, https:\/\/miro.medium.com\/v2\/resize:fit:1100\/format:webp\/1*oGZGOjmTcivDeI0dQHxcwQ.png 1100w, https:\/\/miro.medium.com\/v2\/resize:fit:1400\/format:webp\/1*oGZGOjmTcivDeI0dQHxcwQ.png 1400w\" type=\"image\/webp\" sizes=\"(min-resolution: 4dppx) and (max-width: 700px) 50vw, (-webkit-min-device-pixel-ratio: 4) and (max-width: 700px) 50vw, (min-resolution: 3dppx) and (max-width: 700px) 67vw, (-webkit-min-device-pixel-ratio: 3) and (max-width: 700px) 65vw, (min-resolution: 2.5dppx) and (max-width: 700px) 80vw, (-webkit-min-device-pixel-ratio: 2.5) and (max-width: 700px) 80vw, (min-resolution: 2dppx) and (max-width: 700px) 100vw, (-webkit-min-device-pixel-ratio: 2) and (max-width: 700px) 100vw, 700px\"><source srcset=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/1*oGZGOjmTcivDeI0dQHxcwQ.png 640w, https:\/\/miro.medium.com\/v2\/resize:fit:720\/1*oGZGOjmTcivDeI0dQHxcwQ.png 720w, https:\/\/miro.medium.com\/v2\/resize:fit:750\/1*oGZGOjmTcivDeI0dQHxcwQ.png 750w, https:\/\/miro.medium.com\/v2\/resize:fit:786\/1*oGZGOjmTcivDeI0dQHxcwQ.png 786w, https:\/\/miro.medium.com\/v2\/resize:fit:828\/1*oGZGOjmTcivDeI0dQHxcwQ.png 828w, https:\/\/miro.medium.com\/v2\/resize:fit:1100\/1*oGZGOjmTcivDeI0dQHxcwQ.png 1100w, https:\/\/miro.medium.com\/v2\/resize:fit:1400\/1*oGZGOjmTcivDeI0dQHxcwQ.png 1400w\" sizes=\"(min-resolution: 4dppx) and (max-width: 700px) 50vw, (-webkit-min-device-pixel-ratio: 4) and (max-width: 700px) 50vw, (min-resolution: 3dppx) and (max-width: 700px) 67vw, (-webkit-min-device-pixel-ratio: 3) and (max-width: 700px) 65vw, (min-resolution: 2.5dppx) and (max-width: 700px) 80vw, (-webkit-min-device-pixel-ratio: 2.5) and (max-width: 700px) 80vw, (min-resolution: 2dppx) and (max-width: 700px) 100vw, (-webkit-min-device-pixel-ratio: 2) and (max-width: 700px) 100vw, 700px\" data-testid=\"og\"><\/picture><\/div>\n<\/div>\n<\/figure>\n<p id=\"8664\" class=\"pw-post-body-paragraph mi mj fo be b mk ml mm mn mo mp mq mr ms mt mu mv mw mx my mz na nb nc nd ne fh bj\" data-selectable-paragraph=\"\">The Graphics dashboard shows all the logged charts.<\/p>\n<figure class=\"pc pd pe pf pg mb lt lu paragraph-image\">\n<div class=\"mc md eb me bg mf\" tabindex=\"0\" role=\"button\">\n<figure><img loading=\"lazy\" decoding=\"async\" class=\"bg mg mh c\" role=\"presentation\" src=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*WUkRiaT1EdKCeGoT4KnCHA.png\" alt=\"\" width=\"700\" height=\"289\"><\/figure><div class=\"lt lu rz\"><picture><source srcset=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/format:webp\/1*WUkRiaT1EdKCeGoT4KnCHA.png 640w, https:\/\/miro.medium.com\/v2\/resize:fit:720\/format:webp\/1*WUkRiaT1EdKCeGoT4KnCHA.png 720w, https:\/\/miro.medium.com\/v2\/resize:fit:750\/format:webp\/1*WUkRiaT1EdKCeGoT4KnCHA.png 750w, https:\/\/miro.medium.com\/v2\/resize:fit:786\/format:webp\/1*WUkRiaT1EdKCeGoT4KnCHA.png 786w, https:\/\/miro.medium.com\/v2\/resize:fit:828\/format:webp\/1*WUkRiaT1EdKCeGoT4KnCHA.png 828w, https:\/\/miro.medium.com\/v2\/resize:fit:1100\/format:webp\/1*WUkRiaT1EdKCeGoT4KnCHA.png 1100w, https:\/\/miro.medium.com\/v2\/resize:fit:1400\/format:webp\/1*WUkRiaT1EdKCeGoT4KnCHA.png 1400w\" type=\"image\/webp\" sizes=\"(min-resolution: 4dppx) and (max-width: 700px) 50vw, (-webkit-min-device-pixel-ratio: 4) and (max-width: 700px) 50vw, (min-resolution: 3dppx) and (max-width: 700px) 67vw, (-webkit-min-device-pixel-ratio: 3) and (max-width: 700px) 65vw, (min-resolution: 2.5dppx) and (max-width: 700px) 80vw, (-webkit-min-device-pixel-ratio: 2.5) and (max-width: 700px) 80vw, (min-resolution: 2dppx) and (max-width: 700px) 100vw, (-webkit-min-device-pixel-ratio: 2) and (max-width: 700px) 100vw, 700px\"><source srcset=\"https:\/\/miro.medium.com\/v2\/resize:fit:640\/1*WUkRiaT1EdKCeGoT4KnCHA.png 640w, https:\/\/miro.medium.com\/v2\/resize:fit:720\/1*WUkRiaT1EdKCeGoT4KnCHA.png 720w, https:\/\/miro.medium.com\/v2\/resize:fit:750\/1*WUkRiaT1EdKCeGoT4KnCHA.png 750w, https:\/\/miro.medium.com\/v2\/resize:fit:786\/1*WUkRiaT1EdKCeGoT4KnCHA.png 786w, https:\/\/miro.medium.com\/v2\/resize:fit:828\/1*WUkRiaT1EdKCeGoT4KnCHA.png 828w, https:\/\/miro.medium.com\/v2\/resize:fit:1100\/1*WUkRiaT1EdKCeGoT4KnCHA.png 1100w, https:\/\/miro.medium.com\/v2\/resize:fit:1400\/1*WUkRiaT1EdKCeGoT4KnCHA.png 1400w\" sizes=\"(min-resolution: 4dppx) and (max-width: 700px) 50vw, (-webkit-min-device-pixel-ratio: 4) and (max-width: 700px) 50vw, (min-resolution: 3dppx) and (max-width: 700px) 67vw, (-webkit-min-device-pixel-ratio: 3) and (max-width: 700px) 65vw, (min-resolution: 2.5dppx) and (max-width: 700px) 80vw, (-webkit-min-device-pixel-ratio: 2.5) and (max-width: 700px) 80vw, (min-resolution: 2dppx) and (max-width: 700px) 100vw, (-webkit-min-device-pixel-ratio: 2) and (max-width: 700px) 100vw, 700px\" data-testid=\"og\"><\/picture><\/div>\n<\/div>\n<\/figure>\n<\/div>\n<\/div>\n<\/div>\n\n\n\n<div class=\"fh fi fj fk fl\">\n<div class=\"ab ca\">\n<div class=\"ch bg et eu ev ew\">\n<h1 id=\"58d8\" class=\"nz oa fo be ob oc od oe of og oh oi oj ok ol om on oo op oq or os ot ou ov ow bj\" data-selectable-paragraph=\"\">Final thoughts<\/h1>\n<p id=\"a80f\" class=\"pw-post-body-paragraph mi mj fo be b mk ox mm mn mo oy mq mr ms oz mu mv mw pa my mz na pb nc nd ne fh bj\" data-selectable-paragraph=\"\">In this article, you have seen how to track Flax experiments with Comet. Apart from that, you have also seen how to:<\/p>\n<ul class=\"\">\n<li id=\"6699\" class=\"mi mj fo be b mk ml mm mn mo mp mq mr ms ng mu mv mw nh my mz na ni nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Load datasets in JAX.<\/li>\n<li id=\"4391\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Define networks with Flax.<\/li>\n<li id=\"2295\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Create a training state in Flax.<\/li>\n<li id=\"de60\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\">Train CNN in Flax.<\/li>\n<\/ul>\n<h2 id=\"b857\" class=\"pm oa fo be ob rf rg rh of ri rj rk oj ms rl rm rn mw ro rp rq na rr rs rt ru bj\" data-selectable-paragraph=\"\">Resources<\/h2>\n<ul class=\"\">\n<li id=\"5860\" class=\"mi mj fo be b mk ox mm mn mo oy mq mr ms sb mu mv mw sc my mz na sd nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\"><a class=\"af nf\" href=\"https:\/\/colab.research.google.com\/drive\/1OUh0-dWTdbIqjdL4G8zE9MEYzJarIspA?usp=sharing\" target=\"_blank\" rel=\"noopener ugc nofollow\">Notebook<\/a><\/li>\n<li id=\"cc15\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\"><a class=\"af nf\" href=\"https:\/\/jax.readthedocs.io\/en\/latest\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">JAX<\/a><\/li>\n<li id=\"5383\" class=\"mi mj fo be b mk nm mm mn mo nn mq mr ms no mu mv mw np my mz na nq nc nd ne nj nk nl bj\" data-selectable-paragraph=\"\"><a class=\"af nf\" href=\"https:\/\/flax.readthedocs.io\/en\/latest\/\" target=\"_blank\" rel=\"noopener ugc nofollow\">Flax<\/a><\/li>\n<\/ul>\n<\/div>\n<\/div>\n<\/div>\n\n\n\n<div class=\"ab ca nr ns nt nu\" role=\"separator\"><\/div>\n","protected":false},"excerpt":{"rendered":"<p>JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy\u2019s with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include: Automatic differentiation Vectorization JIT compilation Flax is a [&hellip;]<\/p>\n","protected":false},"author":63,"featured_media":0,"comment_status":"closed","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"customer_name":"","customer_description":"","customer_industry":"","customer_technologies":"","customer_logo":"","_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[9,7],"tags":[],"coauthors":[163],"class_list":["post-7611","post","type-post","status-publish","format-standard","hentry","category-product","category-tutorials"],"yoast_head":"<!-- This site is optimized with the Yoast SEO Premium plugin v25.9 (Yoast SEO v25.9) - https:\/\/yoast.com\/wordpress\/plugins\/seo\/ -->\n<title>Tracking JAX and Flax models with Comet - Comet<\/title>\n<meta name=\"robots\" content=\"index, follow, max-snippet:-1, max-image-preview:large, max-video-preview:-1\" \/>\n<link rel=\"canonical\" href=\"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/\" \/>\n<meta property=\"og:locale\" content=\"en_US\" \/>\n<meta property=\"og:type\" content=\"article\" \/>\n<meta property=\"og:title\" content=\"Tracking JAX and Flax models with Comet\" \/>\n<meta property=\"og:description\" content=\"JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy\u2019s with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include: Automatic differentiation Vectorization JIT compilation Flax is a [&hellip;]\" \/>\n<meta property=\"og:url\" content=\"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/\" \/>\n<meta property=\"og:site_name\" content=\"Comet\" \/>\n<meta property=\"article:publisher\" content=\"https:\/\/www.facebook.com\/cometdotml\" \/>\n<meta property=\"article:published_time\" content=\"2023-09-22T19:45:28+00:00\" \/>\n<meta property=\"article:modified_time\" content=\"2025-04-24T17:13:56+00:00\" \/>\n<meta property=\"og:image\" content=\"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*WpEtFvcW7jxALJ5dXz44BQ.png\" \/>\n<meta name=\"author\" content=\"Derrick Mwiti\" \/>\n<meta name=\"twitter:card\" content=\"summary_large_image\" \/>\n<meta name=\"twitter:creator\" content=\"@Cometml\" \/>\n<meta name=\"twitter:site\" content=\"@Cometml\" \/>\n<meta name=\"twitter:label1\" content=\"Written by\" \/>\n\t<meta name=\"twitter:data1\" content=\"Derrick Mwiti\" \/>\n\t<meta name=\"twitter:label2\" content=\"Est. reading time\" \/>\n\t<meta name=\"twitter:data2\" content=\"7 minutes\" \/>\n<!-- \/ Yoast SEO Premium plugin. -->","yoast_head_json":{"title":"Tracking JAX and Flax models with Comet - Comet","robots":{"index":"index","follow":"follow","max-snippet":"max-snippet:-1","max-image-preview":"max-image-preview:large","max-video-preview":"max-video-preview:-1"},"canonical":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/","og_locale":"en_US","og_type":"article","og_title":"Tracking JAX and Flax models with Comet","og_description":"JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy\u2019s with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include: Automatic differentiation Vectorization JIT compilation Flax is a [&hellip;]","og_url":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/","og_site_name":"Comet","article_publisher":"https:\/\/www.facebook.com\/cometdotml","article_published_time":"2023-09-22T19:45:28+00:00","article_modified_time":"2025-04-24T17:13:56+00:00","og_image":[{"url":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*WpEtFvcW7jxALJ5dXz44BQ.png","type":"","width":"","height":""}],"author":"Derrick Mwiti","twitter_card":"summary_large_image","twitter_creator":"@Cometml","twitter_site":"@Cometml","twitter_misc":{"Written by":"Derrick Mwiti","Est. reading time":"7 minutes"},"schema":{"@context":"https:\/\/schema.org","@graph":[{"@type":"Article","@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/#article","isPartOf":{"@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/"},"author":{"name":"Derrick Mwiti","@id":"https:\/\/www.comet.com\/site\/#\/schema\/person\/9808205cca68ec95b6fbd918d195cea6"},"headline":"Tracking JAX and Flax models with Comet","datePublished":"2023-09-22T19:45:28+00:00","dateModified":"2025-04-24T17:13:56+00:00","mainEntityOfPage":{"@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/"},"wordCount":722,"publisher":{"@id":"https:\/\/www.comet.com\/site\/#organization"},"image":{"@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/#primaryimage"},"thumbnailUrl":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*WpEtFvcW7jxALJ5dXz44BQ.png","articleSection":["Product","Tutorials"],"inLanguage":"en-US"},{"@type":"WebPage","@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/","url":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/","name":"Tracking JAX and Flax models with Comet - Comet","isPartOf":{"@id":"https:\/\/www.comet.com\/site\/#website"},"primaryImageOfPage":{"@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/#primaryimage"},"image":{"@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/#primaryimage"},"thumbnailUrl":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*WpEtFvcW7jxALJ5dXz44BQ.png","datePublished":"2023-09-22T19:45:28+00:00","dateModified":"2025-04-24T17:13:56+00:00","breadcrumb":{"@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/#breadcrumb"},"inLanguage":"en-US","potentialAction":[{"@type":"ReadAction","target":["https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/"]}]},{"@type":"ImageObject","inLanguage":"en-US","@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/#primaryimage","url":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*WpEtFvcW7jxALJ5dXz44BQ.png","contentUrl":"https:\/\/miro.medium.com\/v2\/resize:fit:700\/1*WpEtFvcW7jxALJ5dXz44BQ.png"},{"@type":"BreadcrumbList","@id":"https:\/\/www.comet.com\/site\/blog\/tracking-jax-and-flax-models-with-comet\/#breadcrumb","itemListElement":[{"@type":"ListItem","position":1,"name":"Home","item":"https:\/\/www.comet.com\/site\/"},{"@type":"ListItem","position":2,"name":"Tracking JAX and Flax models with Comet"}]},{"@type":"WebSite","@id":"https:\/\/www.comet.com\/site\/#website","url":"https:\/\/www.comet.com\/site\/","name":"Comet","description":"Build Better Models Faster","publisher":{"@id":"https:\/\/www.comet.com\/site\/#organization"},"potentialAction":[{"@type":"SearchAction","target":{"@type":"EntryPoint","urlTemplate":"https:\/\/www.comet.com\/site\/?s={search_term_string}"},"query-input":{"@type":"PropertyValueSpecification","valueRequired":true,"valueName":"search_term_string"}}],"inLanguage":"en-US"},{"@type":"Organization","@id":"https:\/\/www.comet.com\/site\/#organization","name":"Comet ML, Inc.","alternateName":"Comet","url":"https:\/\/www.comet.com\/site\/","logo":{"@type":"ImageObject","inLanguage":"en-US","@id":"https:\/\/www.comet.com\/site\/#\/schema\/logo\/image\/","url":"https:\/\/www.comet.com\/site\/wp-content\/uploads\/2025\/01\/logo_comet_square.png","contentUrl":"https:\/\/www.comet.com\/site\/wp-content\/uploads\/2025\/01\/logo_comet_square.png","width":310,"height":310,"caption":"Comet ML, Inc."},"image":{"@id":"https:\/\/www.comet.com\/site\/#\/schema\/logo\/image\/"},"sameAs":["https:\/\/www.facebook.com\/cometdotml","https:\/\/x.com\/Cometml","https:\/\/www.youtube.com\/channel\/UCmN63HKvfXSCS-UwVwmK8Hw"]},{"@type":"Person","@id":"https:\/\/www.comet.com\/site\/#\/schema\/person\/9808205cca68ec95b6fbd918d195cea6","name":"Derrick Mwiti","image":{"@type":"ImageObject","inLanguage":"en-US","@id":"https:\/\/www.comet.com\/site\/#\/schema\/person\/image\/b7db96aa11f77239bbde5eb79ede1493","url":"https:\/\/secure.gravatar.com\/avatar\/d52d009e8d0a72c0dcd785caadeefbb3fb7aa64567e9f5a1e65f5faad18f2426?s=96&d=mm&r=g","contentUrl":"https:\/\/secure.gravatar.com\/avatar\/d52d009e8d0a72c0dcd785caadeefbb3fb7aa64567e9f5a1e65f5faad18f2426?s=96&d=mm&r=g","caption":"Derrick Mwiti"},"url":"https:\/\/www.comet.com\/site\/blog\/author\/mwitiderrickgmail-com\/"}]}},"jetpack_featured_media_url":"","jetpack_sharing_enabled":true,"_links":{"self":[{"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/posts\/7611","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/users\/63"}],"replies":[{"embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/comments?post=7611"}],"version-history":[{"count":1,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/posts\/7611\/revisions"}],"predecessor-version":[{"id":15533,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/posts\/7611\/revisions\/15533"}],"wp:attachment":[{"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/media?parent=7611"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/categories?post=7611"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/tags?post=7611"},{"taxonomy":"author","embeddable":true,"href":"https:\/\/www.comet.com\/site\/wp-json\/wp\/v2\/coauthors?post=7611"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}