-
Notifications
You must be signed in to change notification settings - Fork 2.7k
[Feature] nnUNet-style Gaussian Noise and Blur #2373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ec1737f
ae02038
f963009
f878cf0
8452713
61ccea3
ecad01c
976c460
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from mmcv.transforms.utils import cache_randomness | ||
from mmengine.utils import is_tuple_of | ||
from numpy import random | ||
from scipy.ndimage import gaussian_filter | ||
|
||
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset | ||
from mmseg.registry import TRANSFORMS | ||
|
@@ -1507,3 +1508,181 @@ def transform(self, results: dict) -> dict: | |
|
||
def __repr__(self): | ||
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})' | ||
|
||
|
||
@TRANSFORMS.register_module() | ||
class BioMedicalGaussianNoise(BaseTransform): | ||
"""Add random Gaussian noise to image. | ||
|
||
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501 | ||
|
||
Copyright (c) German Cancer Research Center (DKFZ) | ||
Licensed under the Apache License, Version 2.0 | ||
|
||
Required Keys: | ||
|
||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X), | ||
N is the number of modalities, and data type is float32. | ||
|
||
Modified Keys: | ||
|
||
- img | ||
|
||
Args: | ||
prob (float): Probability to add Gaussian noise for | ||
each sample. Default to 0.1. | ||
mean (float): Mean or “centre” of the distribution. Default to 0.0. | ||
std (float): Standard deviation of distribution. Default to 0.1. | ||
""" | ||
|
||
def __init__(self, | ||
prob: float = 0.1, | ||
mean: float = 0.0, | ||
std: float = 0.1) -> None: | ||
super().__init__() | ||
assert 0.0 <= prob <= 1.0 and std >= 0.0 | ||
self.prob = prob | ||
self.mean = mean | ||
self.std = std | ||
|
||
def transform(self, results: Dict) -> Dict: | ||
"""Call function to add random Gaussian noise to image. | ||
|
||
Args: | ||
results (dict): Result dict. | ||
|
||
Returns: | ||
dict: Result dict with random Gaussian noise. | ||
""" | ||
if np.random.rand() < self.prob: | ||
rand_std = np.random.uniform(0, self.std) | ||
noise = np.random.normal( | ||
self.mean, rand_std, size=results['img'].shape) | ||
# noise is float64 array, convert to the results['img'].dtype | ||
noise = noise.astype(results['img'].dtype) | ||
results['img'] = results['img'] + noise | ||
return results | ||
|
||
def __repr__(self): | ||
repr_str = self.__class__.__name__ | ||
repr_str += f'(prob={self.prob}, ' | ||
repr_str += f'mean={self.mean}, ' | ||
repr_str += f'std={self.std})' | ||
return repr_str | ||
|
||
|
||
@TRANSFORMS.register_module() | ||
class BioMedicalGaussianBlur(BaseTransform): | ||
"""Add Gaussian blur with random sigma to image. | ||
|
||
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501 | ||
|
||
Copyright (c) German Cancer Research Center (DKFZ) | ||
Licensed under the Apache License, Version 2.0 | ||
|
||
Required Keys: | ||
|
||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X), | ||
N is the number of modalities, and data type is float32. | ||
|
||
Modified Keys: | ||
|
||
- img | ||
|
||
Args: | ||
sigma_range (Tuple[float, float]|float): range to randomly | ||
select sigma value. Default to (0.5, 1.0). | ||
prob (float): Probability to apply Gaussian blur | ||
for each sample. Default to 0.2. | ||
prob_per_channel (float): Probability to apply Gaussian blur | ||
for each channel (axis N of the image). Default to 0.5. | ||
different_sigma_per_channel (bool): whether to use different | ||
sigma for each channel (axis N of the image). Default to True. | ||
different_sigma_per_axis (bool): whether to use different | ||
sigma for axis Z, X and Y of the image. Default to True. | ||
""" | ||
|
||
def __init__(self, | ||
sigma_range: Tuple[float, float] = (0.5, 1.0), | ||
prob: float = 0.2, | ||
prob_per_channel: float = 0.5, | ||
different_sigma_per_channel: bool = True, | ||
different_sigma_per_axis: bool = True) -> None: | ||
super().__init__() | ||
assert 0.0 <= prob <= 1.0 | ||
assert 0.0 <= prob_per_channel <= 1.0 | ||
assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2 | ||
self.sigma_range = sigma_range | ||
self.prob = prob | ||
self.prob_per_channel = prob_per_channel | ||
self.different_sigma_per_channel = different_sigma_per_channel | ||
self.different_sigma_per_axis = different_sigma_per_axis | ||
|
||
def _get_valid_sigma(self, value_range) -> Tuple[float, ...]: | ||
"""Ensure the `value_range` to be either a single value or a sequence | ||
of two values. If the `value_range` is a sequence, generate a random | ||
value with `[value_range[0], value_range[1]]` based on uniform | ||
sampling. | ||
|
||
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501 | ||
|
||
Args: | ||
value_range (tuple|list|float|int): the input value range | ||
""" | ||
if (isinstance(value_range, (list, tuple))): | ||
if (value_range[0] == value_range[1]): | ||
value = value_range[0] | ||
else: | ||
orig_type = type(value_range[0]) | ||
value = np.random.uniform(value_range[0], value_range[1]) | ||
value = orig_type(value) | ||
return value | ||
|
||
def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray: | ||
"""Random generate sigma and apply Gaussian Blur to the data | ||
Args: | ||
data_sample (np.ndarray): data sample with multiple modalities, | ||
the data shape is (N, Z, Y, X) | ||
""" | ||
sigma = None | ||
for c in range(data_sample.shape[0]): | ||
if np.random.rand() < self.prob_per_channel: | ||
# if no `sigma` is generated, generate one | ||
# if `self.different_sigma_per_channel` is True, | ||
# re-generate random sigma for each channel | ||
if (sigma is None or self.different_sigma_per_channel): | ||
if (not self.different_sigma_per_axis): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From this implementation, only when However,
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here are some mistakes in this docstring. Actually, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a little question at L1657-1659, If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry I got it, it generates sigma for |
||
sigma = self._get_valid_sigma(self.sigma_range) | ||
else: | ||
sigma = [ | ||
self._get_valid_sigma(self.sigma_range) | ||
for _ in data_sample.shape[1:] | ||
] | ||
# apply gaussian filter with `sigma` | ||
data_sample[c] = gaussian_filter( | ||
data_sample[c], sigma, order=0) | ||
return data_sample | ||
|
||
def transform(self, results: Dict) -> Dict: | ||
"""Call function to add random Gaussian blur to image. | ||
|
||
Args: | ||
results (dict): Result dict. | ||
|
||
Returns: | ||
dict: Result dict with random Gaussian noise. | ||
""" | ||
if np.random.rand() < self.prob: | ||
results['img'] = self._gaussian_blur(results['img']) | ||
return results | ||
|
||
def __repr__(self): | ||
repr_str = self.__class__.__name__ | ||
repr_str += f'(prob={self.prob}, ' | ||
repr_str += f'prob_per_channel={self.prob_per_channel}, ' | ||
repr_str += f'sigma_range={self.sigma_range}, ' | ||
repr_str += 'different_sigma_per_channel='\ | ||
f'{self.different_sigma_per_channel}, ' | ||
repr_str += 'different_sigma_per_axis='\ | ||
f'{self.different_sigma_per_axis})' | ||
return repr_str |
Uh oh!
There was an error while loading. Please reload this page.