Benchmark multi-model/multi-view models.
Source code for mmbench.dataset
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2022
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Define the different datasets.
"""
# Imports
import os
import numpy as np
import pandas as pd
from types import SimpleNamespace
from sklearn.preprocessing import StandardScaler
from torchvision import transforms
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
try:
from cvae.datasets import ContrastiveDataset
except:
ContrastiveDataset = object
from mopoe.multimodal_cohort.dataset import MultimodalDataset, DataManager
from mopoe.multimodal_cohort.dataset import MissingModalitySampler
from mmbench.color_utils import print_text
from mmbench.residualize import residualize as residualizer
# Global parameters
IQ_MAP = {
"euaims": 75.,
"hbn": None
}
[docs]def get_train_data(dataset, datasetdir, modalities, residualize=False):
""" See `get_data` and `iq_threshold` for documentation.
"""
threshold = IQ_MAP.get(dataset)
if dataset == "hbn":
data, meta_df = get_data_legacy(dataset, datasetdir, modalities,
dtype="train")
else:
_, meta_df, data, train_indices, test_indices = get_data(
dataset, datasetdir, modalities, dtype="complete",
residualize=residualize)
meta_df = pd.DataFrame(data=meta_df.values[train_indices],
columns=meta_df.columns,
index=meta_df.index[train_indices])
data = dict((key, item.X_train) for key, item in data.items())
data, meta_df = iq_threshold(dataset, data, meta_df, threshold=threshold)
return data, meta_df
[docs]def get_test_data(dataset, datasetdir, modalities, residualize=False):
""" See `get_data` and `iq_threshold` for documentation.
"""
threshold = IQ_MAP.get(dataset)
if dataset == "hbn":
data, meta_df = get_data_legacy(dataset, datasetdir, modalities,
dtype="test")
else:
_, meta_df, data, train_indices, test_indices = get_data(
dataset, datasetdir, modalities, dtype="complete",
residualize=residualize)
meta_df = pd.DataFrame(data=meta_df.values[test_indices],
columns=meta_df.columns,
index=meta_df.index[test_indices])
data = dict((key, item.X_test) for key, item in data.items())
data, meta_df = iq_threshold(dataset, data, meta_df, threshold=threshold)
return data, meta_df
[docs]def get_train_full_data(dataset, datasetdir, modalities, residualize=False):
""" See `get_data` and `iq_threshold` for documentation.
"""
threshold = IQ_MAP.get(dataset)
if dataset == "hbn":
data, meta_df = get_data_legacy(dataset, datasetdir, modalities,
dtype="full_train")
else:
_, meta_df, data, train_indices, test_indices = get_data(
dataset, datasetdir, modalities, dtype="full",
residualize=residualize)
meta_df = pd.DataFrame(data=meta_df.values[train_indices],
columns=meta_df.columns,
index=meta_df.index[train_indices])
data = dict((key, item.X_train) for key, item in data.items())
data, meta_df = iq_threshold(dataset, data, meta_df, threshold=threshold)
return data, meta_df
[docs]def get_test_full_data(dataset, datasetdir, modalities, residualize=False):
""" See `get_data` and `iq_threshold` for documentation.
"""
threshold = IQ_MAP.get(dataset)
if dataset == "hbn":
data, meta_df = get_data_legacy(dataset, datasetdir, modalities,
dtype="full_test")
else:
_, meta_df, data, train_indices, test_indices = get_data(
dataset, datasetdir, modalities, dtype="full",
residualize=residualize)
meta_df = pd.DataFrame(data=meta_df.values[test_indices],
columns=meta_df.columns,
index=meta_df.index[test_indices])
data = dict((key, item.X_test) for key, item in data.items())
data, meta_df = iq_threshold(dataset, data, meta_df, threshold=threshold)
return data, meta_df
[docs]def iq_threshold(dataset, data, meta_df, threshold=80, col_name="fsiq"):
""" Remove subjects with IQ below a user-defined threshold.
Parameters
----------
data: dict
the loaded data for each modality.
metadata: DataFrame
the associated meta information.
threshold: int, default 80
the minimum IQ. If None no thresholding is applied.
col_name: str, default 'fsiq'
the name of the column containing the IQ information.
Returns
-------
data: dict
the loaded data thresholded for each modality.
meta_df: DataFrame
the associated meta information.
"""
if threshold is None:
return data, meta_df
assert col_name in meta_df.columns, "Can't find the given IQ column name."
indices = meta_df[col_name].values > threshold
print_text(f"Filtering data: {np.sum(indices)}/{len(meta_df)}")
meta_df = meta_df.loc[indices]
indices = torch.argwhere(torch.from_numpy(indices)).flatten()
for key, tensor in data.items():
data[key] = torch.index_select(tensor, 0, indices)
return data, meta_df
[docs]def get_data_legacy(dataset, datasetdir, modalities, dtype):
""" Load the train/test data.
Parameters
----------
dataset: str
the dataset name: euaims or hbn.
datasetdir: str
the path to the dataset associated data.
modalities: list of str
the modalities to load.
dtype: str
the data type: 'train', 'test', 'full_test', 'full_train' or 'full'.
Returns
-------
data: dict
the loaded data for each modality.
meta_df: DataFrame
the associated meta information.
"""
trainset, testset = get_dataset(dataset, datasetdir, modalities)
if dtype == "train":
dataset = trainset
elif dtype == "test":
dataset = testset
elif dtype == "full":
datasets = [trainset, testset]
elif dtype == "full_test":
datasets = [testset]
elif dtype == "full_train":
datasets = [trainset]
else:
raise ValueError("Unexpected data type.")
if dtype.startswith("full"):
all_data = {"rois": [], "clinical": []}
all_meta = None
for dataset in datasets:
sampler = MissingModalitySampler(dataset, batch_size=len(dataset))
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=0)
for data, _, meta in loader:
if "rois" not in data:
continue
all_data["rois"].append(data["rois"])
if "clinical" not in data:
all_data["clinical"].append(None)
else:
all_data["clinical"].append(data["clinical"])
if all_meta is None:
all_meta = dict((key, [val]) for key, val in meta.items())
else:
for key, val in meta.items():
all_meta[key].append(val)
clinical_size = set([item.size(1) if item is not None else 0
for item in all_data["clinical"]])
if len(clinical_size) > 1:
clinical_size.remove(0)
assert len(clinical_size) == 1, "All blocks must have the same size."
clinical_size = list(clinical_size)[0]
for idx, (roi_items, clin_items) in enumerate(
zip(all_data["rois"], all_data["clinical"])):
if clin_items is None:
block = torch.empty((roi_items.size(0), clinical_size))
block[:] = float("nan")
all_data["clinical"][idx] = block
all_data["rois"] = torch.cat(all_data["rois"], dim=0)
all_data["clinical"] = torch.cat(all_data["clinical"], dim=0)
print(all_data["rois"].shape, all_data["clinical"].shape)
for key in all_meta:
all_meta[key] = np.concatenate(all_meta[key], axis=0)
data, meta = (all_data, all_meta)
else:
sampler = MissingModalitySampler(dataset, batch_size=len(dataset))
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=0)
while True:
dataiter = iter(loader)
data, _, meta = next(dataiter)
if all([mod in data.keys() for mod in modalities]):
break
scores = data["clinical"].T
clinical_names = np.load(
os.path.join(datasetdir, "clinical_names.npy"), allow_pickle=True)
clinical_names = [name.replace("t1_", "") for name in clinical_names]
meta = dict((key, val.numpy() if isinstance(val, torch.Tensor) else val)
for key, val in meta.items())
del meta["participant_id"]
meta.update(dict((key, val) for key, val in zip(clinical_names, scores)))
meta_df = pd.DataFrame.from_dict(meta)
return data, meta_df
[docs]def get_data(dataset, datasetdir, modalities, dtype="complete",
test_size=0.2, residualize=False, random_state=42):
""" Load the train/test data.
Parameters
----------
dataset: str
the dataset name: euaims or hbn.
datasetdir: str
the path to the dataset associated data.
modalities: list of str
the modalities to load.
dtype: str, default 'complete'
the data type: 'complete', 'full'.
test_size: float, default=0.2
should be between 0.0 and 1.0 and represent the proportion of the
dataset to include in the test split.
residualize: bool, default False
optionaly residualize the image data.
random_state: int, default 42
controls the shuffling applied to the data before applying the split.
Returns
-------
data: dict of DataFrame
the loaded data for each modality.
meta_df: DataFrame
the associated meta information.
tensors: dict of Tensors
the splitted input data (train, test).
train_indices: list of int
the train indices.
test_indices: list of int
the test indices.
"""
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
data, meta_df = load_data(datasetdir, modalities)
data["rois"].dropna(inplace=True)
if dtype == "full":
subjects = data["rois"].index
elif dtype == "complete":
data["clinical"].dropna(inplace=True)
subjects = set(data["rois"].index).intersection(
set(data["clinical"].index))
else:
raise ValueError("Unexpected data type.")
meta_df = meta_df[meta_df.index.isin(subjects)]
for key, df in data.items():
data[key] = df[df.index.isin(subjects)]
msss = MultilabelStratifiedShuffleSplit(
n_splits=1, test_size=test_size, random_state=random_state)
train_indices, test_indices = next(
msss.split(list(subjects), meta_df.values))
meta_train_df = pd.DataFrame(data=meta_df.values[train_indices],
columns=meta_df.columns,
index=meta_df.index[train_indices])
meta_test_df = pd.DataFrame(data=meta_df.values[test_indices],
columns=meta_df.columns,
index=meta_df.index[test_indices])
tensors = {}
for key, df in data.items():
X_train = df.values[train_indices]
X_test = df.values[test_indices]
if residualize and key == "rois":
X_train, X_test = residualizer(
meta_train_df, X_train, meta_test_df, X_test,
formula_res="age + sex",
formula_full="age + sex + asd", site_name="site",
discrete_vars=["sex"], continuous_vars=["age"])
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
cset = SimpleNamespace(X_train=torch.from_numpy(X_train),
X_test=torch.from_numpy(X_test))
tensors[key] = cset
return data, meta_df, tensors, train_indices, test_indices
[docs]def load_data(datasetdir, modalities):
""" Load the data.
Parameters
----------
datasetdir: str
the path to the dataset associated data.
modalities: list of str
the modalities to load.
Returns
-------
data: dict of DataFrame
the loaded data.
meta_df: DataFrame
the associated meta information.
"""
meta_df = pd.read_csv(os.path.join(datasetdir, "metadata.tsv"), sep="\t")
meta_df.set_index("participant_id", inplace=True)
data = {}
all_subjects = []
for mod in modalities:
_data = np.load(os.path.join(datasetdir, f"{mod}_data.npy"))
subjects = np.load(os.path.join(datasetdir, f"{mod}_subjects.npy"))
all_subjects.extend(subjects.tolist())
names = np.load(os.path.join(datasetdir, f"{mod}_names.npy"),
allow_pickle=True)
data[mod] = pd.DataFrame(data=_data, columns=names, index=subjects)
all_subjects = set(all_subjects)
for key, df in data.items():
missing_subjects = all_subjects - set(df.index)
new_df = pd.DataFrame(np.nan, index=list(missing_subjects),
columns=df.columns)
data[key] = pd.concat([df, new_df])
return data, meta_df
[docs]def get_dataset(dataset, datasetdir, modalities):
""" Load the train/test datasets.
Parameters
----------
dataset: str
the dataset name: euaims or hbn.
datasetdir: str
the path to the dataset associated data.
modalities: list of str
the modalities to load.
Returns
-------
trainset, testset: MultimodalDataset
the loaded train/test datasets.
"""
manager = DataManager(
dataset, datasetdir, modalities, overwrite=False,
allow_missing_blocks=False)
scalers = set_scalers(manager.train_dataset, modalities)
transform = {
mod: transforms.Compose([
unsqueeze_0,
scaler.transform,
transforms.ToTensor(),
torch.squeeze]) for mod, scaler in scalers.items()}
trainset = MultimodalDataset(
manager.fetcher.train_input_path,
manager.fetcher.train_metadata_path,
on_the_fly_transform=transform)
testset = MultimodalDataset(
manager.fetcher.test_input_path,
manager.fetcher.test_metadata_path,
on_the_fly_transform=transform)
return trainset, testset
[docs]def set_scalers(dataset, modalities):
""" Apply a standard scaler modality by modality.
Parameters
----------
dataset: MultimodalDataset
a multi modal dataset.
modalities: list of str
the modalities to load.
Returns
-------
scalers: dict
a fitted standard scaler for each modality.
"""
all_data = {}
for data, label, meta in dataset:
for mod in modalities:
if mod in data.keys():
all_data.setdefault(mod, []).append(data[mod])
scalers = {}
for mod in modalities:
scaler = StandardScaler()
scaler.fit(all_data[mod])
scalers[mod] = scaler
return scalers
[docs]def unsqueeze_0(x):
""" Returns a new tensor with a dimension of size one at dimension 0.
"""
return x.unsqueeze(0)
[docs]class EUAIMSDataset(ContrastiveDataset):
""" From the EUAIMS cohort the target and background datasets are composed
of T1w MRI FreeSurfer ROI features of ASD patients and TD controls,
respectively.
"""
[docs] def __init__(self, root, train=True, transform=None, flatten=False,
seed=42, scaler=None):
""" Init class.
Parameters
----------
root: str
root directory of dataset where data will be saved.
train: bool, default True
specifies training or test dataset.
transform: callable, default None
optional transform to be applied on a sample.
flatten: bool, default False
optionally select all data.
seed: int, default 42
for reproducibility fix a seed.
scaler: sklearn-like scaler, default None
optionally set a fitted scaler.
"""
if train and scaler is None:
scaler = StandardScaler()
super(EUAIMSDataset, self).__init__(
root, train, transform, flatten, seed, scaler)
[docs] def get_data(self):
""" Get the background and target data.
Returns
-------
background: array (N, n_channels, \*)
the background data.
background_labels: array (N, )
the background labels.
target: array (M, n_channels, \*)
the target data.
target_labels: array (M, )
the target labels.
"""
split = "train" if self.train else "test"
meta_split_file = os.path.join(self.root, f"metadata_{split}.tsv")
roi_file = os.path.join(self.root, "rois_data.npy")
subject_file = roi_file.replace("_data.npy", "_subjects.npy")
self.is_file(meta_split_file)
self.is_file(meta_split_file)
self.is_file(subject_file)
df = pd.read_csv(meta_split_file, sep="\t")
subjects = df["participant_id"].values
data = np.load(roi_file)
all_subjects = np.load(subject_file)
indices = np.nonzero(np.in1d(all_subjects, subjects))[0]
print(data.shape, len(subjects), len(indices))
data = data[indices]
subjects = all_subjects[indices]
df = df[df["participant_id"].isin(subjects)]
data = np.expand_dims(data, axis=1)
print(f"EUAIMS data: {data.shape}")
print(f"EUAIMS data dynamic: {data.min()} - {data.max()}")
print(f"EUAIMS subjects: {subjects.shape}")
print(f"EUAIMS metadata: {df.shape}")
print(df)
controls_indices = (df["asd"].values == 1)
background = data[controls_indices]
background_labels = np.array(["td"] * len(background))
patients_indices = (df["asd"].values == 2)
target = data[patients_indices]
target_labels = np.array(["asd"] * len(target))
print(f"EUAIMS controls: {background.shape}")
print(f"EUAIMS patients: {target.shape}")
return background, background_labels, target, target_labels
[docs] def is_file(self, path):
""" Check wethe a EUAIMS data resource file is here.
"""
if not os.path.isfile(path):
raise ValueError("The root folder must contains the EUAIMS data.")
Follow us