Skip to content

Integrate with PyTorch Lightning

PyTorch Lightning is the deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. Lightning evolves with you as your projects go from idea to paper/production.

Instrument PyTorch Lightning with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.

Open In Colab

Start logging

Connect Comet to your existing Lightning code by adding the CometLogger to your script or notebook.

import comet_ml
import os
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger

# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger()

# Your training code

trainer = Trainer(
    max_epochs=3,
    logger=comet_logger
)
trainer.fit(model, train_loader, eval_loader)

End-to-end example

Here is a basic example for how you can use Comet with Pytorch Lightning. If you can't wait, and would like a preview of what's to come, check out a completed experiment here.

Install dependencies

pip install comet_ml pytorch-lightning torch torchvision

Run the example

import os

import comet_ml
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64


class Model(pl.LightningModule):
    def __init__(self, layer_size=784):
        super().__init__()
        self.save_hyperparameters()
        self.l1 = torch.nn.Linear(layer_size, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.logger.log_metrics({"train_loss": loss}, step=batch_nb)
        return loss

    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        self.logger.log_metrics({"val_loss": loss}, step=batch_nb)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


# Init our model
model = Model()

# Arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(project_name="comet-examples-lightning")

# Log parameters
comet_logger.log_hyperparams({"batch_size": BATCH_SIZE})

# Init DataLoader from MNIST Dataset
train_ds = MNIST(
    PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor()
)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

eval_ds = MNIST(
    PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor()
)
eval_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
trainer = Trainer(gpus=AVAIL_GPUS, max_epochs=3, logger=comet_logger)

# Train the model ⚡
trainer.fit(model, train_loader, eval_loader)

Try it out!

Don't just take our word for it, try it out yourself.

Open In Colab

Configure Comet for Pytorch Lightning

Pytorch Lightning now ships with a dedicated CometLogger. Find more information about configuring the logger here:

Note

You can pass all the initialization parameters of the Experiment object as keyword arguments to the Lightning CometLogger. There are other ways to configure Comet as well. See more here.

Open In Colab

Apr. 25, 2024