Benchmark multi-model/multi-view models.
Note
This page is a reference documentation. It only explains the function signature, and not how to use it. Please refer to the gallery for the big picture.
-
mmbench.workflow.smcvae.train_model(dataloaders, model, device, criterion, optimizer, scheduler=None, n_epochs=100, checkpointdir=None, save_after_epochs=1, board=None, board_updates=None, load_best=False)[source]ΒΆ General function to train a model and display training metrics.
- Parameters
dataloaders : dict of torch.utils.data.DataLoader
the train & validation data loaders.
model : nn.Module
the model to be trained.
device : torch.device
the device to work on.
criterion : torch.nn._Loss
the criterion to be optimized.
optimizer : torch.optim.Optimizer
the optimizer.
scheduler : torch.optim.lr_scheduler, default None
the scheduler.
n_epochs : int, default 100
the number of epochs.
checkpointdir : str, default None
a destination folder where intermediate models/histories will be saved.
save_after_epochs : int, default 1
determines when the model is saved and represents the number of epochs before saving.
board : brainboard.Board, default None
a board to display live results.
board_updates : list of callable, default None
update displayed item on the board.
load_best : bool, default False
optionally load the best model regarding the loss.
Follow us