Training RNA2seg on Zarr-Saved SpatialData

This notebook demonstrates how to train RNA2seg on spatial transcriptomics data stored in a Zarr file. The process consists of four main steps:

  1. Patch Creation – Extract patches of a reasonable size to process efficiently (saved in the Zarr file).

  2. Filtered Target Generation – Create a curated segmentation mask from a teacher model for RNA2seg training (saved in the Zarr file).

  3. Model Training – Train RNA2seg using the generated patches and filtered segmentation.

  4. Apply to the whole dataset – Use the notebook apply_model_on_zarr.ipynb to apply the trained model to the entire dataset.

Test data for this notebook can be downloaded at: https://cloud.minesparis.psl.eu/index.php/s/qw2HaDVxwwy1EOK
This dataset is a subset of the Mouse Ileum dataset from Petukhov. et al. Nat Biotechnol 40, 345–354 (2022). https://doi.org/10.1038/s41587-021-01044-w

Import

[1]:
import rna2seg
rna2seg.__version__
[1]:
'0.1.0'
[2]:
import warnings
warnings.filterwarnings("ignore")
[3]:
import cv2
import torch
import numpy as np
from tqdm import tqdm
import spatialdata as sd
from pathlib import Path
import albumentations as A

from rna2seg.dataset_zarr import (
    RNA2segDataset, custom_collate_fn, compute_consistent_cell
)

Set your own path

[4]:

## path to spatial data merfish_zarr_path = "/media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr" path_save_model = "/media/tom/Transcend/open_merfish/test_spatial_data/test005" ### load sdata and set path parameters sdata = sd.read_zarr(merfish_zarr_path) image_key = "staining_z3" patch_width = 1200 patch_overlap = 50 points_key = "transcripts" min_points_per_patch = 0 folder_patch_rna2seg = Path(merfish_zarr_path) / f".rna2seg_{patch_width}_{patch_overlap}" channels_dapi= ["DAPI"] channels_cellbound=["Cellbound1"] gene_column_name="gene"

Step 1: Create training patches from Zarr files

In this step, the dataset (image + transcripts) is divided into patches of size patch_width × patch_width with an overlap of patch_overlap. This allows processing images of a manageable size while preserving spatial continuity.

Process

  • The dataset, stored in Zarr format, is loaded.

  • Patches coordinates are saved as a Shape in the zarr: sopa_patches_rna2seg_[patch_width]_[patch_overlap].

  • A .rna2seg directory is created to store the transcript data corresponding to each patch.

  • The transcript information for each patch is saved in CSV format for further processing.

[5]:
from rna2seg.dataset_zarr import create_patch_rna2seg


### create patch in the sdata and precompute transcipt.csv for each patch with sopa
create_patch_rna2seg(sdata,
                    image_key=image_key,
                    points_key=points_key,
                    patch_width=patch_width,
                    patch_overlap=patch_overlap,
                    min_points_per_patch=min_points_per_patch,
                    folder_patch_rna2seg = folder_patch_rna2seg,
                    overwrite = True,
                    gene_column_name=gene_column_name)

