import os
import pickle
from os.path import join
import ants
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.io import loadmat
from deepmreye.util.data_io import download_mask
# --------------------------------------------------------------------------------
# --------------------------ANTS TRANSFORMS---------------------------------------
# --------------------------------------------------------------------------------
[docs]
def register_to_eye_masks(dme_template, func, masks, verbose=1, transforms=None, metric="GC"):
"""Register functional to DeepMReye template (dme_template) using different sized masks.
Parameters
----------
dme_template : ants Image
Ants image with template file
func : ants Image
Functional image to register to dme_template
masks : list
List of Ants image objects containing variable sized masks
verbose : int, optional
Verbosity level of function, by default 1
transforms : string, optional
Which transforms should be used to transform image, by default None & set to Similarity
metric : str, optional
Which metric to quantify fit, by default 'GC'
Returns
-------
func : ants Image
Functional image registered to dme_template
transformation_stats : array
Statistics of transformation, used for dataset report
"""
transformation_stats = []
for idx, mask in enumerate(masks):
if transforms is None:
type_of_transform = "Similarity"
else:
type_of_transform = transforms[idx]
register_to_nau = ants.registration(
fixed=dme_template,
moving=func.get_average_of_timeseries(),
aff_random_sampling_rate=1,
type_of_transform=type_of_transform,
mask=mask,
aff_metric=metric,
aff_sampling=512,
aff_iterations=(200, 200, 200, 10),
aff_smoothing_sigmas=(0, 0, 0, 0),
)
if verbose > 0:
if "SyN" in type_of_transform:
registered_fwd = loadmat(register_to_nau["fwdtransforms"][1])["AffineTransform_float_3_3"]
else:
registered_fwd = loadmat(register_to_nau["fwdtransforms"][0])["AffineTransform_float_3_3"]
print(
f"Mask {idx}/{len(masks) - 1}, "
f"Sum: {np.sum(registered_fwd):.3f}, "
f"Mean {np.mean(registered_fwd):.3f}, "
f"Std {np.std(registered_fwd):.3f}, "
f"Median {np.median(registered_fwd):.3f}"
)
transformation_stats.append(np.mean(registered_fwd))
# Transform
func = ants.apply_transforms(
fixed=dme_template, moving=func, transformlist=register_to_nau["fwdtransforms"], imagetype=3
)
return func, np.array(transformation_stats)
[docs]
def run_participant(
fp_func, dme_template, eyemask_big, eyemask_small, x_edges, y_edges, z_edges, replace_with=0, transforms=None
):
"""Run preprocessing for one participant with templates and masks preloaded to avoid computational overhead.
Parameters
----------
fp_func : string
Filepath to participant functional
dme_template : ants Image
Preloaded Image to dme_template
eyemask_big : ants Image
Big eyemask as ants Image
eyemask_small : ants Image
Small eyemask as ants Image
x_edges : list
Edges of mask in x-dimension
y_edges : list
Edges of mask in y-dimension
z_edges : list
Edges of mask in z-dimension
replace_with : int, optional
Values outside of mask are set to this, by default 0
"""
# Load subject specific run. File should be Nifti and 4D
# but should also work with other formats which can be read with AntsPy
func = ants.image_read(fp_func)
# Register to deepmreye template (dme_template).
# If registration fails quality check, try below line
# with additional parameter "transforms=['Affine', 'Affine', 'SyNAggro']"
transform_to_dme, transformation_statistics = register_to_eye_masks(
dme_template, func, masks=[None, eyemask_big, eyemask_small], transforms=transforms
)
# Cut mask and save to subject folder with subject report / quality control plots
(original_input, masked_eye, mask) = cut_mask(
transform_to_dme,
eyemask_small.numpy(),
x_edges,
y_edges,
z_edges,
replace_with=replace_with,
save_overview=True,
fp_func=fp_func,
)
return (masked_eye, transformation_statistics)
# --------------------------------------------------------------------------------
# --------------------------MASKING-----------------------------------------------
# --------------------------------------------------------------------------------
[docs]
def get_masks(data_path=""):
"""Load masks for whole brain, big eye mask and small eye mask.
Parameters
----------
data_path : str, optional
Path to where masks are stored in .nii format, by default '../deepmreye/masks/'
Returns
-------
eyemask_small : ants Image
Eyemask containing voxels within the eye
eyemask_big : ants Image
Square eye mask centered on both eyes
dme_template : ants Image
Template brain using centered gaze positions
mask : ants Image
Mask which is used to cut 3D shape for model (in this case the same as eyemask_small)
x_edges : list
Edges of mask in x-dimension
y_edges : list
Edges of mask in y-dimension
z_edges : list
Edges of mask in z-dimension
"""
def load_from_path(fn_mask):
if os.path.exists(fn_mask):
return ants.image_read(fn_mask)
print(f"Downloading mask: {fn_mask}")
download_mask(fn_mask)
return ants.image_read(fn_mask)
if data_path == "":
data_path = os.path.abspath(join(__file__, "..", "masks"))
eyemask_small = load_from_path(join(data_path, "eyemask_small.nii"))
eyemask_big = load_from_path(join(data_path, "eyemask_big.nii"))
dme_template = load_from_path(join(data_path, "dme_template.nii"))
(mask, x_edges, y_edges, z_edges) = get_mask_edges(mask=eyemask_small)
return eyemask_small, eyemask_big, dme_template, mask, x_edges, y_edges, z_edges
[docs]
def get_mask_edges(mask, split=True):
"""Get edges of mask.
Parameters
----------
fp_mask : filepath, optional
Filepath to mask
split : bool, optional
Splits masks into hemispheres, by default True
Returns
-------
mask:
Array of extracted mask edges
x_edges, y_edges, z_edges:
Edges in (x,y,z)-dimension
"""
# Get indices for left and right eye separately
edge_indices = np.where(mask.numpy() == 1)
if split:
# Get left and right based on middle between left and right eye. For collin27 : 45
tmp = np.argmax(np.diff(edge_indices[0]))
middle_cut = (edge_indices[0][tmp] + edge_indices[0][tmp + 1]) // 2
# Get x and y values for both eyes and combine in one volume
left_indices = edge_indices[0][edge_indices[0] < middle_cut]
right_indices = edge_indices[0][edge_indices[0] > middle_cut]
x_edges = (np.max(left_indices), np.min(left_indices), np.max(right_indices), np.min(right_indices))
else:
middle = np.min(edge_indices[0]) + (np.max(edge_indices[0]) - np.min(edge_indices[0])) // 2
x_edges = (middle, np.min(edge_indices[0]), np.max(edge_indices[0]), middle)
y_edges = (np.max(edge_indices[1]), np.min(edge_indices[1]))
z_edges = (np.max(edge_indices[2]), np.min(edge_indices[2]))
return (mask.numpy(), x_edges, y_edges, z_edges)
[docs]
def cut_mask(to_mask, mask, x_edges, y_edges, z_edges, replace_with=0, save_overview=True, fp_func=None, verbose=0):
"""Cut mask into given shape given edges.
Parameters
----------
to_mask : ants Image
Image to mask
mask : ants Image
Mask as numpy array
x_edges : list
Edges of mask in x-dimension
y_edges : list
Edges of mask in y-dimension
z_edges : list
Edges of mask in z-dimension
replace_with : int, optional
Values outside of mask are set to this, by default 0
save_overview : bool, optional
Saves report / quality control figure when set to True, by default True
fp_func : str, optional
Filepath to new functional, by default None
verbose : int, optional
Verbosity level of this function, by default 0
Returns
-------
original_input : ants Image
Returns to_mask
masked_eye : ants Image
masked_eye as numpy array
mask : ants Image
Return mask
"""
# Mask image to set out of mask values
original_input = to_mask.copy()
to_mask[mask < 1, ...] = replace_with
# Slice for mask
masked_eye_left = to_mask[x_edges[1] : x_edges[0], y_edges[1] : y_edges[0], z_edges[1] : z_edges[0], ...]
masked_eye_right = to_mask[x_edges[3] : x_edges[2], y_edges[1] : y_edges[0], z_edges[1] : z_edges[0], ...]
masked_eye = np.concatenate((masked_eye_right, masked_eye_left))
if verbose > 0:
print(f"Voxels > 0 / Mean of voxels: {np.sum(np.mean(masked_eye, axis=3) > 0)} / {np.mean(masked_eye)}")
# Save back masked func to .nii and masked eye to .p
participant_basename = os.path.basename(fp_func).split(".")[0]
if save_overview:
fn_full_mask = join(os.path.dirname(fp_func), f"report_{participant_basename}")
plot_subject_report(fn_full_mask, original_input, masked_eye, mask)
fn_masked_eye = join(os.path.dirname(fp_func), f"mask_{participant_basename}.p")
pickle.dump(masked_eye, open(fn_masked_eye, "wb"))
return (original_input, masked_eye, mask)
# --------------------------------------------------------------------------------
# --------------------------VISUALIZATIONS----------------------------------------
# --------------------------------------------------------------------------------
[docs]
def plot_subject_report(
fn_subject, original_input, masked_eye, mask, color="rgb(0, 150, 175)", bg_color="rgb(14, 17, 23, 0)"
):
"""Plot quality check figure for given subject.
Parameters
----------
fn_subject : string
Filepath to subject
original_input : ants Image
Filepath to functional image of subject
masked_eye : array
Numpy array of masked eye
mask : ants Image
ants mask
color : str, optional
Boxplot color, by default "rgb(0, 150, 175)"
bg_color : str, optional
Background color, by default "rgb(0,0,0)"
"""
# Prepare data
whole_brain_mask = original_input.get_average_of_timeseries()
eye_mask = np.mean(masked_eye, axis=3)
eye_mask_flat = eye_mask.flatten()
# Also remove zero for histogram
eye_mask_flat = eye_mask_flat[eye_mask_flat > 0]
whole_brain_timecourse = np.mean(original_input.numpy(), axis=(0, 1, 2))
tmp = np.mean(masked_eye, axis=(0, 1, 2))
# Normalize
eye_mask_timecourse = (tmp - np.min(tmp)) / (np.max(tmp) - np.min(tmp))
whole_brain_timecourse = (whole_brain_timecourse - np.min(whole_brain_timecourse)) / (
np.max(whole_brain_timecourse) - np.min(whole_brain_timecourse)
)
# Plot
fig = make_subplots(
rows=2, cols=4, column_widths=[0.2, 0.2, 0.2, 0.4], row_heights=[0.6, 0.4], vertical_spacing=0.13
)
fig.add_trace(
go.Heatmap(z=whole_brain_mask[25, :, :].transpose(), showscale=False, colorscale="Greys_r"), row=1, col=1
)
fig.add_trace(
go.Heatmap(
z=mask[25, :, :].transpose(),
showscale=False,
colorscale=[[0, "rgba(0, 0, 0, 0)"], [1.0, "rgba(255, 0, 0, 0.25)"]],
),
row=1,
col=1,
)
fig.add_trace(
go.Heatmap(z=whole_brain_mask[:, 90, :].transpose(), showscale=False, colorscale="Greys_r"), row=1, col=2
)
fig.add_trace(
go.Heatmap(
z=mask[:, 90, :].transpose(),
showscale=False,
colorscale=[[0, "rgba(0, 0, 0, 0)"], [1.0, "rgba(255, 0, 0, 0.25)"]],
),
row=1,
col=2,
)
fig.add_trace(go.Heatmap(z=whole_brain_mask[:, :, 15], showscale=False, colorscale="Greys_r"), row=1, col=3)
fig.add_trace(
go.Heatmap(
z=mask[:, :, 15], showscale=False, colorscale=[[0, "rgba(0, 0, 0, 0)"], [1.0, "rgba(255, 0, 0, 0.25)"]]
),
row=1,
col=3,
)
fig.add_trace(
go.Histogram(
x=eye_mask_flat,
nbinsx=75,
marker={"line": {"width": 0.75, "color": "rgb(255, 255, 255)"}},
marker_color=color,
),
row=1,
col=4,
)
fig.add_trace(go.Heatmap(z=np.mean(eye_mask, axis=0).transpose(), showscale=False, colorscale="Hot"), row=2, col=1)
fig.add_trace(go.Heatmap(z=np.mean(eye_mask, axis=1).transpose(), showscale=False, colorscale="Hot"), row=2, col=2)
fig.add_trace(go.Heatmap(z=np.mean(eye_mask, axis=2), showscale=False, colorscale="Hot"), row=2, col=3)
fig.add_trace(
go.Scatter(x=np.arange(0, len(eye_mask_timecourse)), y=eye_mask_timecourse, marker_color=color), row=2, col=4
)
fig.add_trace(
go.Scatter(
x=np.arange(0, len(whole_brain_timecourse)), y=whole_brain_timecourse, marker_color="rgb(255, 255, 255)"
),
row=2,
col=4,
)
annotations = [
dict(x=0.07, y=1.03, xref="paper", yref="paper", text="x=-20", font=(dict(size=20)), showarrow=False),
dict(x=0.29, y=1.03, xref="paper", yref="paper", text="y=36", font=(dict(size=20)), showarrow=False),
dict(x=0.54, y=1.03, xref="paper", yref="paper", text="z=-30", font=(dict(size=20)), showarrow=False),
dict(
x=0.17,
y=1.1,
xref="paper",
yref="paper",
text="<b>Transformed MNI space with eye mask (r)</b>",
font=(dict(size=20)),
showarrow=False,
),
dict(
x=0.93,
y=1.1,
xref="paper",
yref="paper",
text="<b>Histogram of eye mask voxels</b>",
font=(dict(size=20)),
showarrow=False,
),
dict(
x=0.24,
y=0.42,
xref="paper",
yref="paper",
text="<b>Eye mask voxels</b>",
font=(dict(size=20)),
showarrow=False,
),
dict(
x=0.99,
y=0.41,
xref="paper",
yref="paper",
text="<b>Average across whole brain (w) & eye mask (b)</b>",
font=(dict(size=20)),
showarrow=False,
),
]
fig.update_layout(
autosize=False,
showlegend=False,
width=1400,
height=600,
margin=dict(t=70, l=30, b=50, r=30),
plot_bgcolor=bg_color,
paper_bgcolor=bg_color,
font={"color": "#FFFFFF", "size": 13},
annotations=annotations,
)
fig.update_xaxes(showgrid=False, showticklabels=True, col=4)
fig.update_yaxes(showgrid=False, showticklabels=True, col=4)
fig.update_yaxes(showgrid=False, showticklabels=True, showline=True, col=4, row=1)
# # Remove labels from brain plots
fig.update_yaxes(showgrid=False, showticklabels=True, col=1)
fig.update_yaxes(showgrid=False, showticklabels=True, col=2)
fig.update_yaxes(showgrid=False, showticklabels=True, col=3)
fig.update_xaxes(showgrid=False, showticklabels=True, col=1)
fig.update_xaxes(showgrid=False, showticklabels=True, col=2)
fig.update_xaxes(showgrid=False, showticklabels=True, col=3)
# Add mean and median to hist
fig.add_vline(
x=np.mean(eye_mask_flat),
annotation=dict(text="Mean", y=0.9),
line=dict(color="rgb(255, 255, 255)"),
row=1,
col=4,
)
fig.add_vline(
x=np.median(eye_mask_flat), annotation=dict(text="Median"), line=dict(color="rgb(255, 255, 255)"), row=1, col=4
)
fig.write_html(fn_subject + ".html")
# --------------------------------------------------------------------------------
# -----------------------IMG MANIPULATIONS----------------------------------------
# --------------------------------------------------------------------------------
[docs]
def normalize_img(img_in, mad_time=False, standardize_tr=True, std_cut_after=5):
"""Normalize the 4D input across different dimensions.
Parameters
----------
img_in : ants Image
Image to normalize
mad_time : bool, optional
Determines if median absolute deviation should be used across time dimension, by default False
standardize_tr : bool, optional
Determines if each image should be normalized across spatial dimensions, by default True
std_cut_after : int, optional
Gets rid of outliers after normalization, by default 5
Returns
-------
img_in : ants Image
Normalized output image
"""
# Transpose so time comes first
img_in = np.transpose(img_in, axes=(3, 0, 1, 2))
zero_indices = img_in == 0
img_in[zero_indices] = np.NaN
# Median absolute deviation (MAD)
if mad_time:
est_mean = np.nanmedian(img_in, axis=0)
est_std = np.nanmedian(abs(img_in - est_mean), axis=0)
img_in = (img_in - est_mean) / est_std
else:
img_in = (img_in - np.nanmean(img_in, axis=0)) / np.nanstd(img_in, axis=0)
# Normalize each functional on its own:
if standardize_tr:
img_in = np.array([(x - np.nanmean(x)) / np.nanstd(x) for x in img_in])
if std_cut_after is not None:
std_before = np.nanstd(img_in)
img_in[img_in > std_cut_after * std_before] = std_cut_after * std_before
img_in[img_in < -std_cut_after * std_before] = -std_cut_after * std_before
# If division by zero replace with 0
img_in[~np.isfinite(img_in)] = 0
# Transpose back to original
img_in[zero_indices] = 0
img_in = np.transpose(img_in, axes=(1, 2, 3, 0))
return img_in
# --------------------------------------------------------------------------------
# --------------------------LABEL I/O---------------------------------------------
# --------------------------------------------------------------------------------
[docs]
def load_label(label_path, label_type="calibration_run"):
"""Load label for experiment, which should return X,Y coordinates for each timepoint.
This function can be exchanged for experiment specific loading of labels,
or by using different label types.
Parameters
----------
label_path : str
Path to file with labels
label_type : str, optional
Which type of labels are used in the experiment, by default 'calibration_run'
Returns
-------
this_label : numpy array
X,Y coordinates for each functional describing gaze position during this timepoint.
"""
if label_type == "calibration_run":
# Load labels from file
fn_labels = join(label_path, "stim_vals.csv")
labels = np.genfromtxt(fn_labels, delimiter=",")
labels = labels[1:]
labels = np.repeat(labels, 5, axis=0)
this_label = labels[:, np.newaxis, :]
this_label = np.repeat(this_label, 10, axis=1)
# Normalize label
this_label = (this_label - -0.95) / (0.95 - -0.95)
this_label -= 0.5
# Y-axis is flipped for this dataset
this_label[..., 1] *= -1
# Convert to visual angles
this_label[..., 0] *= 19
this_label[..., 1] *= 14.7
return this_label
[docs]
def save_data(participant, participant_data, participant_labels, participant_ids, processed_data, center_labels=False):
"""Save participant data to npz file for fast (lazy) loading during model training.
Parameters
----------
participant : str
Participant label
participant_data : list
4D (X,Y,Z,t) data for participant across runs
participant_labels : list
3D (t,X,Y) data, with corresponding labels to participant data
participant_ids : str
Participant identifier with run id
processed_data : str
Filepath to where processed data should be stored
center_labels : bool, optional
Centers labels to (0,0) which can improve performance of model, by default False
"""
# Save npz file for each participant.
# The resulting file contains both the eye mask and the labels
# which are used for model training
participant_data = np.transpose(np.concatenate(participant_data, axis=3), axes=(3, 0, 1, 2))
participant_labels = np.concatenate(participant_labels, axis=0).astype("float32")
participant_ids = np.concatenate(participant_ids, axis=1).transpose()
# Adjust labels to be centered at (0,0)
if center_labels:
participant_labels = participant_labels - np.nanmedian(participant_labels, axis=(0, 1))
# Save data and labels in npz file (lazy loading)
data_dict = {}
for idx, (data, label, identifier) in enumerate(zip(participant_data, participant_labels, participant_ids)):
data_dict[f"data_{idx}"] = data
data_dict[f"label_{idx}"] = label
data_dict[f"identifier_{idx}"] = identifier
# Save each subject in separate .npz files (fast to load)
subject_file_path = join(processed_data, participant)
print(
f"Saving eye data {participant_data.shape} "
f"and targets {participant_labels.shape} "
f"(NaN {np.sum(np.isnan(participant_labels.flatten()))}) to file {subject_file_path}"
)
np.savez(subject_file_path, **data_dict)