Menu

Benchmark multi-model/multi-view models.

Source code for mmbench.baseline.vae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2022 - 2023
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Define the VAE model.
"""

# Imports
import os
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from sklearn.preprocessing import StandardScaler
from mmbench.dataset import get_train_data, get_test_data
from mmbench.color_utils import print_title, print_subtitle, print_result
from mmbench.workflow.smcvae import train_model


[docs]def train_vae(dataset, datasetdir, outdir, fit_lat_dims=20, n_iter=100, host="http://localhost", port=8085): """ Train the VAE model. Parameters ---------- dataset: str the dataset name: euaims or hbn. datasetdir: str the path to the dataset associated data. outdir: str the destination folder. fit_lat_dims: int, default 20 the number of latent dimensions. n_iter: int, default 100 the number of train iterations. host: str, default 'http://localhost' the host on which visdom is launched. port: int, default 8085 the port on which the visdom server is launched. """ from brainboard import Board from brainite.models import VAE from brainite.losses import BetaHLoss print_title("Training VAE...") print_subtitle("Generating data loader...") modalities = ["rois", "clinical"] if not os.path.isdir(outdir): os.mkdir(outdir) X_train, _ = get_train_data(dataset, datasetdir, modalities) del X_train["index"] X_test, _ = get_test_data(dataset, datasetdir, modalities) del X_test["index"] X_train = znorm(X_train["rois"]).to(torch.float32) X_test = znorm(X_test["rois"]).to(torch.float32) print("- train:", X_train.shape) print("- test:", X_test.shape) datasets = { "train": torch.utils.data.TensorDataset(X_train), "val": torch.utils.data.TensorDataset(X_test)} dataloaders = { split: torch.utils.data.DataLoader( datasets[split], batch_size=100, shuffle=(True if split == "train" else False), num_workers=1) for split in ["train", "val"]} print_subtitle("Generating model...") model_name = "vae" model = VAE(input_channels=1, input_dim=X_train.shape[1], conv_flts=None, dense_hidden_dims=[256], latent_dim=fit_lat_dims, noise_fixed=True, dropout=0.1, sparse=False) print(model) print_subtitle("Fitting model...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001, weight_decay=0.01) criterion = BetaHLoss(beta=10, steps_anneal=0, use_mse=True) scheduler = MyStepLR(optimizer=optimizer, init_lr=0.001, min_lr=0.0001, step_size=100, gamma=0.9) board = Board(host=host, env=f"{dataset}_{model_name}", port=port) train_model(dataloaders, model, device, criterion, optimizer, scheduler=scheduler, board=board, n_epochs=n_iter) path = os.path.join(outdir, f"{model.__class__.__name__}_final.pth") torch.save(model.state_dict(), path)
[docs]class MyStepLR(torch.optim.lr_scheduler.MultiStepLR): """ Custom step LR scheduler. """
[docs] def __init__(self, optimizer, init_lr, min_lr, step_size, gamma): n_steps = int(np.log(min_lr / init_lr) / np.log(gamma)) milestones = [step_size * idx for idx in range(1, n_steps)] super(MyStepLR, self).__init__( optimizer, milestones=milestones, gamma=gamma)
def step(self, metric=None, epoch=None): super(MyStepLR, self).step()
[docs]def znorm(arr): print(arr) std, mean, eps = (1., 0., 1e-8) return std * (arr - torch.mean(arr)) / (torch.std(arr) + eps) + mean

Follow us

© 2023, mmbench developers