print(sdata)
[INFO] (sopa.patches._patches) Added 16 patche(s) to sdata['sopa_patches_rna2seg_1200_50']
[########################################] | 100% Completed | 3.46 sms
[########################################] | 100% Completed | 3.55 ss
SpatialData object, with associated Zarr store: /media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr
├── Images
│     └── 'staining_z3': DataTree[cyx] (5, 3704, 3704), (5, 1852, 1851), (5, 926, 926), (5, 463, 463), (5, 232, 232)
├── Points
│     └── 'transcripts': DataFrame with shape: (<Delayed>, 10) (3D points)
└── Shapes
      ├── 'Cellbound1': GeoDataFrame shape: (569, 1) (2D shapes)
      ├── 'DAPI': GeoDataFrame shape: (409, 1) (2D shapes)
      ├── 'sopa_patches_rna2seg_1200_50': GeoDataFrame shape: (16, 3) (2D shapes)
      └── 'sopa_patches_rna2seg_1200_150': GeoDataFrame shape: (16, 3) (2D shapes)
with coordinate systems:
    ▸ 'microns', with elements:
        staining_z3 (Images), transcripts (Points), Cellbound1 (Shapes), DAPI (Shapes), sopa_patches_rna2seg_1200_50 (Shapes), sopa_patches_rna2seg_1200_150 (Shapes)

Step 2: Create a Consistent Target for Training RNA2seg

Input: Spatial data with potentially erroneous nucleus and cell segmentations.
Output: Curated cell and nucleus segmentations for training RNA2seg. Saved in the zarr.
This step refines two segmentations stored in the Zarr file: cell segmentation (key_shape_cell_seg) and nuclei segmentation (key_nuclei_segmentation).
The goal is to generate a teacher segmentation by filtering out inconsistencies between cells and nuclei.

Process

  1. Load the segmentations (Cellbound1 and DAPI) from the Zarr file.

  2. Apply a consistency check to remove unreliable segmentations:

    • Consistent cell segmentationCellbound1_consistent

    • Consistent nuclei segmentationDAPI_consistent

  3. Save the refined segmentations back into the Zarr file.

This ensures high-quality annotations for training or fine-tuning RNA2seg.

[6]:
key_cell_segmentation = "Cellbound1"
key_nuclei_segmentation="DAPI"
# to name for future shape that will be created in the sdata
key_cell_consistent = "Cellbound1_consistent"
key_nucleus_consistent = "DAPI_consistent"

sdata, _ = compute_consistent_cell(
    sdata=sdata,
    key_shape_nuclei_seg=key_nuclei_segmentation,
    key_shape_cell_seg=key_cell_segmentation,
    key_cell_consistent=key_cell_consistent,
    key_nuclei_consistent=key_nucleus_consistent,
    image_key=image_key,
    threshold_intersection_contain=0.95,
    threshold_intersection_intersect= 0.05,
    accepted_nb_nuclei_per_cell=None,
    max_cell_nb_intersecting_nuclei=1,
)
print(sdata)
Resolving conflicts: 1344it [00:00, 10513.37it/s]
SpatialData object, with associated Zarr store: /media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr
├── Images
│     └── 'staining_z3': DataTree[cyx] (5, 3704, 3704), (5, 1852, 1851), (5, 926, 926), (5, 463, 463), (5, 232, 232)
├── Points
│     └── 'transcripts': DataFrame with shape: (<Delayed>, 10) (3D points)
└── Shapes
      ├── 'Cellbound1': GeoDataFrame shape: (569, 1) (2D shapes)
      ├── 'Cellbound1_consistent_with_nuclei': GeoDataFrame shape: (204, 1) (2D shapes)
      ├── 'Cellbound1_consistent_without_nuclei': GeoDataFrame shape: (248, 1) (2D shapes)
      ├── 'DAPI': GeoDataFrame shape: (409, 1) (2D shapes)
      ├── 'DAPI_consistent_in_cell': GeoDataFrame shape: (204, 1) (2D shapes)
      ├── 'DAPI_consistent_not_in_cell': GeoDataFrame shape: (171, 1) (2D shapes)
      ├── 'sopa_patches_rna2seg_1200_50': GeoDataFrame shape: (16, 3) (2D shapes)
      └── 'sopa_patches_rna2seg_1200_150': GeoDataFrame shape: (16, 3) (2D shapes)
with coordinate systems:
    ▸ 'microns', with elements:
        staining_z3 (Images), transcripts (Points), Cellbound1 (Shapes), Cellbound1_consistent_with_nuclei (Shapes), Cellbound1_consistent_without_nuclei (Shapes), DAPI (Shapes), DAPI_consistent_in_cell (Shapes), DAPI_consistent_not_in_cell (Shapes), sopa_patches_rna2seg_1200_50 (Shapes), sopa_patches_rna2seg_1200_150 (Shapes)

Step 3: Training RNA2seg

Now, we will train RNA2seg using the target segmentation created in Step 2.

Initialize a RNA2segDataset

[7]:
from rna2seg.models import RNA2seg

transform_resize  = A.Compose([
 A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST),
])

dataset = RNA2segDataset(
   sdata=sdata,
   channels_dapi=channels_dapi,
   channels_cellbound=channels_cellbound,
   shape_patch_key=f"sopa_patches_rna2seg_{patch_width}_{patch_overlap}", # Created at step 1
   key_cell_consistent=key_cell_consistent, # Created at step 2
   key_nucleus_consistent=key_nucleus_consistent, # Created at step 2
   key_nuclei_segmentation=key_nuclei_segmentation,
   gene_column=gene_column_name,
   density_threshold = None,
   kernel_size_background_density = 10 ,
   kernel_size_rna2img = 0.5,
   max_filter_size_rna2img = 2,
   transform_resize = transform_resize,
   training_mode = True,
   patch_dir = folder_patch_rna2seg,
   patch_width=1200,
   patch_overlap=50,
   use_cache = True,
)
No module named 'vmunet'
VMUnet not loaded
100%|██████████| 16/16 [00:00<00:00, 42.29it/s]
Number of valid patches: 6
100%|██████████| 16/16 [00:00<00:00, 26.38it/s]
compute density threshold
100%|██████████| 6/6 [00:01<00:00,  4.60it/s]
Time to compute density threshold: 1.305933s

Train / Validataion split

[8]:
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler

train_indices, val_indices = train_test_split(
    range(len(dataset)), test_size=0.1, shuffle=True, random_state=42
)
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

Initialize Dataloaders

[9]:
training_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=2,
                                              shuffle=False,
                                              num_workers = 0,
                                              sampler=train_sampler,
                                              collate_fn = custom_collate_fn,
                                              )

