Skip to content

Integrate with PyTorch Lightning

PyTorch Lightning helps organize PyTorch code and decouple the science code from the engineering code. It’s more of a style-guide than a framework. By organizing PyTorch code under a LightningModule, Lightning makes things like TPU, multi-GPU and 16-bit precision training (40+ other features) trivial.

Open In Colab

import comet_ml
import os
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger

# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger()

# Your training code

trainer = Trainer(
    max_epochs=3,
    logger=comet_logger
)
trainer.fit(model, train_loader, eval_loader)

Configure Comet for Pytorch Lightning

Pytorch Lightning now ships with a dedicated CometLogger. Find more information about configuring the logger here:

End-to-end example

import comet_ml
import os
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
import pytorch_lightning as pl

from torch.optim.lr_scheduler import StepLR
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.logger.log_metrics({'train_loss': loss}, step=batch_nb)
        return loss

    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        self.logger.log_metrics({'val_loss': loss}, step=batch_nb)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

# Init our model
model = Model()

# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger()

# Log Parameters
comet_logger.log_hyperparams({"batch_size": BATCH_SIZE})

# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

eval_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
eval_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
trainer = Trainer(
    gpus=AVAIL_GPUS,
    max_epochs=3,
    logger=comet_logger
)

# Train the model ⚡
trainer.fit(model, train_loader, eval_loader)

Note

There are alternatives to setting the API key programatically. See more here.

Try it out!

Here's an example for using Comet with PyTorch Lightning.

Open In Colab

Jul. 7, 2022