Skip to content

Integrate with PyTorch

PyTorch is a popular open source machine learning framework based on the Torch library, used for applications such as computer vision and natural language processing.

PyTorch enables fast, flexible experimentation and efficient production through a user-friendly front-end, distributed training, and ecosystem of tools and libraries.

Instrument PyTorch with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.

Open In Colab

This integration also supports PyTorch Distributed Data Parallel. See below.

Start logging

Connect Comet to your existing code by adding in a simple Comet Experiment.

Add the following lines of code to your script or notebook:

import comet_ml
import torch
import torchvision

experiment = comet_ml.Experiment(
    api_key="<Your API Key>",
    project_name="<Your Project Name>"

# Your code here


There are other ways to configure Comet. See more here.

Log automatically

After an Experiment has been created, Comet automatically logs the following PyTorch items, by default, with no additional configuration:

  • Model and graph description
  • Training loss

You can easily turn the automatic logging on and off for any or all items. See Configure Comet for PyTorch for more details.


Don't see what you need to log here? We have your back. You can manually log any kind of data to Comet using the Experiment object. For example, use experiment.log_image to log images, or experiment.log_audio to log audio.

End-to-end example

Following is a basic example for using Comet with PyTorch.

If you can't wait, check out the results of this example PyTorch experiment for a preview of what's to come.

Install dependencies

pip install comet_ml torch torchvision tqdm

Run the example

import comet_ml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm

experiment = comet_ml.Experiment()

hyper_params = {"batch_size": 100, "num_epochs": 2, "learning_rate": 0.01}

# MNIST Dataset
dataset = datasets.MNIST(
    root="./data/", train=True, transform=transforms.ToTensor(), download=True

# Data Loader (Input Pipeline)
dataloader =
    dataset=dataset, batch_size=hyper_params["batch_size"], shuffle=True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

def train(model, optimizer, criterion, dataloader, epoch):
    total_loss = 0
    correct = 0
    for batch_idx, (images, labels) in tqdm(enumerate(dataloader)):
        images =
        labels =

        outputs = model(images)

        loss = criterion(outputs, labels)
        pred = outputs.argmax(
            dim=1, keepdim=True
        )  # get the index of the max log-probability


        # Compute train accuracy
        batch_correct = pred.eq(labels.view_as(pred)).sum().item()
        batch_total = labels.size(0)

        total_loss += loss.item()
        correct += batch_correct

        # Log batch_accuracy to Comet; step is each batch
        experiment.log_metric("batch_accuracy", batch_correct / batch_total)

    total_loss /= len(dataloader.dataset)
    correct /= len(dataloader.dataset)

    experiment.log_metrics({"accuracy": correct, "loss": total_loss}, epoch=epoch)

model = Net().to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=hyper_params["learning_rate"])

# Train the Model
with experiment.train():
    print("Running Model Training")
    for epoch in range(hyper_params["num_epochs"]):
        train(model, optimizer, criterion, dataloader, epoch)

PyTorch Distributed Data Parallel

Are you running distributed training with PyTorch? There is an example for logging PyTorch DDP with Comet in the comet-example repository.

Try it out!

Don't just take our word for it, try it out for yourself.

Configure Comet for PyTorch

You can control which PyTorch items are logged automatically. Use any of the following methods:

experiment = comet_ml.Experiment(
    log_graph=True, # Can be be True or False.
    auto_metric_logging=True # Can be be True or False

Add or remove these fields from your .comet.config file under the [comet_auto_log] section to enable or disable logging.

graph=true # can be true or false
metrics=true # can be true or false
export COMET_AUTO_LOG_GRAPH=true # Can be be true or false
export COMET_AUTO_LOG_METRICS=true # Can be be true or false

For more information about configuring Comet, see Configure Comet.

Aug. 3, 2022