Benchmark multi-model/multi-view models.
Source code for mmbench.baseline.pls
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2022 - 2023
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Define the PLS model.
"""
# Imports
import os
import torch
from mmbench.dataset import get_train_data
from sklearn.cross_decomposition import PLSRegression
from sklearn.model_selection import train_test_split
from mmbench.color_utils import print_title, print_subtitle
from joblib import dump
[docs]def train_pls(dataset, datasetdir, outdir, fit_lat_dims=3, n_iter=10,
random_state=None):
""" Train the PLS model
Parameters
----------
dataset: str
the dataset name: euaims or hbn.
datasetdir: str
the path to the dataset associated data.
outdir: str
the destination folder.
fit_lat_dims: int, default 3
the number of latent dimensions.
n_iter: int, default 10
the number of trained models.
random_state: list of int, default None
controls the shuffling applied to the data before applying the split.
Pass a list of n_sampoles int for reproducible output across multiple
function calls.
Note
----
The generated model is stored in 'outdir' in a file named
'pls_model.joblib'. 'outdir' must correspond to the path given in the
configuration file for the PLS checkpointfile.
"""
print_title("PLS ")
print_subtitle("Loading data...")
modalities = ["clinical", "rois"]
X_train, _ = get_train_data(dataset, datasetdir, modalities)
del X_train["index"]
print("train:", [(key, arr.shape) for key, arr in X_train.items()])
Y_train, X_train = [X_train[mod].to(torch.float32) for mod in modalities]
if not os.path.isdir(outdir):
os.mkdir(outdir)
print_subtitle("Create models...")
if random_state is None:
random_state = [None] * n_iter
for idx in range(n_iter):
Xi_train, _, Yi_train, _ = train_test_split(
X_train, Y_train, test_size=0.2, random_state=random_state[idx])
pls = PLSRegression(n_components=fit_lat_dims)
pls.fit(Xi_train, Yi_train)
model_file = os.path.join(outdir, f"pls_model_{idx}.joblib")
dump(pls, model_file)
Follow us