Menu

Benchmark multi-model/multi-view models.

Source code for mmbench.workflow.cvae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2023
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Train a constrastive Variational Auto-Encoder (cVAE).
"""

# Imports
import os
import sys
import time
import copy
import torch
from torch.utils.data import TensorDataset
from mmbench.dataset import EUAIMSDataset
from mmbench.color_utils import print_title


[docs]def train_cvae(dataset, datasetdir, outdir, general_lat_dims=15, specific_lat_dims=5, beta=4, lambda1=1, lambda2=2, adam_lr=1e-4, n_epochs=1000, host="http://localhost", port=8085): """ Train a contrastive Variational Auto-Encoder (cVAE). Parameters ---------- dataset: str the dataset name: euaims. datasetdir: str the path to the dataset associated data. outdir: str the destination folder. general_lat_dims: int, default 15 the number of latent dimensions in the general part of the latent space. specific_lat_dims: int, default 5 the number of latent dimensions in the specific part of the latent space. beta: float, default 4 weight of the KL divergence. lambda1: float, default 1 weight for the salient disentanglement loss. lambda2: float, default 2 weight for the background disentanglement loss. adam_lr: float, default 2e-3 the initial learning rate in the ADAM optimizer. n_epochs: int, default 10000 the number of training epochs. host: str, default 'http://localhost' the host on which visdom is launched. port: int, default 8085 the port on which the visdom server is launched. """ from cvae.model import mmcVAE from cvae.loss import mmcVAELoss from cvae.utils import train_model from brainboard import Board print_title("Load dataset...") train_dataset = EUAIMSDataset(datasetdir, train=True) scaler = train_dataset.scaler test_dataset = EUAIMSDataset(datasetdir, train=False, scaler=scaler) datasets = { "train": train_dataset, "val": test_dataset} print_title("Create data loaders...") dataloaders = { split: torch.utils.data.DataLoader( datasets[split], batch_size=len(datasets[split]), shuffle=(True if split == "train" else False), num_workers=1) for split in ["train", "val"]} dataiter = iter(dataloaders["train"]) data, _, _ = dataiter.next() print_title("Create model...") model_name = "cvae" checkpointdir = os.path.join(outdir, "checkpoints") if not os.path.isdir(checkpointdir): os.mkdir(checkpointdir) model = mmcVAE( input_channels=1, input_dim=data.shape[-2:], conv_flts=None, dense_hidden_dims=[128], latent_dims=[general_lat_dims, specific_lat_dims]) print(f" model: {model_name}") print(model) board = Board(host=host, env=f"{dataset}_{model_name}", port=port) optimizer = torch.optim.Adam(params=model.parameters(), lr=adam_lr) criterion = mmcVAELoss(beta=beta, lambda1=lambda1, lambda2=lambda2, use_mse=True) print_title("Train model...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_model(dataloaders, model, device, criterion, optimizer, n_epochs=(n_epochs + 1), board=board, checkpointdir=checkpointdir, save_after_epochs=100)

Follow us

© 2023, mmbench developers