print( f"len(training_loader) {len(training_loader)}")

validation_loader = torch.utils.data.DataLoader(dataset,
                                                batch_size=2,
                                                shuffle=False,
                                                num_workers = 0,
                                                sampler=valid_sampler,
                                                collate_fn = custom_collate_fn,
                                                )

print( f"len(training_loader) {len(validation_loader)}")
len(training_loader) 3
len(training_loader) 1

Initilize RNA2seg Model

[10]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")
Using device: cpu
[11]:
rna2seg = RNA2seg(
    device,
    net='unet',
    flow_threshold = 0.9,
    cellbound_flow_threshold = 0.4,
    pretrained_model = None,
)
rna2seg = rna2seg.to(device)

optimizer = torch.optim.AdamW(rna2seg.parameters(), lr=0.001, weight_decay=0.01)

initiaisation of CPnet
Initiaisation of ChannelInvariantNet

Training RNA2seg

[12]:
from rna2seg.train import train_one_epoch

best_val_loss = np.inf


for epoch_index in tqdm(range(3)):

    train_one_epoch(
        device=device,
        epoch_index=epoch_index,
        rna2seg=rna2seg,
        training_loader=training_loader,
        optimizer=optimizer,
        print_loss_every = int(len(training_loader) /3),
        tb_writer= None,
        validation_loader=validation_loader,
        path_save_model=path_save_model,
        cellbound_prob= 0.8,
        best_val_loss=best_val_loss
    )
  0%|          | 0/3 [00:00<?, ?it/s]

No cache found for patch 8 Recomputing the patch 8
No cache found for patch 10 Recomputing the patch 10
No cache found for patch 4 Recomputing the patch 4
  validation loss: 11.355034828186035
best_val_loss: 11.355034828186035

No cache found for patch 6 Recomputing the patch 6t]
No cache found for patch 5 Recomputing the patch 5
  validation loss: 11.071903228759766
best_val_loss: 11.071903228759766

No cache found for patch 9 Recomputing the patch 9t]
  validation loss: 11.06725025177002
best_val_loss: 11.06725025177002

training: 100%|██████████| 3/3 [00:36<00:00, 12.10s/it]
 33%|███▎      | 1/3 [00:36<01:12, 36.31s/it]


  validation loss: 11.359213829040527 ?it/s]
best_val_loss: 11.359213829040527

  validation loss: 11.140854835510254:12,  6.36s/it]
best_val_loss: 11.140854835510254

  validation loss: 11.295135498046875:06,  6.66s/it]

training: 100%|██████████| 3/3 [00:17<00:00,  5.67s/it]
 67%|██████▋   | 2/3 [00:53<00:24, 24.99s/it]


  validation loss: 11.087203025817871 ?it/s]
best_val_loss: 11.087203025817871

  validation loss: 11.102073669433594:16,  8.08s/it]

  validation loss: 11.101568222045898:07,  7.54s/it]

training: 100%|██████████| 3/3 [00:18<00:00,  6.27s/it]
100%|██████████| 3/3 [01:12<00:00, 24.07s/it]


[ ]: