Skip to content

[Feature] nnUNet-style Gaussian Noise and Blur #2373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmseg/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
Expand All @@ -42,5 +43,6 @@
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
]
5 changes: 4 additions & 1 deletion mmseg/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadImageFromNDArray)
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)

# yapf: enable
__all__ = [
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge'
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
]
179 changes: 179 additions & 0 deletions mmseg/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_tuple_of
from numpy import random
from scipy.ndimage import gaussian_filter

from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from mmseg.registry import TRANSFORMS
Expand Down Expand Up @@ -1507,3 +1508,181 @@ def transform(self, results: dict) -> dict:

def __repr__(self):
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'


@TRANSFORMS.register_module()
class BioMedicalGaussianNoise(BaseTransform):
"""Add random Gaussian noise to image.

Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501

Copyright (c) German Cancer Research Center (DKFZ)
Licensed under the Apache License, Version 2.0

Required Keys:

- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.

Modified Keys:

- img

Args:
prob (float): Probability to add Gaussian noise for
each sample. Default to 0.1.
mean (float): Mean or “centre” of the distribution. Default to 0.0.
std (float): Standard deviation of distribution. Default to 0.1.
"""

def __init__(self,
prob: float = 0.1,
mean: float = 0.0,
std: float = 0.1) -> None:
super().__init__()
assert 0.0 <= prob <= 1.0 and std >= 0.0
self.prob = prob
self.mean = mean
self.std = std

def transform(self, results: Dict) -> Dict:
"""Call function to add random Gaussian noise to image.

Args:
results (dict): Result dict.

Returns:
dict: Result dict with random Gaussian noise.
"""
if np.random.rand() < self.prob:
rand_std = np.random.uniform(0, self.std)
noise = np.random.normal(
self.mean, rand_std, size=results['img'].shape)
# noise is float64 array, convert to the results['img'].dtype
noise = noise.astype(results['img'].dtype)
results['img'] = results['img'] + noise
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'mean={self.mean}, '
repr_str += f'std={self.std})'
return repr_str


@TRANSFORMS.register_module()
class BioMedicalGaussianBlur(BaseTransform):
"""Add Gaussian blur with random sigma to image.

Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501

Copyright (c) German Cancer Research Center (DKFZ)
Licensed under the Apache License, Version 2.0

Required Keys:

- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.

Modified Keys:

- img

Args:
sigma_range (Tuple[float, float]|float): range to randomly
select sigma value. Default to (0.5, 1.0).
prob (float): Probability to apply Gaussian blur
for each sample. Default to 0.2.
prob_per_channel (float): Probability to apply Gaussian blur
for each channel (axis N of the image). Default to 0.5.
different_sigma_per_channel (bool): whether to use different
sigma for each channel (axis N of the image). Default to True.
different_sigma_per_axis (bool): whether to use different
sigma for axis Z, X and Y of the image. Default to True.
"""

def __init__(self,
sigma_range: Tuple[float, float] = (0.5, 1.0),
prob: float = 0.2,
prob_per_channel: float = 0.5,
different_sigma_per_channel: bool = True,
different_sigma_per_axis: bool = True) -> None:
super().__init__()
assert 0.0 <= prob <= 1.0
assert 0.0 <= prob_per_channel <= 1.0
assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2
self.sigma_range = sigma_range
self.prob = prob
self.prob_per_channel = prob_per_channel
self.different_sigma_per_channel = different_sigma_per_channel
self.different_sigma_per_axis = different_sigma_per_axis

def _get_valid_sigma(self, value_range) -> Tuple[float, ...]:
"""Ensure the `value_range` to be either a single value or a sequence
of two values. If the `value_range` is a sequence, generate a random
value with `[value_range[0], value_range[1]]` based on uniform
sampling.

Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501

Args:
value_range (tuple|list|float|int): the input value range
"""
if (isinstance(value_range, (list, tuple))):
if (value_range[0] == value_range[1]):
value = value_range[0]
else:
orig_type = type(value_range[0])
value = np.random.uniform(value_range[0], value_range[1])
value = orig_type(value)
return value

