Modules

class rna2seg.dataset_zarr.RNA2segDataset(*args: Any, **kwargs: Any)

Bases: Dataset

__getitem__(idx)

Retrieves a patch of spatial transcriptomics data based on the given index.

Parameters:

idx (int) – The index of the patch to retrieve.

Returns:

A dictionary containing the following elements:

  • ”img_cellbound”: Tensor representing the cell boundary image.

  • ”dapi”: Tensor representing the DAPI-stained image for nuclear staining.

  • ”rna_img”: Tensor representing the spatial RNA expression image.

  • ”mask_flow”: Tensor representing the cellular flow field mask for segmentation.

  • ”mask_gradient”: Tensor representing the gradient mask for segmentation refinement.

  • ”background” (optional): Tensor representing the background image if test_return_background is enabled.

  • ”idx”: Integer index of the patch.

  • ”patch_index”: Integer representing the patch identifier.

  • ”bounds”: List defining the spatial boundaries of the patch.

  • ”segmentation_nuclei” (optional): Tensor representing the nuclear segmentation mask if available.

  • ”list_gene” (optional): Tensor containing the list of detected genes if return_df is enabled.

  • ”array_coord” (optional): Tensor containing spatial coordinates of detected transcripts

if return_df is enabled.

__init__(sdata: spatialdata.SpatialData, channels_dapi: list[str], channels_cellbound: list[str] | None = None, key_cell_consistent: str | None = None, key_nucleus_consistent: str | None = None, key_nuclei_segmentation: str | None = None, dict_gene_value: dict | None = None, training_mode: bool = False, evaluation_mode: bool = False, patch_width: int | None = None, patch_overlap: int | None = None, list_patch_index: list[int] | None = None, list_annotation_patches: list[int] | None = None, gene_column='gene', density_threshold: float | None = None, kernel_size_background_density: None | float = 5, kernel_size_rna2img: float = 0.5, max_filter_size_rna2img: float = 2, transform_resize: albumentations.Compose | None = None, transform_dapi: albumentations.Compose | None = None, augment_rna_density: bool = False, min_nb_cell_per_patch=1, remove_cell_in_background_threshold=0.05, remove_nucleus_seg_from_bg=True, addition_mode=True, min_transcripts: int = 1, return_df=False, gene2index=None, augmentation_img=False, recompute_flow=True, test_return_background=False, patch_dir: Path | str | None = None, experiment_name='input_target_rna2seg', use_cache=False, shape_patch_key=None)

Initializes the dataset with the provided parameters.

Parameters:
  • sdata – SpatialData object containing spatial transcriptomics data.

  • channels_dapi – List of DAPI channels for nuclear staining.

  • channels_cellbound – List of cell boundary channels, or None if not provided.

  • key_cell_consistent – Key for consistent cell segmentation, or None.

  • key_nucleus_consistent – Key for consistent nucleus segmentation, or None.

  • key_nuclei_segmentation – Key for nuclei segmentation, or None.

  • dict_gene_value – Dictionary containing gene encodings, or None.

  • training_mode – Boolean flag to enable training mode.

  • evaluation_mode – Boolean flag to enable evaluation mode.

  • patch_width – Width of the patch for segmentation, or None.

  • patch_overlap – Overlap of the patch for segmentation, or None.

  • list_patch_index – List of patch indices, or None.

  • list_annotation_patches – List of annotation patches to exclude, or None.

  • gene_column – Column name for genes, default is “gene”.

  • density_threshold – Threshold for density calculation, or None.

  • kernel_size_background_density – Kernel size for background density calculation, default is 5.

  • kernel_size_rna2img – Gaussian kernel size for RNA image transformation, default is 0.5.

  • max_filter_size_rna2img – Max filter size for RNA image transformation, default is 2.

  • transform_resize – Resize transformation function for images, or None.

  • transform_dapi – Transformation function for DAPI images, or None.

  • augment_rna_density – Boolean flag to enable RNA density augmentation.

  • min_nb_cell_per_patch – Minimum number of cells per patch for inclusion.

  • remove_cell_in_background_threshold – Threshold for removing cells in the background.

  • remove_nucleus_seg_from_bg – Boolean flag to remove nucleus segmentation from the background.

  • addition_mode – Boolean flag to enable RNA spot value addition.

  • return_df – Boolean flag to return DataFrame, default is False.

  • gene2index – Dictionary mapping genes to indices, or None if return_df is False.

  • augmentation_img – Boolean flag to enable image augmentation.

  • recompute_flow – Boolean flag to recompute the flow field.

  • test_return_background – Boolean flag to return background image for testing.

  • patch_dir – Directory for patch storage, or None.

  • experiment_name – Name of the experiment, default is ‘input_target_rna2seg’.

  • use_cache – Boolean flag to enable cache usage.

  • shape_patch_key – Key for patch shape, or None.

get_rna_img(bounds, key_transcripts='transcripts', dict_gene_value=None)

Generates an image of RNA from spatial transcriptomics data within the specified bounds.

Parameters:
  • bounds (tuple[int, int, int, int]) – The bounding box coordinates for the extracted region (xmin, ymin, xmax, ymax).

  • key_transcripts (str) – The key to access transcriptomic data in the dataset. Defaults to “transcripts”.

  • dict_gene_value (dict[str, float] | None) – Dictionary mapping gene names to encodings/colors. If None, all genes have a value of 1.

Returns:

An image representation of RNA expression in the selected region.

Return type:

np.ndarray

get_segmentation_img(bounds, key_cell, image=None, color='red', size_line=5)

Generates an image of cell segmentation within the specified bounds.

Parameters:
  • bounds (tuple[int, int, int, int]) – The bounding box coordinates for the extracted region (xmin, ymin, xmax, ymax).

  • key_cell (str) – The key to access cell segmentation data.

  • image (np.ndarray | None) – The base image on which segmentation will be overlaid. If None, uses DAPI staining.

  • color (str) – The color used to outline cell contours. Defaults to “red”.

  • size_line (int) – The thickness of the segmentation contour lines. Defaults to 5.

Returns:

An image with cell segmentation overlaid.

Return type:

np.ndarray

get_staining_img(bounds)

Generates an image of dapi and cell boundaries stainings within the specified bounds.

Parameters:

bounds (tuple[int, int, int, int]) – The bounding box coordinates for the extracted region (xmin, ymin, xmax, ymax).

Returns:

An image of the different staining (channels_dapi and channels_cellbound) in the selected region.

Return type:

np.ndarray

class rna2seg.models.RNA2seg(*args: Any, **kwargs: Any)

Bases: Module

RNA2seg: A deep learning-based method for cell segmentation using spatial transcriptomics data and membrane and nuclei stainings.

__init__(device, pretrained_model: Path | str | None = 'default_pretrained', net: str = 'unet', n_inv_chan: int = 3, nb_rna_channel: int = 1, nout: int = 3, nbase=[32, 64, 128, 256], sz: int = 3, flow_threshold: float = 0.9, min_cell_size: float = 200, cellbound_flow_threshold: float = 0.4, gene2index=None)

Initialize RNA2Seg.

Parameters:
  • device (str) – The computing device (e.g., “cpu” or “cuda”).

  • pretrained_model – Path to a pretrained model.

If “default_pretrained”, a trained rna2seg model is download from huggingface. If None weight are randomly initialized, Defaults is “default_pretrained”. :type pretrained_model: Path | str | None :param net: Backbone network architecture. Can be “unet” or “vmunet”. Defaults to “unet”. :type net: str :param n_inv_chan: Number of channels in staining input image. The stainings are combined and encoded into an image of n_inv_chan channels using a Channel-Net. Defaults to 3. :type n_inv_chan: int :param nb_rna_channel: Number of RNA channels used as input. Defaults to 1. :type nb_rna_channel: int :param nout: Number of output channels. Following the CellPose method, the network outputs cell probabilities on

one channel and the 2-channel flow. Defaults to 3.

Parameters:
  • nbase (list[int]) – List defining the number of channels at each layer of the network.

  • sz (int) – Kernel size for convolutions. Defaults to 3.

  • flow_threshold (float) – Threshold for flow consistency during segmentation. Defaults to 0.9.

  • min_cell_size (float) – Minimum cell size (in pixels) to retain. Defaults to 200.

  • cellbound_flow_threshold (float) – Threshold for cell boundary detection. Defaults to 0.4.

  • gene2index (dict or None) – Mapping from gene names to indices, if applicable. Defaults to None.

Raises:

ValueError – If an invalid model or configuration is provided.

encode(imgs=None, dapi=None, img_cellbound=None)

Encodes the input images (DAPI and cell boundary) for the model.

The method supports two modes of input:
  1. You can pass a tensor imgs containing the full image data.

  2. Alternatively, you can pass DAPI and cell boundary images separately (dapi and img_cellbound).

These two images are concatenated along the channel dimension and then passed through the network.

Parameters:

imgs – Tensor containing the full set of image channels, including RNA channels.

This is kept for compatibility with the old version. Defaults to None. :type imgs: torch.Tensor | None :param dapi: Tensor representing the DAPI staining image, used for encoding. Must be provided if img_cellbound

is provided.

Parameters:

img_cellbound (torch.Tensor | None) – Tensor representing the cell boundary image, used for encoding. Must be provided if dapi is provided.

Returns:

