What is Skull Stripping?
Skull stripping, also known as brain extraction, is a preprocessing step that removes the skull and other non-brain tissues from MRI scans. It's a crucial step in the analysis of many MRI neurological images, such as tissue classification and image registration.
Skull stripping is a key area of study in brain image processing applications. It's a preliminary step in many medical applications, as it increases the speed and accuracy of diagnosis.
Skull stripping involves segmenting brain tissue (cortex and cerebellum) from the surrounding region (skull and nonbrain area). It's a preprocessing step for the cortical surface reconstruction process. This procedure takes an intensity-normalized image and deforms a tessellated ellipsoidal template into the shape of the inner surface of the skull.
What are we going to do exactly with the 3D MRI Brain images?
I was researching on the early detection of the Alzheimer's Disease. Therefore, I was working on the 3D MRI images of the human brain. Therefore, let's say, I have the raw 3D MRI brain image like below.
The output would be like below.
I want to get the brain mask that has been applied in the skull stripping process on my raw 3D image.
I want the preprocessed skull-stripped image that only has the brain portion excluding any non-brain tissues.
Why Skull Stripping is a necessary thing in human brain-related research?
Skull stripping, also known as brain extraction or skull-stripped segmentation, is a necessary step in human brain-related research for several reasons:
Isolation of Brain Tissue: The skull contains bones and other tissues that are not relevant to many types of brain analysis, such as MRI or fMRI studies. Removing the skull allows researchers to focus specifically on the brain tissue itself, which is the primary object of study in neuroimaging research.
Improved Accuracy: Skull stripping improves the accuracy of brain imaging analysis by removing non-brain tissues that can interfere with the interpretation of results. This interference can occur due to artifacts or inaccuracies introduced by including non-brain tissues in the analysis.
Reduced Computational Load: Including the skull in the analysis increases the computational load required for processing brain images. By removing the skull, researchers can streamline the analysis process, making it more efficient and reducing computational resources.
Enhanced Visualization: Skull stripping improves the visualization of brain structures in neuroimaging data. It allows for clearer and more precise rendering of brain images, which is essential for accurately identifying and studying various brain regions and structures.
Standardization: Skull stripping helps standardize neuroimaging data processing pipelines across different studies and research groups. By employing consistent skull stripping techniques, researchers can ensure comparability and reproducibility of results, facilitating collaboration and meta-analyses in the field.
Overall, skull stripping is a critical preprocessing step in human brain-related research, enabling more accurate analysis, improved visualization, and enhanced comparability of neuroimaging data.
How can Skull Stripping be performed using Python?
There are many ways to apply skull stripping to 3D MRI images. There are also a lot of Python modules that help us to do that.
However, in this article, I am going to use a very simple code to apply the skull stripping on a basic level using ANTsPy.
For this, we are going to use a helper file that would help us scale up our process.
Let me give it a name, helpers.py
.
Helpers
import matplotlib.pyplot as plt
from ipywidgets import interact
import numpy as np
import SimpleITK as sitk
import cv2
def explore_3D_array(arr: np.ndarray, cmap: str = 'gray'):
"""
Given a 3D array with shape (Z,X,Y) This function will create an interactive
widget to check out all the 2D arrays with shape (X,Y) inside the 3D array.
The purpose of this function to visual inspect the 2D arrays in the image.
Args:
arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
cmap : Which color map use to plot the slices in matplotlib.pyplot
"""
def fn(SLICE):
plt.figure(figsize=(7,7))
plt.imshow(arr[SLICE, :, :], cmap=cmap)
interact(fn, SLICE=(0, arr.shape[0]-1))
def explore_3D_array_comparison(arr_before: np.ndarray, arr_after: np.ndarray, cmap: str = 'gray'):
"""
Given two 3D arrays with shape (Z,X,Y) This function will create an interactive
widget to check out all the 2D arrays with shape (X,Y) inside the 3D arrays.
The purpose of this function to visual compare the 2D arrays after some transformation.
Args:
arr_before : 3D array with shape (Z,X,Y) that represents the volume of a MRI image, before any transform
arr_after : 3D array with shape (Z,X,Y) that represents the volume of a MRI image, after some transform
cmap : Which color map use to plot the slices in matplotlib.pyplot
"""
assert arr_after.shape == arr_before.shape
def fn(SLICE):
fig, (ax1, ax2) = plt.subplots(1, 2, sharex='col', sharey='row', figsize=(10,10))
ax1.set_title('Before', fontsize=15)
ax1.imshow(arr_before[SLICE, :, :], cmap=cmap)
ax2.set_title('After', fontsize=15)
ax2.imshow(arr_after[SLICE, :, :], cmap=cmap)
plt.tight_layout()
interact(fn, SLICE=(0, arr_before.shape[0]-1))
def show_sitk_img_info(img: sitk.Image):
"""
Given a sitk.Image instance prints the information about the MRI image contained.
Args:
img : instance of the sitk.Image to check out
"""
pixel_type = img.GetPixelIDTypeAsString()
origin = img.GetOrigin()
dimensions = img.GetSize()
spacing = img.GetSpacing()
direction = img.GetDirection()
info = {'Pixel Type' : pixel_type, 'Dimensions': dimensions, 'Spacing': spacing, 'Origin': origin, 'Direction' : direction}
for k,v in info.items():
print(f' {k} : {v}')
def add_suffix_to_filename(filename: str, suffix:str) -> str:
"""
Takes a NIfTI filename and appends a suffix.
Args:
filename : NIfTI filename
suffix : suffix to append
Returns:
str : filename after append the suffix
"""
if filename.endswith('.nii'):
result = filename.replace('.nii', f'_{suffix}.nii')
return result
elif filename.endswith('.nii.gz'):
result = filename.replace('.nii.gz', f'_{suffix}.nii.gz')
return result
else:
raise RuntimeError('filename with unknown extension')
def rescale_linear(array: np.ndarray, new_min: int, new_max: int):
"""Rescale an array linearly."""
minimum, maximum = np.min(array), np.max(array)
m = (new_max - new_min) / (maximum - minimum)
b = new_min - m * minimum
return m * array + b
def explore_3D_array_with_mask_contour(arr: np.ndarray, mask: np.ndarray, thickness: int = 1):
"""
Given a 3D array with shape (Z,X,Y) This function will create an interactive
widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. The binary
mask provided will be used to overlay contours of the region of interest over the
array. The purpose of this function is to visual inspect the region delimited by the mask.
Args:
arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
mask : binary mask to obtain the region of interest
"""
assert arr.shape == mask.shape
_arr = rescale_linear(arr,0,1)
_mask = rescale_linear(mask,0,1)
_mask = _mask.astype(np.uint8)
def fn(SLICE):
arr_rgb = cv2.cvtColor(_arr[SLICE, :, :], cv2.COLOR_GRAY2RGB)
contours, _ = cv2.findContours(_mask[SLICE, :, :], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
arr_with_contours = cv2.drawContours(arr_rgb, contours, -1, (0,1,0), thickness)
plt.figure(figsize=(7,7))
plt.imshow(arr_with_contours)
interact(fn, SLICE=(0, arr.shape[0]-1))
Additional required modules and libraries
We also need some modules and Python libraries so that we can create our model based on them.
I can simply create a text (.txt
) file named requirements.txt
where I can simply write down all the individual module/library names in separate lines.
ipykernel
ants
ipywidgets
matplotlib
opencv-python
jupyter
notebook
antspyx
SimpleITK
antspynet
Notebook
We can use VSCode or any other Python IDEs for writing our code. But I like to use the Jupyter Notebook as we can run different segments of our code separately. Also, all of the outputs stay in the same notebook.
You can also use the Google Colab notebook for free! Kaggle Notebook can also be used for free.
If you use the Google Colab, then you need to mount the drive to Notebook so that you can directly access all the files from your Google Drive.
# Connect Google Drive with Colab
from google.colab import drive
drive.mount('/content/drive')
Now we can simply install all the required modules/libraries that are listed on our requirements.txt
file.
# Install necessary components
! pip install -r requirements.txt
This will install all the packages/modules that are listed on that specific text file. If you need more modules to install, you can simply add the name in that text file. If you need a specific version of a library, then you can also do that in the text file.
Output
Now we can import modules in our code. Let me check whether it has successfully installed the AntsPy and SimpleITK or not. Those two are very important libraries for our task.
%matplotlib inline
# matplotlib will be displayed inline in the notebook
import os
import sys
sys.path.append('/content/drive/MyDrive/MRI_Image_Processing/notebooks')
# appending the path to include our custom helper
import helpers
from helpers import *
import ants
import SimpleITK as sitk
print(f'AntsPy Version = {ants.__version__}')
print(f'SimpleITK version = {sitk.__version__}')
Output:
AntsPy Version = 0.4.2 SimpleITK version = 2.3.1
The rest of the code with their output and necessary comments are given below:
# Define the base directory by getting the directory name of the specified path
BASE_DIR = os.path.dirname("/content/drive/MyDrive/MRI_Image_Processing/")
# Print the path to the project folder
print(f'project folder = {BASE_DIR}')
Output
project folder = /content/drive/MyDrive/MRI_Image_Processing
import os
# Define the directory path
directory_path = '/content/drive/MyDrive/MRI_Image_Processing/assets/raw_examples/'
# Initialize an empty list to store filenames
raw_examples = []
# Iterate through files in the directory and add filenames with extensions to raw_examples list
for filename in os.listdir(directory_path):
# Check if the path refers to a file (not a subdirectory)
if os.path.isfile(os.path.join(directory_path, filename)):
raw_examples.append(filename)
# Display the updated raw_examples list
print(raw_examples)
Output:
['ADNI_136_S_0184_MR_MPR____N3__Scaled_2_Br_20081008132905229_S18601_I119714.nii.gz', 'ADNI_136_S_0195_MR_MPR____N3__Scaled_Br_20090708095227200_S65770_I148270.nii.gz', 'ADNI_136_S_0184_MR_MPR____N3__Scaled_Br_20080123103107781_S18601_I88159.nii.gz', 'ADNI_136_S_0086_MR_MPR____N3__Scaled_Br_20070815111150885_S31407_I67781.nii.gz', 'ADNI_136_S_0195_MR_MPR____N3__Scaled_2_Br_20081008133516751_S12748_I119721.nii.gz', 'ADNI_136_S_0184_MR_MPR____N3__Scaled_2_Br_20081008132712063_S12474_I119712.nii.gz', 'ADNI_136_S_0195_MR_MPR____N3__Scaled_Br_20080123103529723_S19574_I88164.nii.gz', 'ADNI_136_S_0184_MR_MPR-R____N3__Scaled_Br_20080415154909319_S47135_I102840.nii.gz', 'ADNI_136_S_0196_MR_MPR____N3__Scaled_Br_20070215192140032_S13831_I40269.nii.gz', 'ADNI_136_S_0195_MR_MPR____N3__Scaled_Br_20071113184014832_S28912_I81981.nii.gz', 'ADNI_136_S_0196_MR_MPR____N3__Scaled_Br_20070809224441151_S31829_I66740.nii.gz', 'ADNI_136_S_0086_MR_MPR____N3__Scaled_Br_20081013161936541_S49691_I120416.nii.gz', 'ADNI_136_S_0195_MR_MPR____N3__Scaled_2_Br_20081008133704276_S19574_I119723.nii.gz', 'ADNI_136_S_0086_MR_MPR____N3__Scaled_Br_20070215172221943_S14069_I40172.nii.gz', 'ADNI_136_S_0184_MR_MPR____N3__Scaled_Br_20070215174801158_S12474_I40191.nii.gz', 'ADNI_136_S_0195_MR_MPR-R____N3__Scaled_Br_20071110123442398_S39882_I81460.nii.gz', 'ADNI_136_S_0184_MR_MPR____N3__Scaled_Br_20090708094745554_S64785_I148265.nii.gz', 'ADNI_136_S_0195_MR_MPR____N3__Scaled_Br_20070215185520914_S12748_I40254.nii.gz', 'ADNI_136_S_0195_MR_MPR-R____N3__Scaled_Br_20081013162450631_S47389_I120423.nii.gz', 'ADNI_136_S_0184_MR_MPR-R____N3__Scaled_Br_20070819190556867_S28430_I69136.nii.gz']
Above, we can see all the necessary input data that we want to process.
from antspynet.utilities import brain_extraction
# Initialize an empty list to store the generated probabilistic brain masks
prob_brain_masks = []
# Loop through all items in the raw_examples list
for raw_example in raw_examples:
# Create the full file path by joining the base directory with the path to the specific image file
raw_img_path = os.path.join(BASE_DIR, 'assets', 'raw_examples', raw_example)
# Read the image using AntsPy's image_read function, specifying reorientation to 'IAL'
raw_img_ants = ants.image_read(raw_img_path, reorient='IAL')
# Print the shape of the image as a numpy array in the format (Z, X, Y)
print(f'shape = {raw_img_ants.numpy().shape} -> (Z, X, Y)')
# Display the 3D array using a custom function 'explore_3D_array' with the image array and a specified colormap
explore_3D_array(arr=raw_img_ants.numpy(), cmap='nipy_spectral')
# Generate a probabilistic brain mask using the 'brain_extraction' function
# The 'modality' parameter specifies the imaging modality as 'bold'
# The 'verbose' parameter is set to 'True' to display detailed progress information
prob_brain_mask = brain_extraction(raw_img_ants, modality='bold', verbose=True)
# Append the generated probabilistic brain mask to the list
prob_brain_masks.append(prob_brain_mask)
# Print the filename or any relevant information about the current image (optional)
print(f"Probabilistic Brain Mask for: {raw_example}")
# Print the probabilistic brain mask
print(prob_brain_mask)
# Visualize the probabilistic brain mask using the 'explore_3D_array' function
# This function displays the 3D array representation of the brain mask
explore_3D_array(prob_brain_mask.numpy())
# Generate a binary brain mask from the probabilistic brain mask using a threshold
brain_mask = ants.get_mask(prob_brain_mask, low_thresh=0.5)
# Visualize the original image overlaid with the brain mask contour
explore_3D_array_with_mask_contour(raw_img_ants.numpy(), brain_mask.numpy())
# Define the output folder path by joining the base directory with the 'assets' and 'preprocessed' directories
out_folder = os.path.join(BASE_DIR, 'assets', 'preprocessed')
# Create a subfolder within the 'preprocessed' directory named after the raw file (without extension)
out_folder = os.path.join(out_folder, raw_example.split('.')[0]) # Create folder with name of the raw file
# Create the output folder if it doesn't exist already
os.makedirs(out_folder, exist_ok=True) # Create folder if it doesn't exist
# Generate a filename by adding the suffix 'brainMaskByDL' to the original raw file name
out_filename = add_suffix_to_filename(raw_example, suffix='brainMaskByDL')
# Create the full output file path by joining the output folder with the generated filename
out_path = os.path.join(out_folder, out_filename)
# Print the relative path of the input raw image file (excluding the base directory)
print(raw_img_path[len(BASE_DIR):])
# Print the relative path of the output file (excluding the base directory)
print(out_path[len(BASE_DIR):])
# Save the brain mask to a file
brain_mask.to_file(out_path)
# Create a masked image by applying the binary brain mask ('brain_mask') to the original image ('raw_img_ants')
masked = ants.mask_image(raw_img_ants, brain_mask)
# Visualize the 3D array representation of the masked image
explore_3D_array(masked.numpy())
# Generate a filename by adding the suffix 'brainMaskedByDL' to the original raw file name
out_filename = add_suffix_to_filename(raw_example, suffix='brainMaskedByDL')
# Create the full output file path by joining the output folder with the generated filename
out_path = os.path.join(out_folder, out_filename)
# Print the relative path of the input raw image file (excluding the base directory)
print(raw_img_path[len(BASE_DIR):])
# Print the relative path of the output file (excluding the base directory)
print(out_path[len(BASE_DIR):])
# Save the masked image ('masked') to a file specified by 'out_path'
masked.to_file(out_path)
Conclusion
The output is very huge.
But you can check the entire project from here: FahimFBA/skull-stripping-3D-brain-mri/
Also, you will find a warning in the output like below:
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x792af6defe20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
I can solve it for you. But I am leaving this task for you to resolve! 😉
You can find my raw data and the preprocessed data in the assets directory.
The complete notebook with all the outputs is here!
Some sample images of the output images are given below.
Make sure to ⭐ the repository if you like this!
Cheers! 😊