Reconstruction of visual stimuli from Miyawaki et al. 2008

This example reproduces the experiment presented in Miyawaki et al.[1].

It reconstructs 10x10 binary images from functional MRI data. Random images are used as training set and structured images are used for reconstruction.

The code is a bit elaborate as the example uses, as the original article, a multiscale prediction on the images seen by the subject.

See also

For an encoding approach for the same dataset, see Encoding models for visual stimuli from Miyawaki et al. 2008

Note

If you are using Nilearn with a version older than 0.9.0, then you should either upgrade your version or import maskers from the input_data module instead of the maskers module.

That is, you should manually replace in the following example all occurrences of:

from nilearn.maskers import NiftiMasker

with:

from nilearn.input_data import NiftiMasker
import sys
import time

First we load the Miyawaki dataset

from nilearn import datasets
from nilearn.plotting import show

sys.stderr.write("Fetching dataset...")
t0 = time.time()

miyawaki_dataset = datasets.fetch_miyawaki2008()

# print basic information on the dataset
print(
    "First functional nifti image (4D) is located "
    f"at: {miyawaki_dataset.func[0]}"
)

X_random_filenames = miyawaki_dataset.func[12:]
X_figure_filenames = miyawaki_dataset.func[:12]
y_random_filenames = miyawaki_dataset.label[12:]
y_figure_filenames = miyawaki_dataset.label[:12]
y_shape = (10, 10)

sys.stderr.write(f" Done ({time.time() - t0:.2f}s).\n")

Then we prepare and mask the data

import numpy as np

from nilearn.maskers import MultiNiftiMasker

sys.stderr.write("Preprocessing data...")
t0 = time.time()

# Load and mask fMRI data
masker = MultiNiftiMasker(
    mask_img=miyawaki_dataset.mask, detrend=True, standardize=False, n_jobs=2
)
masker.fit()
X_train = masker.transform(X_random_filenames)
X_test = masker.transform(X_figure_filenames)

y_train = [
    np.reshape(
        np.loadtxt(y, dtype=int, delimiter=","), (-1, *y_shape), order="F"
    )
    for y in y_random_filenames
]
y_test = [
    np.reshape(
        np.loadtxt(y, dtype=int, delimiter=","), (-1, *y_shape), order="F"
    )
    for y in y_figure_filenames
]
X_train = np.vstack([x[2:] for x in X_train])
y_train = np.vstack([y[:-2] for y in y_train]).astype(float)
X_test = np.vstack([x[2:] for x in X_test])
y_test = np.vstack([y[:-2] for y in y_test]).astype(float)

n_features = X_train.shape[1]


def flatten(list_of_2d_array):
    flattened = [array.ravel() for array in list_of_2d_array]
    return flattened


# Build the design matrix for multiscale computation
# Matrix is squared, y_rows == y_cols
y_cols = y_shape[1]

# Original data
design_matrix = np.eye(100)

# Example of matrix used for multiscale (sum pixels vertically)
#
# 0.5 *
#
# 1 1 0 0 0 0 0 0 0 0
# 0 1 1 0 0 0 0 0 0 0
# 0 0 1 1 0 0 0 0 0 0
# 0 0 0 1 1 0 0 0 0 0
# 0 0 0 0 1 1 0 0 0 0
# 0 0 0 0 0 1 1 0 0 0
# 0 0 0 0 0 0 1 1 0 0
# 0 0 0 0 0 0 0 1 1 0
# 0 0 0 0 0 0 0 0 1 1

height_tf = (np.eye(y_cols) + np.eye(y_cols, k=1))[: y_cols - 1] * 0.5
width_tf = height_tf.T

yt_tall = [np.dot(height_tf, m) for m in y_train]
yt_large = [np.dot(m, width_tf) for m in y_train]
yt_big = [np.dot(height_tf, np.dot(m, width_tf)) for m in y_train]

# Add it to the training set
y_train = [
    np.r_[y.ravel(), t.ravel(), l.ravel(), b.ravel()]
    for y, t, l, b in zip(y_train, yt_tall, yt_large, yt_big)
]

y_test = np.asarray(flatten(y_test))
y_train = np.asarray(y_train)

# Remove rest period
X_train = X_train[y_train[:, 0] != -1]
y_train = y_train[y_train[:, 0] != -1]
X_test = X_test[y_test[:, 0] != -1]
y_test = y_test[y_test[:, 0] != -1]

sys.stderr.write(f" Done ({time.time() - t0:.2f}s).\n")

We define our prediction function

sys.stderr.write("Training classifiers... \r")
t0 = time.time()

from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

# Create as many OrthogonalMatchingPursuit as voxels to predict
clfs = []
n_clfs = y_train.shape[1]
for i in range(y_train.shape[1]):
    sys.stderr.write(
        f"Training classifiers {int(i + 1):03}/{int(n_clfs)}... \r"
    )

    clf = Pipeline(
        [
            ("selection", SelectKBest(f_classif, k=500)),
            ("scl", StandardScaler()),
            ("clf", OrthogonalMatchingPursuit(n_nonzero_coefs=10)),
        ]
    )
    clf.fit(X_train, y_train[:, i])
    clfs.append(clf)

