Skip to content

integration.pytorch

load_model

comet_ml.integration.pytorch.load_model(MODEL_URI: str,
    map_location: Any = None, pickle_module: Optional[Module] = None,
    **torch_load_args) -> ModelStateDict

Load model's state_dict from experiment, registry or from disk by uri. This will returns a Pytorch state_dict that you will need to load into your model. This will load the model using torch.load.

Here is an example of loading a model from the Model Registry for inference:

from comet_ml.integration.pytorch import load_model

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Load the model state dict from Comet Registry
model.load_state_dict(load_model("registry://WORKSPACE/TheModel:1.2.4"))

model.eval()

prediction = model(...)

Here is an example of loading a model from an Experiment for Resume Training:

from comet_ml.integration.pytorch import load_model

# Initialize model
model = TheModelClass()

# Load the model state dict from a Comet Experiment
checkpoint = load_model("experiment://e1098c4e1e764ff89881b868e4c70f5/TheModel")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.train()

Args:

  • uri: string (required), a uri string defining model location. Possible options are:

    • file://data/my-model

    • file:///path/to/my-model

    • registry://workspace/registry_name (takes the last version)

    • registry://workspace/registry_name:version

    • experiment://experiment_key/model_name

    • experiment://workspace/project_name/experiment_name/model_name

  • map_location: (optional) passed to torch.load (see torch.load)

  • pickle_module: (optional) passed to torch.load (see torch.load)
  • torch_load_args: (optional) passed to torch.load (see torch.load)

Returns: model's state dict

log_model

comet_ml.integration.pytorch.log_model(experiment, model, model_name,
    metadata=None, pickle_module=None, **torch_save_args)

Logs a Pytorch model to an experiment. This will save the model using torch.save and save it as an Experiment Model.

The model parameter can either be an instance of torch.nn.Module or any input supported by torch.save, see the tutorial about saving and loading Pytorch models for more details.

Here is an example of logging a model for inference:

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Train model
train(model)

# Save the model for inference
log_model(experiment, model, model_name="TheModel")

Here is an example of logging a checkpoint for resume training:

model_checkpoint = {
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
    ...
}
log_model(experiment, model_checkpoint, model_name="TheModel")

Args:

  • experiment: Experiment (required), instance of experiment to log model
  • model: model's state dict or torch.nn.Module (required), model to log
  • model_name: string (required), the name of the model
  • metadata: dict (optional), some additional data to attach to the the data. Must be a JSON-encodable dict
  • pickle_module: (optional) passed to torch.save (see torch.save documentation)
  • torch_save_args: (optional) passed to torch.save (see torch.save documentation)

Returns: None

watch

comet_ml.integration.pytorch.watch(model: torch.nn.Module,
    log_step_interval: int = 1000) -> None

Enables automatic logging of each layer's parameters and gradients in the given PyTorch module. These will be logged as histograms. Note that an Experiment must be created before calling this function.

Args:

  • model: torch.nn.Module, an instance of torch.nn.Module.
  • log_step_interval: int (optional), determines how often layers are logged (default is every 1000 steps).
Feb. 24, 2024