Benchmark multi-model/multi-view models.
Source code for mmbench.workflow.barrier
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2022 - 2023
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Define barriere experiments.
"""
# Imports
import os
import copy
from pprint import pprint
import numpy as np
import torch
from mmbench.config import ConfigParser
from mmbench.color_utils import (
print_title, print_subtitle, print_text, print_result)
from mmbench.dataset import (
get_test_data, get_train_data, get_test_full_data, get_train_full_data)
from mmbench.workflow.predict import get_predictor
from mmbench.model import get_models
from brainboard.metric import eval_interpolation
from mmbench.plotting import mat_display, barrier_display
[docs]def benchmark_barrier_exp(dataset, datasetdir, configfile, outdir,
downstream_name, dtype="full", n_coeffs=10):
""" Compare the performance barrier interpolating the weights of any two
pairs of intances of the same network and monitoring a common downstream
task.
Parameters
----------
dataset: str
the dataset name: euaims or hbn.
datasetdir: str
the path to the dataset associated data.
configfile: str
the path to the config file descibing the different models to compare.
This configuration file is a Python (\*.py) file with a dictionary
named '_models' containing the different model settings. Keys of this
dictionary are the model names, each beeing described with a model
getter function 'get' and associated kwargs 'get_kwargs', as weel as
an evaluation function 'eval' and associated kwargs 'eval_kwargs'.
The getter and evaluation functions are defined in the 'mmbench.model'
module.
outdir: str
the destination folder.
downstream_name: str
the name of the column that contains the downstream classification
task.
dtype: str, default 'full'
the data type: 'complete' or 'full'.
n_coeffs: int, default 10
number of interpolation points.
"""
print_title(f"COMPARE MODEL WEIGHTS: {dataset}")
assert dtype in ("complete", "full")
benchdir = outdir
if not os.path.isdir(benchdir):
os.mkdir(benchdir)
print_text(f"Benchmark directory: {benchdir}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print_subtitle("Loading data...")
modalities = ["clinical", "rois"]
print_text(f"modalities: {modalities}")
if dtype == "full":
train_loader, test_loader = (get_train_full_data, get_test_full_data)
else:
train_loader, test_loader = (get_train_data, get_test_data)
data_train, meta_train_df = train_loader(dataset, datasetdir, modalities)
assert downstream_name in meta_train_df.columns, (
f"Specify a downstream task from: {meta_train_df.columns}")
data_test, meta_test_df = test_loader(dataset, datasetdir, modalities)
y_train = meta_train_df[downstream_name]
y_test = meta_test_df[downstream_name]
for mod in modalities:
data_train[mod] = data_train[mod].to(device).float()
data_test[mod] = data_test[mod].to(device).float()
print_text([(key, arr.shape) for key, arr in data_train.items()])
print_text(meta_train_df)
print_text([(key, arr.shape) for key, arr in data_test.items()])
print_text(meta_test_df)
print_subtitle("Parsing config...")
parser = ConfigParser("latent-config", configfile)
pprint(parser.config.models)
print_subtitle("Loading models...")
models = {}
default_params = {
"n_channels": len(modalities),
"n_feats": [data_test[mod].shape[1] for mod in modalities],
"modalities": modalities}
for name, params in parser.config.models.items():
checkpoints = params["get_kwargs"]["checkpointfile"]
if not isinstance(checkpoints, (list, tuple)):
continue
_models = get_models(
params["get"],
**parser.set_auto_params(params["get_kwargs"], default_params))
eval_kwargs = parser.set_auto_params(
params["eval_kwargs"], default_params)
eval_kwargs["n_samples"] = 1
eval_kwargs["verbose"] = 0
if name == "sMCVAE":
eval_kwargs["threshold"] = None
models[name] = (_models, params["eval"], eval_kwargs)
for name, (_models, _, _) in models.items():
print_text(f"model: {name}")
print(_models[0])
print_subtitle("Evaluate models...")
def eval_fn(model, loaders, y_train, y_test, eval_fn=None,
eval_kwargs=None):
model.eval()
with torch.no_grad():
X = []
for data in loaders:
if eval_fn is not None:
z = eval_fn(model, data, **eval_kwargs).values()
z = np.concatenate(list(z), axis=1)
else:
z = model(data).cpu().detach().numpy()
X.append(z)
X_train, X_test = X
clf, scorer, sname = get_predictor(y_train)
clf.fit(X_train, y_train)
return scorer(clf, X_test, y_test)
results_test, results_curve = {}, {}
scale = [100, -100]
_, _, sname = get_predictor(y_test)
for name, (_models, eval_fct, eval_kwargs) in models.items():
if not isinstance(_models[0], torch.nn.Module):
continue
print_text(f"model: {name}")
kwargs = {"eval_fn": eval_fct, "eval_kwargs": eval_kwargs,
"y_train": y_train, "y_test": y_test}
n_models = len(_models)
iu = np.array(np.triu_indices(n_models, k=0)).T
mat = np.zeros((n_models, n_models))
points_curve = np.zeros((n_models, n_models, n_coeffs))
for i1, i2 in iu:
model1 = _models[i1].to(device).eval()
model2 = _models[i2].to(device).eval()
state1 = model1.state_dict()
state2 = model2.state_dict()
coeffs, metrics = eval_interpolation(
copy.deepcopy(model1), state1, state2, [data_train, data_test],
eval_fn, n_coeffs=n_coeffs, eval_kwargs=kwargs)
points_curve[i1, i2] = metrics
points_curve[i2, i1] = metrics[::-1]
mat[i1, i2] = area(metrics, coeffs)
mat[i2, i1] = mat[i1, i2]
vmax = np.max(points_curve)
vmin = np.min(points_curve)
y_axes = [vmin - 0.05 * (vmax - vmin), vmax + 0.05 * (vmax - vmin)]
barrier_display(coeffs, points_curve, name, downstream_name,
dataset, benchdir, y_axes, sname)
min_val = np.min(mat)
max_val = np.max(mat)
if scale[0] > min_val:
scale[0] = min_val
if scale[1] < max_val:
scale[1] = max_val
print(mat)
results_test[name] = mat
results_curve[name] = points_curve
mat_display(results_test, dataset, outdir, downstream_name, scale)
barrier_file = os.path.join(
benchdir, f"barrier_interp_{dataset}_{downstream_name}.npz")
np.savez_compressed(barrier_file, **results_test)
curve_file = os.path.join(
benchdir, f"barrier_curves_{dataset}_{downstream_name}.npz")
np.savez_compressed(curve_file, **results_curve)
print_result(f"barrier interpolation: {barrier_file}, {curve_file}")
[docs]def area(y, x):
""" Calculation of the area between a curve y and the line (ax + b)
joining its two extrem values.
Parameters
----------
y: list
the points of the curve.
x: list
x-axis.
Returns
-------
area: float
area of the curve y relative to its base line.
"""
slope = (y[-1] - y[0]) / (x[-1] - x[0])
intercept = y[0] - slope * x[0]
ref = slope * x + intercept
upref = [max(y1, y2) for y1, y2 in zip(ref, y)]
downref = [min(y1, y2) for y1, y2 in zip(ref, y)]
area = np.trapz(upref, x) - np.trapz(downref, x)
return area
Follow us