Menu

Benchmark multi-model/multi-view models.

Source code for mmbench.plotting

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2022
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Plotting utility functions.
"""

# Imports
import os
from itertools import combinations
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.stats import ttest_1samp
from scipy.stats import ttest_ind as ttest
from mmbench.color_utils import print_subtitle, print_result


[docs]def plot_barrier_clusters(data, labels, scores, task_name, metric_name): """ Display the barrier clustering results. Parameters ---------- data: (N, t) the time courses obtained when interpolating the weights of any two pairs of intances. labels: (N, ) the time courses clusters. scores: (K, ) the clustering metrics used to determine to best number of clusters. task_name: str the task name used in the barrier expermiement. metric_name: str metric name used to select the best number of clusters. """ fontparams = {"font.size": 11, "font.weight": "bold", "font.family": "serif", "font.style": "italic"} plt.rcParams.update(fontparams) labelparams = {"size": 16, "weight": "semibold", "family": "serif"} unique_labels = np.unique(labels) n_cluster = len(unique_labels) max_clusters = len(scores) alpha = range(data.shape[-1]) xk = range(1, max_clusters + 1) cmap = cm.get_cmap("hsv", max_clusters) fig = plt.figure() plt.subplot(1, 2, 1) ax = plt.gca() for label in unique_labels: ts = data[labels == label] mean_ts = np.mean(ts, axis=0) std_ts = np.std(ts, axis=0) ax.plot(alpha, mean_ts, label=f"basin {label + 1}", c=cmap(label)) ax.fill_between(alpha, mean_ts - std_ts, mean_ts + std_ts, alpha=0.3, facecolor=cmap(label)) ax.spines[["right", "top"]].set_visible(False) ax.set_xlabel(r"$\alpha$", labelparams) ax.set_ylabel(task_name, labelparams) handles, labels = ax.get_legend_handles_labels() kw = dict(ncol=len(handles), loc="lower center", frameon=False) leg = ax.legend(handles, labels, bbox_to_anchor=[0.5, 1.04], **kw) ax.add_artist(leg) fig.subplots_adjust(top=0.9) plt.subplot(1, 2, 2) ax = plt.gca() ax.plot(xk, scores) plt.vlines(n_cluster, plt.ylim()[0], plt.ylim()[1], linestyles="dashed") plt.text(n_cluster, (plt.ylim()[0] + plt.ylim()[1]) / 2, f"k={n_cluster}", ha="center", va="center", rotation="vertical", backgroundcolor="white") ax.spines[["right", "top"]].set_visible(False) ax.set_xlabel("k", labelparams) ax.set_ylabel(metric_name, labelparams) return fig
[docs]def plot_mat(key, mat, ax=None, figsize=(5, 2), dpi=300, fontsize=16, fontweight="bold", title=None, vmin=None, vmax=None): """ Display a mat array. Parameters ---------- key: str the mat array identifier. mat: array (n, n) the mat array to display. ax: matplotlib.axes.Axes, default None the axes used to display the plot. figsize: (float, float), default (5, 2) width, height in inches. dpi: float, default 300 the resolution of the figure in dots-per-inch. fontsize: int or str, default 16 size in points or relative size, e.g., 'smaller', 'x-large'. fontweight: str, default 'bold' the font weight, e.g. 'normal', 'bold', 'heavy', 'light', 'ultrabold' or 'ultralight'. title: str, default None the title displayed on the figure. vmin: float, default None minimum value on y-axis of figures. vmax: float, default None maximum value on y-axis of figures. """ if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) ax.imshow(mat, aspect="auto", cmap="Reds", vmin=vmin, vmax=vmax) if title is None: plt.title(key, fontsize=fontsize * 1.5, pad=2, fontweight=fontweight) else: plt.title(title, fontsize=fontsize * 1.5, pad=2, fontweight=fontweight)
[docs]def plot_bar(key, rsa, ax=None, figsize=(5, 2), dpi=300, fontsize=16, fontsize_star=25, fontweight="bold", line_width=2.5, marker_size=.1, title=None, palette="Spectral", report_t=False, do_pairwise_stars=False, do_one_sample_stars=True, yname="model fit (r)"): """ Display results with bar plots. Parameters ---------- key: str the analysis identifier. rsa: dict of dict the sampling data with the analysis identifiers as first key and experimental conditions as second key. ax: matplotlib.axes.Axes, default None the axes used to display the plot. figsize: (float, float), default (5, 2) width, height in inches. dpi: float, default 300 the resolution of the figure in dots-per-inch. fontsize: int or str, default 16 size in points or relative size, e.g., 'smaller', 'x-large'. fontsize_star: int or str, default 25 size in points or relative size, e.g., 'smaller', 'x-large' used to display pairwise statistics. fontweight: str, default 'bold' the font weight, e.g. 'normal', 'bold', 'heavy', 'light', 'ultrabold' or 'ultralight'. line_width: int, default 2.5 the bar plot line width. marker_size: int, default .1 the sampling scatter plot marker size. title: str, default None the title displayed on the figure. palette: palette name, list, or dict colors to use for the different levels of the hue variable. Should be something that can be interpreted by color_palette(), or a dictionary mapping hue levels to matplotlib colors. report_t: bool, default False optionally, generate a report with pairwise statistics. do_pairwise_stars: bool, default False optionally, display pairwise statistics. do_one_sample_stars: bool, default True optionally, display sampling statistics. yname: str, default 'model fit (r)' optionally, name of the metric on y-axis. Returns ------- pairwise_stat_df: pandas.DataFrame or None the generated pairwise statistics. """ if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) data = rsa[key] _data = {} for cond in list(data.keys()): _data.setdefault("model fit (r)", []).extend(data[cond]) _data.setdefault("condition", []).extend([cond] * len(data[cond])) data_df = pd.DataFrame.from_dict(_data) sns.stripplot(data=data_df, x="condition", y="model fit (r)", jitter=0.15, alpha=1.0, color="k", size=marker_size) plot = sns.barplot(data=data_df, x="condition", y="model fit (r)", errcolor="r", alpha=0.3, linewidth=line_width, errwidth=line_width, palette=palette) for patch in plot.containers[0]: fc = patch.get_edgecolor() patch.set_edgecolor(mcolors.to_rgba(fc, 1.)) locs, labels = plt.yticks() new_y = locs new_y = np.linspace(locs[0], locs[-1], 6) plt.yticks(new_y, labels=[f"{yy:.2f}" for yy in new_y], fontsize=fontsize, fontweight=fontweight) plt.ylabel(yname, fontsize=fontsize, fontweight=fontweight) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) for axis in ["top", "bottom", "left", "right"]: ax.spines[axis].set_linewidth(line_width) xlabels = [item.get_text() for item in ax.get_xticklabels()] _xlabels = ["\n".join(item.split("_")[:-1]) for item in xlabels] ax.set_xticklabels(_xlabels, fontsize=fontsize, fontweight=fontweight) x_label = ax.axes.get_xaxis().get_label() x_label.set_visible(False) ylim = plt.ylim() plt.ylim(np.array(ylim) * (1, 1.1)) if title is None: plt.title(key, fontsize=fontsize * 1.5, pad=2, fontweight=fontweight) else: plt.title(title, fontsize=fontsize * 1.5, pad=2, fontweight=fontweight) if do_one_sample_stars: one_sample_thresh = np.array((1, .05, .001, .0001)) one_sample_stars = np.array(("n.s.", "*", "**", "***")) for idx, name in enumerate(xlabels): one_sample = ttest_1samp(data[name], 0) these_stars = one_sample_stars[ max(np.nonzero(one_sample.pvalue < one_sample_thresh)[0])] _xlabels[idx] = f"{_xlabels[idx]}\n({these_stars})" ax.set_xticklabels(_xlabels, fontsize=fontsize, fontweight=fontweight) if report_t or do_pairwise_stars: size = len(xlabels) pairwise_t = np.zeros((size, size)) pairwise_p = np.zeros((size, size)) _data = dict() for idx1, name1 in enumerate(xlabels): for idx2, name2 in enumerate(xlabels): n_samples = len(data[name1]) tval, pval = ttest(data[name1], data[name2]) if pval > .001: print(f"{key} {name1} > {name2} | " f"t({n_samples-1}) = {tval:.2f} p = {pval:.2f}") else: print(f"{key} {name1} > {name2} | " f"t({n_samples-1}) = {tval:.2f} p < .001") pairwise_t[idx1, idx2] = tval pairwise_p[idx1, idx2] = pval _data.setdefault("pair", []).append( f"qname-{key}_src-{name1.replace('_', '-')}_" f"dest-{name2.replace('_', '-')}") _data.setdefault("tval", []).append(tval) _data.setdefault("pval", []).append(pval) pairwise_stat_df = pd.DataFrame.from_dict(_data) else: pairwise_stat_df = None if do_pairwise_stars: from statannotations.Annotator import Annotator pairwise_sample_thresh = np.array((1, .05, .001, .0001)) pairwise_sample_stars = np.array(("n.s.", "*", "**", "***")) comps = list(combinations(range(len(xlabels)), 2)) pairs, annotations = [], [] for comp_idx, this_comp in enumerate(comps): sig_idx = max(np.nonzero( pairwise_p[this_comp[0], this_comp[1]] < pairwise_sample_thresh)[0]) if sig_idx != 0: stars = pairwise_sample_stars[sig_idx] pairs.append([xlabels[this_comp[0]], xlabels[this_comp[1]]]) annotations.append(stars) if len(pairs) > 0: annotator = Annotator( ax, pairs, data=data_df, x="condition", y="model fit (r)", order=xlabels) annotator.set_custom_annotations(annotations) annotator.annotate() return pairwise_stat_df
[docs]def barrier_display(coeffs, l_metrics, model_name, downstream, dataset, outdir, scale, sname): """ Save barrier curves for a model. Parameters ---------- coeffs : list the abscissa of the graph. l_metrics : array (n, n, n_coeffs) value matrix of the curve between two models. model_name : str name of the model. downstream : str name of the downstream task. dataset: str the dataset name: euaims or hbn. outdir : str the destination folder. scale : tuple (min, max) min and max values of matrix in matrices. sname : str the name of the scorer. """ print_subtitle(f"Display {model_name}_{downstream} figures...") ncols = 3 nrows = 4 plt.figure(figsize=np.array((ncols, nrows)) * 4) for idx, row in enumerate(l_metrics): ax = plt.subplot(nrows, ncols, idx + 1) plot_curve( coeffs, row, ax=ax, figsize=None, dpi=300, fontsize=7, fontweight="bold", title=f"from run {idx + 1}") ax.set_ylim(scale[0], scale[1]) ax.set_ylabel(sname) plt.subplots_adjust( left=None, bottom=None, right=None, top=None, wspace=1, hspace=.5) plt.suptitle(f"{model_name} {downstream} BARRIER FIGURES", fontsize=20, y=.95) filename = os.path.join(outdir, f"barrier_{model_name}_{downstream}_{dataset}.png") plt.savefig(filename) print_result(f"BARRIER: {filename}")
[docs]def mat_display(matrices, dataset, outdir, downstream_name, scale): """ Plot area matrices. Parameters ---------- matrices : dict area matrix dictionaries by models. dataset: str the dataset name: euaims or hbn. outdir : str the destination folder. downstream_name: str the name of the column that contains the downstream classification task. scale : tuple (min, max) min and max values of matrix in matrices. """ ncols = 2 nrows = 3 plt.figure(figsize=np.array((ncols, nrows)) * 4) for idx, key in enumerate(matrices): ax = plt.subplot(nrows, ncols, idx + 1) plot_mat( key, matrices[key], ax=ax, figsize=None, dpi=300, fontsize=7, fontweight="bold", title=f"{key}", vmin=scale[0], vmax=scale[1]) ax.set_xticks(np.arange(0, 10, 2)) ax.set_yticks(np.arange(0, 10, 2)) ax.set_xticklabels(np.arange(1, 11, 2)) ax.set_yticklabels(np.arange(1, 11, 2)) plt.colorbar(ax.images[0], ax=ax) plt.subplots_adjust( left=None, bottom=None, right=None, top=None, wspace=.5, hspace=.5) plt.suptitle(f"{dataset} BARRIER AREA", fontsize=20, y=.95) filename = os.path.join(outdir, f"barrier_area_{downstream_name}_{dataset}.png") plt.savefig(filename) print_result(f"AREA: {filename}")
[docs]def plot_curve(xticks, mat, ax=None, figsize=(5, 2), dpi=300, fontsize=16, fontweight="bold", title=None): """ Display a list of curve. Parameters ---------- xticks: list the list of xtick locations. mat: array (n_curve, n_points) the matrix containing the points of the curves. ax: matplotlib.axes.Axes, default None the axis used to display the plot. figsize: (float, float), default (5, 2) width, height in inches. dpi: float, default 300 the resolution of the figure in dots-per-inch. fontsize: int or str, default 16 size in points or relative size, e.g., 'smaller', 'x-large'. fontweight: str, default 'bold' the font weight, e.g. 'normal', 'bold', 'heavy', 'light', 'ultrabold' or 'ultralight'. title: str, default None the title displayed on the figure. """ if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) for idx, elem in enumerate(mat): ax.plot(xticks, elem, label=f"to {idx+1}") box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) if title is None: plt.title(xticks, fontsize=fontsize * 1.5, pad=2, fontweight=fontweight) else: plt.title(title, fontsize=fontsize * 1.5, pad=2, fontweight=fontweight)

Follow us

© 2023, mmbench developers