[docs]classCompose(Transform):"""Compose several transforms together. Args: transforms: Sequence of instances of :class:`~torchio.transforms.Transform`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """def__init__(self,transforms:Sequence[Transform],**kwargs):super().__init__(parse_input=False,**kwargs)fortransformintransforms:ifnotcallable(transform):message=('One or more of the objects passed to the Compose'f' transform are not callable: "{transform}"')raiseTypeError(message)self.transforms=list(transforms)def__len__(self):returnlen(self.transforms)def__getitem__(self,index)->Transform:returnself.transforms[index]def__repr__(self)->str:returnf'{self.name}({self.transforms})'defapply_transform(self,subject:Subject)->Subject:fortransforminself.transforms:subject=transform(subject)# type: ignore[assignment]returnsubjectdefis_invertible(self)->bool:returnall(t.is_invertible()fortinself.transforms)definverse(self,warn:bool=True)->Compose:"""Return a composed transform with inverted order and transforms. Args: warn: Issue a warning if some transforms are not invertible. """transforms=[]fortransforminself.transforms:iftransform.is_invertible():transforms.append(transform.inverse())elifwarn:message=f'Skipping {transform.name} as it is not invertible'warnings.warn(message,RuntimeWarning,stacklevel=2)transforms.reverse()result=Compose(transforms)ifnottransformsandwarn:warnings.warn('No invertible transforms found',RuntimeWarning,stacklevel=2,)returnresult
[docs]classOneOf(RandomTransform):"""Apply only one of the given transforms. Args: transforms: Dictionary with instances of :class:`~torchio.transforms.Transform` as keys and probabilities as values. Probabilities are normalized so they sum to one. If a sequence is given, the same probability will be assigned to each transform. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Example: >>> import torchio as tio >>> colin = tio.datasets.Colin27() >>> transforms_dict = { ... tio.RandomAffine(): 0.75, ... tio.RandomElasticDeformation(): 0.25, ... } # Using 3 and 1 as probabilities would have the same effect >>> transform = tio.OneOf(transforms_dict) >>> transformed = transform(colin) """def__init__(self,transforms:TypeTransformsDict,**kwargs):super().__init__(parse_input=False,**kwargs)self.transforms_dict=self._get_transforms_dict(transforms)defapply_transform(self,subject:Subject)->Subject:weights=torch.Tensor(list(self.transforms_dict.values()))index=torch.multinomial(weights,1)transforms=list(self.transforms_dict.keys())transform=transforms[index]transformed=transform(subject)returntransformed# type: ignore[return-value]def_get_transforms_dict(self,transforms:TypeTransformsDict,)->Dict[Transform,float]:ifisinstance(transforms,dict):transforms_dict=dict(transforms)self._normalize_probabilities(transforms_dict)else:try:p=1/len(transforms)exceptTypeErrorase:message=('Transforms argument must be a dictionary or a sequence,'f' not {type(transforms)}')raiseValueError(message)frometransforms_dict={transform:pfortransformintransforms}fortransformintransforms_dict:ifnotisinstance(transform,Transform):message=('All keys in transform_dict must be instances of'f'torchio.Transform, not "{type(transform)}"')raiseValueError(message)returntransforms_dict@staticmethoddef_normalize_probabilities(transforms_dict:Dict[Transform,float],)->None:probabilities=np.array(list(transforms_dict.values()),dtype=float)ifnp.any(probabilities<0):message=(f'Probabilities must be greater or equal to zero, not "{probabilities}"')raiseValueError(message)ifnp.all(probabilities==0):message=('At least one probability must be greater than zero,'f' but they are "{probabilities}"')raiseValueError(message)fortransform,probabilityintransforms_dict.items():transforms_dict[transform]=probability/probabilities.sum()