def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray:
"""Random generate sigma and apply Gaussian Blur to the data
Args:
data_sample (np.ndarray): data sample with multiple modalities,
the data shape is (N, Z, Y, X)
"""
sigma = None
for c in range(data_sample.shape[0]):
if np.random.rand() < self.prob_per_channel:
# if no `sigma` is generated, generate one
# if `self.different_sigma_per_channel` is True,
# re-generate random sigma for each channel
if (sigma is None or self.different_sigma_per_channel):
if (not self.different_sigma_per_axis):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From this implementation, only when self.different_sigma_per_axis is True, it will generate different sigma for each channels along Z axis?

However, different_sigma_per_axis is not working for the function described in docstring, it works as different_sigma_per_channel

different_sigma_per_axis (bool): whether to use different
            sigma for axis X and Y of the image. Default to True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some mistakes in this docstring. Actually, different_sigma_per_axis controls axis X, Y, Z, and different_sigma_per_channel controls axis N. The docstring is corrected now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a little question at L1657-1659, If different_sigma_per_axis is True, it will generate different sigma along Z? not X Y Z?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I got it, it generates sigma for data_sample.shape[1:], I just ignore :

sigma = self._get_valid_sigma(self.sigma_range)
else:
sigma = [
self._get_valid_sigma(self.sigma_range)
for _ in data_sample.shape[1:]
]
# apply gaussian filter with `sigma`
data_sample[c] = gaussian_filter(
data_sample[c], sigma, order=0)
return data_sample

def transform(self, results: Dict) -> Dict:
"""Call function to add random Gaussian blur to image.

Args:
results (dict): Result dict.

Returns:
dict: Result dict with random Gaussian noise.
"""
if np.random.rand() < self.prob:
results['img'] = self._gaussian_blur(results['img'])
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'prob_per_channel={self.prob_per_channel}, '
repr_str += f'sigma_range={self.sigma_range}, '
repr_str += 'different_sigma_per_channel='\
f'{self.different_sigma_per_channel}, '
repr_str += 'different_sigma_per_axis='\
f'{self.different_sigma_per_axis})'
return repr_str
108 changes: 108 additions & 0 deletions tests/test_datasets/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,111 @@ def test_biomedical3d_random_crop():
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)


def test_biomedical_gaussian_noise():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedicalGaussianNoise', prob=1.5)
TRANSFORMS.build(transform)

# test assertion for invalid std
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianNoise', prob=0.2, mean=0.5, std=-0.5)
TRANSFORMS.build(transform)

transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
noise_module = TRANSFORMS.build(transform)
assert str(noise_module) == 'BioMedicalGaussianNoise'\
'(prob=1.0, ' \
'mean=0.0, ' \
'std=0.1)'

transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
noise_module = TRANSFORMS.build(transform)
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = noise_module(results)
assert original_img.shape == results['img'].shape


def test_biomedical_gaussian_blur():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedicalGaussianBlur', prob=-1.5)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=0.6)
smooth_module = TRANSFORMS.build(transform)

with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.6))
smooth_module = TRANSFORMS.build(transform)

with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(15, 8, 9))
TRANSFORMS.build(transform)

with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range='0.16')
TRANSFORMS.build(transform)

transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.7, 0.8))
smooth_module = TRANSFORMS.build(transform)
assert str(
smooth_module
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
'prob_per_channel=0.5, '\
'sigma_range=(0.7, 0.8), ' \
'different_sigma_per_channel=True, '\
'different_sigma_per_axis=True)'

transform = dict(type='BioMedicalGaussianBlur', prob=1.0)
smooth_module = TRANSFORMS.build(transform)
assert str(
smooth_module
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
'prob_per_channel=0.5, '\
'sigma_range=(0.5, 1.0), ' \
'different_sigma_per_channel=True, '\
'different_sigma_per_axis=True)'

results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = smooth_module(results)
assert original_img.shape == results['img'].shape
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()

transform = dict(
type='BioMedicalGaussianBlur',
prob=1.0,
different_sigma_per_axis=False)
smooth_module = TRANSFORMS.build(transform)

results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = smooth_module(results)
assert original_img.shape == results['img'].shape
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy