Modules¶
- class rna2seg.dataset_zarr.RNA2segDataset(*args: Any, **kwargs: Any)¶
Bases:
Dataset- __getitem__(idx)¶
Retrieves a patch of spatial transcriptomics data based on the given index.
- Parameters:
idx (int) – The index of the patch to retrieve.
- Returns:
A dictionary containing the following elements:
”img_cellbound”: Tensor representing the cell boundary image.
”dapi”: Tensor representing the DAPI-stained image for nuclear staining.
”rna_img”: Tensor representing the spatial RNA expression image.
”mask_flow”: Tensor representing the cellular flow field mask for segmentation.
”mask_gradient”: Tensor representing the gradient mask for segmentation refinement.
”background” (optional): Tensor representing the background image if test_return_background is enabled.
”idx”: Integer index of the patch.
”patch_index”: Integer representing the patch identifier.
”bounds”: List defining the spatial boundaries of the patch.
”segmentation_nuclei” (optional): Tensor representing the nuclear segmentation mask if available.
”list_gene” (optional): Tensor containing the list of detected genes if return_df is enabled.
”array_coord” (optional): Tensor containing spatial coordinates of detected transcripts
if return_df is enabled.
- __init__(sdata: spatialdata.SpatialData, channels_dapi: list[str], channels_cellbound: list[str] | None = None, key_cell_consistent: str | None = None, key_nucleus_consistent: str | None = None, key_nuclei_segmentation: str | None = None, dict_gene_value: dict | None = None, training_mode: bool = False, evaluation_mode: bool = False, patch_width: int | None = None, patch_overlap: int | None = None, list_patch_index: list[int] | None = None, list_annotation_patches: list[int] | None = None, gene_column='gene', density_threshold: float | None = None, kernel_size_background_density: None | float = 5, kernel_size_rna2img: float = 0.5, max_filter_size_rna2img: float = 2, transform_resize: albumentations.Compose | None = None, transform_dapi: albumentations.Compose | None = None, augment_rna_density: bool = False, min_nb_cell_per_patch=1, remove_cell_in_background_threshold=0.05, remove_nucleus_seg_from_bg=True, addition_mode=True, min_transcripts: int = 1, return_df=False, gene2index=None, augmentation_img=False, recompute_flow=True, test_return_background=False, patch_dir: Path | str | None = None, experiment_name='input_target_rna2seg', use_cache=False, shape_patch_key=None)¶
Initializes the dataset with the provided parameters.
- Parameters:
sdata – SpatialData object containing spatial transcriptomics data.
channels_dapi – List of DAPI channels for nuclear staining.
channels_cellbound – List of cell boundary channels, or None if not provided.
key_cell_consistent – Key for consistent cell segmentation, or None.
key_nucleus_consistent – Key for consistent nucleus segmentation, or None.
key_nuclei_segmentation – Key for nuclei segmentation, or None.
dict_gene_value – Dictionary containing gene encodings, or None.
training_mode – Boolean flag to enable training mode.
evaluation_mode – Boolean flag to enable evaluation mode.
patch_width – Width of the patch for segmentation, or None.
patch_overlap – Overlap of the patch for segmentation, or None.
list_patch_index – List of patch indices, or None.
list_annotation_patches – List of annotation patches to exclude, or None.
gene_column – Column name for genes, default is “gene”.
density_threshold – Threshold for density calculation, or None.
kernel_size_background_density – Kernel size for background density calculation, default is 5.
kernel_size_rna2img – Gaussian kernel size for RNA image transformation, default is 0.5.
max_filter_size_rna2img – Max filter size for RNA image transformation, default is 2.
transform_resize – Resize transformation function for images, or None.
transform_dapi – Transformation function for DAPI images, or None.
augment_rna_density – Boolean flag to enable RNA density augmentation.
min_nb_cell_per_patch – Minimum number of cells per patch for inclusion.
remove_cell_in_background_threshold – Threshold for removing cells in the background.
remove_nucleus_seg_from_bg – Boolean flag to remove nucleus segmentation from the background.
addition_mode – Boolean flag to enable RNA spot value addition.
return_df – Boolean flag to return DataFrame, default is False.
gene2index – Dictionary mapping genes to indices, or None if return_df is False.
augmentation_img – Boolean flag to enable image augmentation.
recompute_flow – Boolean flag to recompute the flow field.
test_return_background – Boolean flag to return background image for testing.
patch_dir – Directory for patch storage, or None.
experiment_name – Name of the experiment, default is ‘input_target_rna2seg’.
use_cache – Boolean flag to enable cache usage.
shape_patch_key – Key for patch shape, or None.
- get_rna_img(bounds, key_transcripts='transcripts', dict_gene_value=None)¶
Generates an image of RNA from spatial transcriptomics data within the specified bounds.
- Parameters:
bounds (tuple[int, int, int, int]) – The bounding box coordinates for the extracted region (xmin, ymin, xmax, ymax).
key_transcripts (str) – The key to access transcriptomic data in the dataset. Defaults to “transcripts”.
dict_gene_value (dict[str, float] | None) – Dictionary mapping gene names to encodings/colors. If None, all genes have a value of 1.
- Returns:
An image representation of RNA expression in the selected region.
- Return type:
np.ndarray
- get_segmentation_img(bounds, key_cell, image=None, color='red', size_line=5)¶
Generates an image of cell segmentation within the specified bounds.
- Parameters:
bounds (tuple[int, int, int, int]) – The bounding box coordinates for the extracted region (xmin, ymin, xmax, ymax).
key_cell (str) – The key to access cell segmentation data.
image (np.ndarray | None) – The base image on which segmentation will be overlaid. If None, uses DAPI staining.
color (str) – The color used to outline cell contours. Defaults to “red”.
size_line (int) – The thickness of the segmentation contour lines. Defaults to 5.
- Returns:
An image with cell segmentation overlaid.
- Return type:
np.ndarray
- get_staining_img(bounds)¶
Generates an image of dapi and cell boundaries stainings within the specified bounds.
- Parameters:
bounds (tuple[int, int, int, int]) – The bounding box coordinates for the extracted region (xmin, ymin, xmax, ymax).
- Returns:
An image of the different staining (channels_dapi and channels_cellbound) in the selected region.
- Return type:
np.ndarray
- class rna2seg.models.RNA2seg(*args: Any, **kwargs: Any)¶
Bases:
ModuleRNA2seg: A deep learning-based method for cell segmentation using spatial transcriptomics data and membrane and nuclei stainings.
- __init__(device, pretrained_model: Path | str | None = 'default_pretrained', net: str = 'unet', n_inv_chan: int = 3, nb_rna_channel: int = 1, nout: int = 3, nbase=[32, 64, 128, 256], sz: int = 3, flow_threshold: float = 0.9, min_cell_size: float = 200, cellbound_flow_threshold: float = 0.4, gene2index=None)¶
Initialize RNA2Seg.
- Parameters:
device (str) – The computing device (e.g., “cpu” or “cuda”).
pretrained_model – Path to a pretrained model.
If “default_pretrained”, a trained rna2seg model is download from huggingface. If None weight are randomly initialized, Defaults is “default_pretrained”. :type pretrained_model: Path | str | None :param net: Backbone network architecture. Can be “unet” or “vmunet”. Defaults to “unet”. :type net: str :param n_inv_chan: Number of channels in staining input image. The stainings are combined and encoded into an image of n_inv_chan channels using a Channel-Net. Defaults to 3. :type n_inv_chan: int :param nb_rna_channel: Number of RNA channels used as input. Defaults to 1. :type nb_rna_channel: int :param nout: Number of output channels. Following the CellPose method, the network outputs cell probabilities on
one channel and the 2-channel flow. Defaults to 3.
- Parameters:
nbase (list[int]) – List defining the number of channels at each layer of the network.
sz (int) – Kernel size for convolutions. Defaults to 3.
flow_threshold (float) – Threshold for flow consistency during segmentation. Defaults to 0.9.
min_cell_size (float) – Minimum cell size (in pixels) to retain. Defaults to 200.
cellbound_flow_threshold (float) – Threshold for cell boundary detection. Defaults to 0.4.
gene2index (dict or None) – Mapping from gene names to indices, if applicable. Defaults to None.
- Raises:
ValueError – If an invalid model or configuration is provided.
- encode(imgs=None, dapi=None, img_cellbound=None)¶
Encodes the input images (DAPI and cell boundary) for the model.
- The method supports two modes of input:
You can pass a tensor imgs containing the full image data.
Alternatively, you can pass DAPI and cell boundary images separately (dapi and img_cellbound).
These two images are concatenated along the channel dimension and then passed through the network.
- Parameters:
imgs – Tensor containing the full set of image channels, including RNA channels.
This is kept for compatibility with the old version. Defaults to None. :type imgs: torch.Tensor | None :param dapi: Tensor representing the DAPI staining image, used for encoding. Must be provided if img_cellbound
is provided.
- Parameters:
img_cellbound (torch.Tensor | None) – Tensor representing the cell boundary image, used for encoding. Must be provided if dapi is provided.
- Returns:
Tensor representing the encoded image data, after processing through the network’s AdaptorNet.
- Return type:
torch.Tensor
- Raises:
AssertionError – If imgs is provided while dapi and img_cellbound are also provided, or if neither is provided when one of them is required.
- forward(input_dict=None, list_gene=None, array_coord=None, dapi=None, img_cellbound=None, rna_img=None)¶
Forward pass for the RNA2seg model.
- The forward method supports two modes of input:
1. You can pass a dictionary (input_dict) containing all the relevant inputs (‘list_gene’, ‘array_coord’, ‘dapi’, ‘img_cellbound’, and ‘rna_img’). If provided, the values in the dictionary will override the individual function arguments. 2. Alternatively, you can pass each argument independently.
- The DAPI and cell boundary images are used for encoding, and the RNA image (either encoded or pre-encoded) is
combined with the other inputs. The concatenated data is then passed through the model to generate the output.
- Parameters:
input_dict – Optional dictionary containing the following keys: ‘list_gene’, ‘array_coord’, ‘dapi’,
‘img_cellbound’, and ‘rna_img’. If provided, the values will override the function arguments. :type input_dict: dict | None :param list_gene: Tensor representing the list of genes for RNA encoding. Cannot be provided simultaneously with rna_img. Defaults to None. :type list_gene: torch.Tensor | None :param array_coord: Tensor containing the coordinates for RNA encoding. Required if list_gene is provided. :type array_coord: torch.Tensor | None :param dapi: Tensor representing the DAPI staining image, used as input for encoding. :type dapi: torch.Tensor :param img_cellbound: Tensor representing the cell boundary image, used as input for encoding. :type img_cellbound: torch.Tensor :param rna_img: Tensor representing the pre-encoded RNA image. Cannot be provided simultaneously
with list_gene. Defaults to None.
- Returns:
Tensor representing the model’s output after processing the input data.
- Return type:
torch.Tensor
- Raises:
AssertionError – If neither list_gene nor rna_img is provided, or if both are provided simultaneously.
AssertionError – If array_coord is not provided when list_gene is used.
- load_model(filename, device=None)¶
Load the model from a file.
- Parameters:
filename (str) – The path to the file where the model is saved.
device (torch.device | None) – The device to load the model on. If None, the model is loaded on the CPU. Defaults to None.
- run(path_temp_save, input_dict=None, list_gene=None, array_coord=None, dapi=None, img_cellbound=None, rna_img=None, bounds=None, min_area=0)¶
Evaluates the model on a batch of images or a single image, and optionally on staining images.
- Parameters:
path_temp_save (str | Path) – The directory where the results will be saved.
input_dict (dict | None) – A dictionary containing the inputs for the model. It can include ‘list_gene’, ‘array_coord’, ‘dapi’, ‘img_cellbound’, ‘rna_img’, and ‘bounds’.
list_gene (torch.Tensor | None) – List of gene expressions to use for encoding RNA. Defaults to None.
array_coord (torch.Tensor | None) – Coordinates array for the genes, required if list_gene is provided.
dapi (torch.Tensor) – DAPI stained image used for encoding.
img_cellbound (torch.Tensor) – Image of cell boundaries used for encoding.
rna_img (torch.Tensor | None) – RNA image, either encoded or pre-encoded.
bounds (list | None) – Bounds for the image, used for transformations. Defaults to None.
min_area (int) – The minimum area to consider for detected cells. Defaults to 0 (no filtering).
- Returns:
A tuple containing the flow, cell probability, predicted masks, and cells (as a GeoDataFrame).
- Return type:
tuple (torch.Tensor, torch.Tensor, np.array, GeoDataFrame)
- save_model(filename)¶
Save the model to a file.
- Parameters:
filename (str) – The path to the file where the model will be saved.
- rna2seg.dataset_zarr.patches.create_patch_rna2seg(sdata: spatialdata.SpatialData, image_key: str, points_key: str, patch_width: int, patch_overlap: int, min_points_per_patch: int, folder_patch_rna2seg: Path | str | None = None, overwrite: bool = False, gene_column_name: str = 'gene')¶
Creates patches from the spatial data to handle data in manageable sizes. Save the patches shapes into the zarr and precomputes transcript.csv files for each patch.
- Parameters:
sdata – SpatialData object containing the spatial dataset.
image_key – Key identifying the image in sdata.
points_key – Key identifying the points in sdata.
patch_width – Width of each patch.
patch_overlap – Overlap between adjacent patches.
min_points_per_patch – Minimum number of transcripts required for a patch.
folder_patch_rna2seg – Directory where patches will be saved. If None, defaults to sdata.path/.rna2seg.
overwrite – Whether to overwrite the existing patches shape if it already exists in the zarr.