Menu

Benchmark multi-model/multi-view models.

Source code for mmbench.baseline.neuroclav

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2022 - 2023
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Define the NeuroClav weakly supervised model.
"""
# Imports
import os
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from sklearn.preprocessing import StandardScaler
from mmbench.dataset import get_train_data, get_test_data
from mmbench.color_utils import print_title, print_subtitle


[docs]def train_neuroclav(dataset, datasetdir, outdir, fit_lat_dims=20, n_iter=100): """ Train the NeuroClav model. Parameters ---------- dataset: str the dataset name: euaims or hbn. datasetdir: str the path to the dataset associated data. outdir: str the destination folder. fit_lat_dims: int, default 20 the number of latent dimensions. n_iter: int, default 100 the number of train iterations. """ from manifold.neuroclav import NeuroCLAV from data.collate import TwoViewsCollateFunction from data.augmentations.base import Compose, Transform from data.augmentations.intensity import ZNormalization class NeuroClavDataset(Dataset): """ Define a dataset that can be used with NeuroClav. """ def __init__(self, X_train, y_train, X_test, y_test, phenotypes, split="train", znorm=True): """ Init class. Parameters ----------- X_train: ndarray (N, M) the input train data, where N is the number of subjets and M the number of ROIs. y_train: ndarray (N, K) the train phenotypes, where N is the number of subjects and K the number of auxiliary variables. X_test: ndarray (L, M) the input train data, where L is the number of subjets and M the number of ROIs. y_test: ndarray (L, K) the train phenotypes, where L is the number of subjects and K the number of auxiliary variables. phenotypes: list of str list of phenotype names. split: str, default 'train' split to instantiate in ('train', 'test'). znorm: boolean, default True whether the phenotypes will be z-normalized across subjects or not. """ assert split in ["train", "test"], f"Unknown split: {split}" for X, y in ((X_train, y_train), (X_test, y_test)): assert len(X) == len(y), ( "Different number of subjects in X and y.") if not isinstance(phenotypes, (list, tuple)): phenotypes = [phenotypes] for y in (y_train, y_test): assert y.shape[1] == len(phenotypes), ( "Different number of phenotypes in y and phenotypes.") self.X_train, self.X_test = X_train, X_test self.y_train, self.y_test = y_train, y_test self.phenotypes = phenotypes self.split = split if znorm: scaler = StandardScaler().fit(self.y_train) self.y_train = scaler.transform(self.y_train) self.y_test = scaler.transform(self.y_test) if split == "train": self.X, self.y = self.X_train, self.y_train else: self.X, self.y = self.X_test, self.y_test self.n_subjects = len(self.X) def __getitem__(self, idx): return self.X[idx], self.y[idx] def __len__(self): return self.n_subjects class RandomCutoutCollateFunction(TwoViewsCollateFunction): """ Implements the transformations for NeuroClav with random cutout. """ def __init__(self, n_iter=10, value=0.0, znorm=True, p=0.5): """ Init class. Parameters ---------- n_iter: int, default 10 number of erased areas. value: float, default 0.0 the replacement value. znorm: bool, default True whether the data will be z-normalized across subjects or not. p: float, default 0.5 probability to apply the transformation. """ transform = [RandomCutout((1, 1), n_iter, value, p=p)] if znorm: transform.append(ZNormalization()) transform = Compose(transform) super(RandomCutoutCollateFunction, self).__init__(transform) class RandomCutout(Transform): """ Define a radom cutout transformation for ROI-like data. """ def __init__(self, n_iter=1, value=0.0, inplace=False, **kwargs): super().__init__(**kwargs) self.n_iter = n_iter self.value = value self.inplace = inplace def parse_data(self, data): if isinstance(data, torch.Tensor): return data.detach().cpu().numpy() elif isinstance(data, np.ndarray): return data else: raise ValueError("Unexpected type: %s" % type(data)) def _erase_area(self, arr, idx): arr[idx] = self.value return arr def apply_transform(self, arr): if not self.inplace: arr = np.copy(arr) indices = np.random.choice(len(arr), self.n_iter, replace=False) for idx in indices: arr = self._erase_area(arr, idx) return arr print_title("Training NeuroClav...") print_subtitle("Generating data loader...") modalities = ["rois", "clinical"] if not os.path.isdir(outdir): os.mkdir(outdir) X_train, _ = get_train_data(dataset, datasetdir, modalities) del X_train["index"] X_test, _ = get_test_data(dataset, datasetdir, modalities) del X_test["index"] X_train, y_train = [X_train[mod].to(torch.float32) for mod in modalities] X_test, y_test = [X_test[mod].to(torch.float32) for mod in modalities] clinical_names = np.load( os.path.join(datasetdir, "clinical_names.npy"), allow_pickle=True) clinical_names = clinical_names.tolist() print("- train:", X_train.shape, y_train.shape) print("- test:", X_test.shape, y_test.shape) print("- phenotypes:", clinical_names, len(clinical_names)) datasets = dict( (split, NeuroClavDataset(X_train, y_train, X_test, y_test, clinical_names, split=split, znorm=True)) for split in ("train", "test")) print_subtitle("Generating model...") model = NeuroCLAV(encoder=f"mlp{X_train.shape[1]}x256", kernel="gaussian", n_components=fit_lat_dims, batch_size=128, max_iteration=n_iter, learning_rate_init=1e-4, dir=outdir, valid_frequency=1, save_frequency=10, validation_fraction=None, num_workers=1) print(model) print(model._build_model(model.encoder)) print_subtitle("Fitting model...") n_erase = int(X_train.shape[1] * 0.1) print("- n erase:", n_erase) collate_fn = RandomCutoutCollateFunction(n_iter=n_erase, p=0.8, znorm=True) model.fit(datasets["train"], collate_fn=collate_fn, aux_names=clinical_names) model.save() path = os.path.join(outdir, f"{model.model_.__class__.__name__}_final.pth") torch.save(model.model_.state_dict(), path)

Follow us

© 2023, mmbench developers