Menu

Benchmark multi-model/multi-view models.

Source code for mmbench.workflow.smcvae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2023
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Train the sparse Multi-Channels Variational Auto-Encoder (sMCVAE).
"""

# Imports
import os
import sys
import time
import copy
import progressbar
import numpy as np
import torch
from torch.utils.data import TensorDataset
from mmbench.dataset import get_train_data, get_test_data
from mmbench.color_utils import print_title


[docs]def train_smcvae(dataset, datasetdir, outdir, fit_lat_dims=10, beta=1, adam_lr=2e-3, n_epochs=10000, host="http://localhost", port=8085): """ Train the sparse Multi-Channels Variational Auto-Encoder (sMCVAE). Parameters ---------- dataset: str the dataset name: euaims or hbn. datasetdir: str the path to the dataset associated data. outdir: str the destination folder. fit_lat_dims: int, default 10 the number of latent dimensions. beta: float, default 1 the loss beta-VAE weight (0.5 for HBN). adam_lr: float, default 2e-3 the initial learning rate in the ADAM optimizer. n_epochs: int, default 10000 the number of training epochs. host: str, default 'http://localhost' the host on which visdom is launched. port: int, default 8085 the port on which the visdom server is launched. """ from brainboard import Board from brainite.models import MCVAE from brainite.losses import MCVAELoss print_title("Load dataset...") modalities = ["clinical", "rois"] X_train, _ = get_train_data(dataset, datasetdir, modalities) # train_indices = X_train["index"] del X_train["index"] print("train:", [(key, arr.shape) for key, arr in X_train.items()]) X_test, _ = get_test_data(dataset, datasetdir, modalities) # test_indices = X_test["index"] del X_test["index"] print("test:", [(key, arr.shape) for key, arr in X_test.items()]) print_title("Create data loaders...") X_train = [X_train[mod].to(torch.float32) for mod in modalities] X_test = [X_test[mod].to(torch.float32) for mod in modalities] print("train:", [arr.shape for arr in X_train]) datasets = { "train": TensorDataset(*X_train), "val": TensorDataset(*X_test)} dataloaders = { split: torch.utils.data.DataLoader( datasets[split], batch_size=len(datasets[split]), shuffle=(True if split == "train" else False), num_workers=1) for split in ["train", "val"]} print_title("Create model...") model_name = "smcvae" n_channels = len(X_train) n_feats = [X.shape[1] for X in X_train] checkpointdir = os.path.join(outdir, "checkpoints") if not os.path.isdir(checkpointdir): os.mkdir(checkpointdir) model = MCVAE( latent_dim=fit_lat_dims, n_channels=n_channels, n_feats=n_feats, vae_model="dense", vae_kwargs={}, sparse=True, noise_init_logvar=-3, noise_fixed=False) print(f" model: {model_name}") print(model) board = Board(host=host, env=f"{dataset}_{model_name}", port=port) optimizer = torch.optim.Adam(params=model.parameters(), lr=adam_lr) criterion = MCVAELoss(n_channels, beta=beta, sparse=True) print_title("Train model...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_model(dataloaders, model, device, criterion, optimizer, n_epochs=(n_epochs + 1), board=board, checkpointdir=checkpointdir, board_updates=update_dropout_rate, save_after_epochs=100)
[docs]def train_model(dataloaders, model, device, criterion, optimizer, scheduler=None, n_epochs=100, checkpointdir=None, save_after_epochs=1, board=None, board_updates=None, load_best=False): """ General function to train a model and display training metrics. Parameters ---------- dataloaders: dict of torch.utils.data.DataLoader the train & validation data loaders. model: nn.Module the model to be trained. device: torch.device the device to work on. criterion: torch.nn._Loss the criterion to be optimized. optimizer: torch.optim.Optimizer the optimizer. scheduler: torch.optim.lr_scheduler, default None the scheduler. n_epochs: int, default 100 the number of epochs. checkpointdir: str, default None a destination folder where intermediate models/histories will be saved. save_after_epochs: int, default 1 determines when the model is saved and represents the number of epochs before saving. board: brainboard.Board, default None a board to display live results. board_updates: list of callable, default None update displayed item on the board. load_best: bool, default False optionally load the best model regarding the loss. """ since = time.time() if board_updates is not None: board_updates = listify(board_updates) best_model_wts = copy.deepcopy(model.state_dict()) best_loss = sys.float_info.max dataset_sizes = {x: len(dataloaders[x]) for x in ["train", "val"]} model = model.to(device) with progressbar.ProgressBar(max_value=n_epochs) as bar: for epoch in range(n_epochs): for phase in ["train", "val"]: if phase == "train": model.train() else: model.eval() running_loss = 0.0 running_extra_loss = {} for batch_data in dataloaders[phase]: if isinstance(batch_data, list): batch_data = batch_data[0] batch_data = to_device(batch_data, device) # Zero the parameter gradients optimizer.zero_grad() # Forward: # track history if only in train with torch.set_grad_enabled(phase == "train"): outputs, layer_outputs = model(batch_data) criterion.layer_outputs = layer_outputs try: loss, extra_loss = criterion(outputs) except: loss, extra_loss = criterion(outputs, batch_data) for key, val in extra_loss.items(): if key not in running_extra_loss: running_extra_loss[key] = val.item() else: running_extra_loss[key] += val.item() # Backward + optimize only if in training phase if phase == "train": loss.backward() optimizer.step() # Statistics running_loss += loss.item() * batch_data[0].size(0) for key in running_extra_loss.keys(): running_extra_loss[key] *= batch_data[0].size(0) if scheduler is not None and phase == "train": scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_extra_loss = copy.deepcopy(running_extra_loss) for key in epoch_extra_loss.keys(): epoch_extra_loss[key] /= dataset_sizes[phase] if board is not None: if epoch % 25 == 0: board.update_plot( "loss_{0}".format(phase), epoch, epoch_loss) for name, val in epoch_extra_loss.items(): board.update_plot( "{0}_{1}".format(name, phase), epoch, val) # Display validation classification results if (board is not None and board_updates is not None and phase == "val"): if epoch % 25 == 0: for update in board_updates: update(model, board, outputs, layer_outputs) # Deep copy the best model if phase == "val" and epoch_loss < best_loss: best_loss = epoch_loss best_model_wts = copy.deepcopy(model.state_dict()) # Save intermediate results if checkpointdir is not None and epoch % save_after_epochs == 0: outfile = os.path.join( checkpointdir, "model_{0}.pth".format(epoch)) checkpoint( model=model, outfile=outfile, optimizer=optimizer, scheduler=scheduler, epoch=epoch, epoch_loss=epoch_loss) bar.update(epoch) time_elapsed = time.time() - since print("Training complete in {:.0f}m {:.0f}s".format( time_elapsed // 60, time_elapsed % 60)) print("Best val loss: {:4f}".format(best_loss)) # Load best model weights if load_best: model.load_state_dict(best_model_wts)
[docs]def listify(data): """ Ensure that the input is a list or tuple. Parameters ---------- arr: list or array the input data. Returns ------- out: list the liftify input data. """ if isinstance(data, list) or isinstance(data, tuple): return data else: return [data]
[docs]def to_device(data, device): """ Transfer data to device. Parameters ---------- data: tensor or list of tensor the data to be transfered. device: torch.device the device to work on. Returns ------- out: tensor or list of tensor the transfered data. """ if isinstance(data, list): return [tensor.to(device) for tensor in data] else: return data.to(device)
[docs]def checkpoint(model, outfile, optimizer=None, scheduler=None, **kwargs): """ Save the weights of a given model. Parameters ---------- model: nn.Module the model to be saved. outfile: str the destination file name. optimizer: torch.optim.Optimizer the optimizer. scheduler: torch.optim.lr_scheduler, default None the scheduler. kwargs: dict others parameters to be saved. """ kwargs.update(model=model.state_dict()) if optimizer is not None: kwargs.update(optimizer=optimizer.state_dict()) if scheduler is not None: kwargs.update(scheduler=scheduler.state_dict()) torch.save(kwargs, outfile)
[docs]def update_dropout_rate(model, board, outputs, layer_outputs=None): """ Display the dropout rate. """ if model.log_alpha is not None: do = np.sort(model.dropout.numpy().reshape(-1)) board.update_hist("dropout_probability", do)

Follow us

© 2023, mmbench developers