Tensor representing the encoded image data, after processing through the network’s AdaptorNet.

Return type:

torch.Tensor

Raises:

AssertionError – If imgs is provided while dapi and img_cellbound are also provided, or if neither is provided when one of them is required.

forward(input_dict=None, list_gene=None, array_coord=None, dapi=None, img_cellbound=None, rna_img=None)

Forward pass for the RNA2seg model.

The forward method supports two modes of input:

1. You can pass a dictionary (input_dict) containing all the relevant inputs (‘list_gene’, ‘array_coord’, ‘dapi’, ‘img_cellbound’, and ‘rna_img’). If provided, the values in the dictionary will override the individual function arguments. 2. Alternatively, you can pass each argument independently.

The DAPI and cell boundary images are used for encoding, and the RNA image (either encoded or pre-encoded) is

combined with the other inputs. The concatenated data is then passed through the model to generate the output.

Parameters:

input_dict – Optional dictionary containing the following keys: ‘list_gene’, ‘array_coord’, ‘dapi’,

‘img_cellbound’, and ‘rna_img’. If provided, the values will override the function arguments. :type input_dict: dict | None :param list_gene: Tensor representing the list of genes for RNA encoding. Cannot be provided simultaneously with rna_img. Defaults to None. :type list_gene: torch.Tensor | None :param array_coord: Tensor containing the coordinates for RNA encoding. Required if list_gene is provided. :type array_coord: torch.Tensor | None :param dapi: Tensor representing the DAPI staining image, used as input for encoding. :type dapi: torch.Tensor :param img_cellbound: Tensor representing the cell boundary image, used as input for encoding. :type img_cellbound: torch.Tensor :param rna_img: Tensor representing the pre-encoded RNA image. Cannot be provided simultaneously

with list_gene. Defaults to None.

Returns:

Tensor representing the model’s output after processing the input data.

Return type:

torch.Tensor

Raises:
  • AssertionError – If neither list_gene nor rna_img is provided, or if both are provided simultaneously.

  • AssertionError – If array_coord is not provided when list_gene is used.

load_model(filename, device=None)

Load the model from a file.

Parameters:
  • filename (str) – The path to the file where the model is saved.

  • device (torch.device | None) – The device to load the model on. If None, the model is loaded on the CPU. Defaults to None.

run(path_temp_save, input_dict=None, list_gene=None, array_coord=None, dapi=None, img_cellbound=None, rna_img=None, bounds=None, min_area=0)

Evaluates the model on a batch of images or a single image, and optionally on staining images.

Parameters:
  • path_temp_save (str | Path) – The directory where the results will be saved.

  • input_dict (dict | None) – A dictionary containing the inputs for the model. It can include ‘list_gene’, ‘array_coord’, ‘dapi’, ‘img_cellbound’, ‘rna_img’, and ‘bounds’.

  • list_gene (torch.Tensor | None) – List of gene expressions to use for encoding RNA. Defaults to None.

  • array_coord (torch.Tensor | None) – Coordinates array for the genes, required if list_gene is provided.

  • dapi (torch.Tensor) – DAPI stained image used for encoding.

  • img_cellbound (torch.Tensor) – Image of cell boundaries used for encoding.

  • rna_img (torch.Tensor | None) – RNA image, either encoded or pre-encoded.

  • bounds (list | None) – Bounds for the image, used for transformations. Defaults to None.

  • min_area (int) – The minimum area to consider for detected cells. Defaults to 0 (no filtering).

Returns:

A tuple containing the flow, cell probability, predicted masks, and cells (as a GeoDataFrame).

Return type:

tuple (torch.Tensor, torch.Tensor, np.array, GeoDataFrame)

save_model(filename)

Save the model to a file.

Parameters:

filename (str) – The path to the file where the model will be saved.

rna2seg.dataset_zarr.patches.create_patch_rna2seg(sdata: spatialdata.SpatialData, image_key: str, points_key: str, patch_width: int, patch_overlap: int, min_points_per_patch: int, folder_patch_rna2seg: Path | str | None = None, overwrite: bool = False, gene_column_name: str = 'gene')

Creates patches from the spatial data to handle data in manageable sizes. Save the patches shapes into the zarr and precomputes transcript.csv files for each patch.

Parameters:
  • sdata – SpatialData object containing the spatial dataset.

  • image_key – Key identifying the image in sdata.

  • points_key – Key identifying the points in sdata.

  • patch_width – Width of each patch.

  • patch_overlap – Overlap between adjacent patches.

  • min_points_per_patch – Minimum number of transcripts required for a patch.

  • folder_patch_rna2seg – Directory where patches will be saved. If None, defaults to sdata.path/.rna2seg.

  • overwrite – Whether to overwrite the existing patches shape if it already exists in the zarr.