sys.stderr.write(
    f"Training classifiers {n_clfs:03d}/{n_clfs:d}... "
    f"Done ({(time.time() - t0):.2f}s).\n"
)

Here we run the prediction: the decoding itself

sys.stderr.write("Calculating scores and outputs...")
t0 = time.time()

y_pred = [clf.predict(X_test) for clf in clfs]
y_pred = np.asarray(y_pred).T


# We need to the multi scale reconstruction
def split_multi_scale(y, y_shape):
    """Split data into 4 original multi_scale images"""
    yw, yh = y_shape

    # Index of original image
    split_index = [yw * yh]
    # Index of large image
    split_index.append(split_index[-1] + (yw - 1) * yh)
    # Index of tall image
    split_index.append(split_index[-1] + yw * (yh - 1))
    # Index of big image
    split_index.append(split_index[-1] + (yw - 1) * (yh - 1))

    # We split according to computed indices
    y_preds = np.split(y, split_index, axis=1)

    # y_pred is the original image
    y_pred = y_preds[0]

    # y_pred_tall is the image with 1x2 patch application. We have to make
    # some calculus to get it back in original shape
    height_tf_i = (np.eye(y_cols) + np.eye(y_cols, k=-1))[
        :, : y_cols - 1
    ] * 0.5
    height_tf_i.flat[0] = 1
    height_tf_i.flat[-1] = 1
    y_pred_tall = [
        np.dot(height_tf_i, np.reshape(m, (yw - 1, yh))).flatten()
        for m in y_preds[1]
    ]
    y_pred_tall = np.asarray(y_pred_tall)

    # y_pred_large is the image with 2x1 patch application. We have to make
    # some calculus to get it back in original shape
    width_tf_i = (np.eye(y_cols) + np.eye(y_cols, k=1))[: y_cols - 1] * 0.5
    width_tf_i.flat[0] = 1
    width_tf_i.flat[-1] = 1
    y_pred_large = [
        np.dot(np.reshape(m, (yw, yh - 1)), width_tf_i).flatten()
        for m in y_preds[2]
    ]
    y_pred_large = np.asarray(y_pred_large)

    # y_pred_big is the image with 2x2 patch application. We use previous
    # matrices to get it back in original shape
    y_pred_big = [
        np.dot(np.reshape(m, (yw - 1, yh - 1)), width_tf_i) for m in y_preds[3]
    ]
    y_pred_big = [
        np.dot(height_tf_i, np.reshape(m, (yw - 1, yh))).flatten()
        for m in y_pred_big
    ]
    y_pred_big = np.asarray(y_pred_big)

    return (y_pred, y_pred_tall, y_pred_large, y_pred_big)


y_pred, y_pred_tall, y_pred_large, y_pred_big = split_multi_scale(
    y_pred, y_shape
)

y_pred = (
    0.25 * y_pred
    + 0.25 * y_pred_tall
    + 0.25 * y_pred_large
    + 0.25 * y_pred_big
)

sys.stderr.write(f" Done ({time.time() - t0:.2f}s).\n")

Let us quantify our prediction error

from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)

print("Scores")
print("------")
accuracy_to_print = np.mean(
    [accuracy_score(y_test[:, i], y_pred[:, i] > 0.5) for i in range(100)]
)
print(f"  - Accuracy (percent): {accuracy_to_print:f}")

precision_to_print = np.mean(
    [precision_score(y_test[:, i], y_pred[:, i] > 0.5) for i in range(100)]
)
print(f"  - Precision: {precision_to_print:f}")

recall_to_print = np.mean(
    [
        recall_score(y_test[:, i], y_pred[:, i] > 0.5, zero_division=0)
        for i in range(100)
    ]
)
print(f"  - Recall: {recall_to_print:f}")

f1_score_to_print = np.mean(
    [f1_score(y_test[:, i], y_pred[:, i] > 0.5) for i in range(100)]
)
print(f"  - F1-score: {f1_score_to_print:f}")

And finally, we plot six reconstructed images, to compare with ground truth

from pathlib import Path

from matplotlib import pyplot as plt

output_dir = Path.cwd() / "results" / "plot_miyawaki_reconstruction"
output_dir.mkdir(exist_ok=True, parents=True)
print(f"Output will be saved to: {output_dir}")

for i in range(6):
    j = 10 * i
    fig = plt.figure()
    sp1 = plt.subplot(131)
    sp1.axis("off")
    plt.title("Stimulus")
    sp2 = plt.subplot(132)
    sp2.axis("off")
    plt.title("Reconstruction")
    sp3 = plt.subplot(133)
    sp3.axis("off")
    plt.title("Binarized")
    sp1.imshow(
        np.reshape(y_test[j], (10, 10)),
        cmap="gray",
        interpolation="nearest",
    )
    sp2.imshow(
        np.reshape(y_pred[j], (10, 10)),
        cmap="gray",
        interpolation="nearest",
    )
    sp3.imshow(
        np.reshape(y_pred[j] > 0.5, (10, 10)),
        cmap="gray",
        interpolation="nearest",
    )
    plt.savefig(output_dir / f"miyawaki2008_reconstruction_{int(i)}.png")

show()

References

Estimated memory usage: 0 MB

Gallery generated by Sphinx-Gallery