Benchmark multi-model/multi-view models.
Source code for mmbench.workflow.rsa
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2022 - 2023
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Define the Representational Similarity Analysis (RSA) workflows.
"""
# Imports
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from mmbench.stat_utils import data2mat, vec2mat, fit_rsa
from mmbench.color_utils import (
print_title, print_subtitle, print_text, print_result,
print_error)
import matplotlib.pyplot as plt
from mmbench.plotting import plot_mat, plot_bar
[docs]def benchmark_rsa_exp(dataset, datasetdir, outdir):
""" Compare the learned latent space of different models using
Representational Similarity Analysis (RSA).
Parameters
----------
dataset: str
the dataset name: euaims or hbn.
datasetdir: str
the path to the dataset associated data.
outdir: str
the destination folder.
Notes
-----
- The samples are generated with the 'bench-latent' sub-command and are
stored in the 'outdir' in a file named 'latent_vecs_<dataset>.npz'.
- The samples shape is (n_samples, n_subjects, latent_dim). All samples
must have the same number of samples and subjects, but possibly
different latent dimensions.
"""
print_title(f"COMPARE MODELS USING RSA ANALYSIS: {dataset}")
benchdir = outdir
print_text(f"Benchmark directory: {benchdir}")
print_subtitle("Loading data...")
latent_data = np.load(os.path.join(benchdir, f"latent_vecs_{dataset}.npz"))
smats, shape = {}, None
for key in latent_data.keys():
samples = latent_data[key]
assert samples.ndim == 3, (
"Expect samples with shape (n_samples, n_subjects, latent_dim).")
if shape is None:
shape = samples.shape
n_samples, n_subjects, _ = samples.shape
assert n_samples == shape[0], (
"All samples must have the same number of samples.")
assert n_subjects == shape[1], (
"All samples must have the same number of subjects.")
smats[key] = data2mat(samples)
n_subjects = smats[key].shape[1]
print_text(f"{key} similarities: {smats[key].shape}")
meta_df = pd.read_csv(
os.path.join(benchdir, f"latent_meta_{dataset}.tsv"), sep="\t")
clinical_scores = ["asd", "age", "sex", "site", "fsiq"]
scale_scores = ["ordinal", "ratio", "ordinal", "ratio", "ratio"]
scores = dict((qname, scale)
for qname, scale in zip(clinical_scores, scale_scores))
indices = range(n_subjects)
cmats = dict()
cidxs = dict()
le = LabelEncoder()
clinical_scores = meta_df.columns
for qname in clinical_scores:
if qname not in scores:
print_error(f"Unknown score {qname}, use default ratio scale.")
if qname in ("site", "sex"):
meta_df[qname] = le.fit_transform(meta_df[qname].values)
scale = scores.get(qname, "ratio")
vec = meta_df[qname].values[indices]
idx = ~np.isnan(vec)
vec = vec[idx]
cmat = vec2mat(vec, data_scale=scale)
cmats[qname] = cmat
cidxs[qname] = idx
print_text(f"{qname} number of outliers measures: {np.sum(~idx)}")
print_text(f"{qname} features similarities: {cmat.shape}")
print_subtitle("Compute RSA...")
data = dict((key, arr[:, indices][..., indices])
for key, arr in smats.items())
rsa_results, rsa_records = dict(), dict()
for qname in clinical_scores:
for key, smat in data.items():
res = fit_rsa(smat, cmats[qname], idxs=cidxs[qname])
n_samples = len(res)
rsa_records.setdefault(key, []).extend(res.tolist())
rsa_results.setdefault(qname, {})[key] = res
rsa_records.setdefault("score", []).extend([qname] * n_samples)
rsa_df = pd.DataFrame.from_dict(rsa_records)
print(rsa_df.groupby("score").describe().loc[
:, (slice(None), ["count", "mean", "std"])])
rsa_df.to_csv(os.path.join(benchdir, "rsa.tsv"), sep="\t", index=False)
print_subtitle("Display subject's (dis)similarity matrices...")
ncols = n_samples
nrows = len(data)
plt.figure(figsize=np.array((ncols, nrows)) * 4)
idx1 = 0
for name, sdata in data.items():
_name = " ".join(name.split("_")[:-1])
for idx2, smat in enumerate(sdata):
ax = plt.subplot(nrows, ncols, idx1 + 1)
plot_mat(f"{_name} ({idx2 + 1})", smat, ax=ax, figsize=None,
dpi=300, fontsize=12)
idx1 += 1
plt.subplots_adjust(
left=None, bottom=None, right=None, top=None, wspace=.5, hspace=.5)
plt.suptitle(f"{dataset.upper()} SUBJECTS (S) MAT", fontsize=20, y=.95)
filename = os.path.join(benchdir, f"sub_mat_{dataset}.png")
plt.savefig(filename)
print_result(f"subjects mat: {filename}")
print_subtitle("Display score's (dis)similarity matrices...")
ncols = 4
nrows = int(np.ceil(len(cmats) / ncols))
plt.figure(figsize=np.array((ncols, nrows)) * 4)
for idx, (name, cmat) in enumerate(cmats.items()):
_name = " ".join(name.split("_"))
ax = plt.subplot(nrows, ncols, idx + 1)
plot_mat(_name.upper(), cmat, ax=ax, figsize=None, dpi=300,
fontsize=12)
plt.subplots_adjust(
left=None, bottom=None, right=None, top=None, wspace=.5, hspace=.5)
plt.suptitle(f"{dataset.upper()} CLINICAL (C) MAT", fontsize=20, y=.95)
filename = os.path.join(benchdir, f"clinical_mat_{dataset}.png")
plt.savefig(filename)
print_result(f"clinical mat: {filename}")
print_subtitle("Display Kendall tau statistics...")
ncols = 3
nrows = int(np.ceil(len(clinical_scores) / ncols))
plt.figure(figsize=np.array((ncols, nrows)) * 4)
pairwise_stats = []
for idx, qname in enumerate(clinical_scores):
ax = plt.subplot(nrows, ncols, idx + 1)
pairwise_stat_df = plot_bar(
qname, rsa_results, ax=ax, figsize=None, dpi=300, fontsize=7,
fontsize_star=12, fontweight="bold", line_width=2.5,
marker_size=3, title=qname.upper(), report_t=True,
do_one_sample_stars=True, do_pairwise_stars=True, palette="Set2",
yname="correlation")
if pairwise_stat_df is not None:
pairwise_stats.append(pairwise_stat_df)
if len(pairwise_stats) > 0:
pairwise_stat_df = pd.concat(pairwise_stats)
pairwise_stat_df.to_csv(
os.path.join(benchdir, "rsa_pairwise_stats.tsv"), sep="\t",
index=False)
plt.subplots_adjust(
left=None, bottom=None, right=None, top=None, wspace=.5, hspace=.5)
plt.suptitle(f"{dataset.upper()} RSA RESULTS", fontsize=20, y=.95)
filename = os.path.join(benchdir, f"rsa_{dataset}.png")
plt.savefig(filename)
print_result(f"RSA: {filename}")
Follow us