Menu

Benchmark multi-model/multi-view models.

Note

This page is a reference documentation. It only explains the function signature, and not how to use it. Please refer to the gallery for the big picture.

mmbench.workflow.smcvae.train_model(dataloaders, model, device, criterion, optimizer, scheduler=None, n_epochs=100, checkpointdir=None, save_after_epochs=1, board=None, board_updates=None, load_best=False)[source]ΒΆ

General function to train a model and display training metrics.

Parameters

dataloaders : dict of torch.utils.data.DataLoader

the train & validation data loaders.

model : nn.Module

the model to be trained.

device : torch.device

the device to work on.

criterion : torch.nn._Loss

the criterion to be optimized.

optimizer : torch.optim.Optimizer

the optimizer.

scheduler : torch.optim.lr_scheduler, default None

the scheduler.

n_epochs : int, default 100

the number of epochs.

checkpointdir : str, default None

a destination folder where intermediate models/histories will be saved.

save_after_epochs : int, default 1

determines when the model is saved and represents the number of epochs before saving.

board : brainboard.Board, default None

a board to display live results.

board_updates : list of callable, default None

update displayed item on the board.

load_best : bool, default False

optionally load the best model regarding the loss.

Follow us

© 2023, mmbench developers