Training RNA2seg on Zarr-Saved SpatialData¶
This notebook demonstrates how to train RNA2seg on spatial transcriptomics data stored in a Zarr file. The process consists of four main steps:
Patch Creation – Extract patches of a reasonable size to process efficiently (saved in the Zarr file).
Filtered Target Generation – Create a curated segmentation mask from a teacher model for RNA2seg training (saved in the Zarr file).
Model Training – Train RNA2seg using the generated patches and filtered segmentation.
Apply to the whole dataset – Use the notebook
apply_model_on_zarr.ipynbto apply the trained model to the entire dataset.
Import¶
[1]:
import rna2seg
rna2seg.__version__
[1]:
'0.1.0'
[2]:
import warnings
warnings.filterwarnings("ignore")
[3]:
import cv2
import torch
import numpy as np
from tqdm import tqdm
import spatialdata as sd
from pathlib import Path
import albumentations as A
from rna2seg.dataset_zarr import (
RNA2segDataset, custom_collate_fn, compute_consistent_cell
)
Set your own path¶
[4]:
## path to spatial data
merfish_zarr_path = "/media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr"
path_save_model = "/media/tom/Transcend/open_merfish/test_spatial_data/test005"
### load sdata and set path parameters
sdata = sd.read_zarr(merfish_zarr_path)
image_key = "staining_z3"
patch_width = 1200
patch_overlap = 50
points_key = "transcripts"
min_points_per_patch = 0
folder_patch_rna2seg = Path(merfish_zarr_path) / f".rna2seg_{patch_width}_{patch_overlap}"
channels_dapi= ["DAPI"]
channels_cellbound=["Cellbound1"]
gene_column_name="gene"
Step 1: Create training patches from Zarr files¶
In this step, the dataset (image + transcripts) is divided into patches of size patch_width × patch_width with an overlap of patch_overlap. This allows processing images of a manageable size while preserving spatial continuity.
Process
The dataset, stored in Zarr format, is loaded.
Patches coordinates are saved as a
Shapein the zarr:sopa_patches_rna2seg_[patch_width]_[patch_overlap].A
.rna2segdirectory is created to store the transcript data corresponding to each patch.The transcript information for each patch is saved in CSV format for further processing.
[5]:
from rna2seg.dataset_zarr import create_patch_rna2seg
### create patch in the sdata and precompute transcipt.csv for each patch with sopa
create_patch_rna2seg(sdata,
image_key=image_key,
points_key=points_key,
patch_width=patch_width,
patch_overlap=patch_overlap,
min_points_per_patch=min_points_per_patch,
folder_patch_rna2seg = folder_patch_rna2seg,
overwrite = True,
gene_column_name=gene_column_name)
print(sdata)
[INFO] (sopa.patches._patches) Added 16 patche(s) to sdata['sopa_patches_rna2seg_1200_50']
[########################################] | 100% Completed | 3.46 sms
[########################################] | 100% Completed | 3.55 ss
SpatialData object, with associated Zarr store: /media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr
├── Images
│ └── 'staining_z3': DataTree[cyx] (5, 3704, 3704), (5, 1852, 1851), (5, 926, 926), (5, 463, 463), (5, 232, 232)
├── Points
│ └── 'transcripts': DataFrame with shape: (<Delayed>, 10) (3D points)
└── Shapes
├── 'Cellbound1': GeoDataFrame shape: (569, 1) (2D shapes)
├── 'DAPI': GeoDataFrame shape: (409, 1) (2D shapes)
├── 'sopa_patches_rna2seg_1200_50': GeoDataFrame shape: (16, 3) (2D shapes)
└── 'sopa_patches_rna2seg_1200_150': GeoDataFrame shape: (16, 3) (2D shapes)
with coordinate systems:
▸ 'microns', with elements:
staining_z3 (Images), transcripts (Points), Cellbound1 (Shapes), DAPI (Shapes), sopa_patches_rna2seg_1200_50 (Shapes), sopa_patches_rna2seg_1200_150 (Shapes)
Step 2: Create a Consistent Target for Training RNA2seg¶
key_shape_cell_seg) and nuclei segmentation (key_nuclei_segmentation).Process
Load the segmentations (
Cellbound1andDAPI) from the Zarr file.Apply a consistency check to remove unreliable segmentations:
Consistent cell segmentation →
Cellbound1_consistentConsistent nuclei segmentation →
DAPI_consistent
Save the refined segmentations back into the Zarr file.
This ensures high-quality annotations for training or fine-tuning RNA2seg.
[6]:
key_cell_segmentation = "Cellbound1"
key_nuclei_segmentation="DAPI"
# to name for future shape that will be created in the sdata
key_cell_consistent = "Cellbound1_consistent"
key_nucleus_consistent = "DAPI_consistent"
sdata, _ = compute_consistent_cell(
sdata=sdata,
key_shape_nuclei_seg=key_nuclei_segmentation,
key_shape_cell_seg=key_cell_segmentation,
key_cell_consistent=key_cell_consistent,
key_nuclei_consistent=key_nucleus_consistent,
image_key=image_key,
threshold_intersection_contain=0.95,
threshold_intersection_intersect= 0.05,
accepted_nb_nuclei_per_cell=None,
max_cell_nb_intersecting_nuclei=1,
)
print(sdata)
Resolving conflicts: 1344it [00:00, 10513.37it/s]
SpatialData object, with associated Zarr store: /media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr
├── Images
│ └── 'staining_z3': DataTree[cyx] (5, 3704, 3704), (5, 1852, 1851), (5, 926, 926), (5, 463, 463), (5, 232, 232)
├── Points
│ └── 'transcripts': DataFrame with shape: (<Delayed>, 10) (3D points)
└── Shapes
├── 'Cellbound1': GeoDataFrame shape: (569, 1) (2D shapes)
├── 'Cellbound1_consistent_with_nuclei': GeoDataFrame shape: (204, 1) (2D shapes)
├── 'Cellbound1_consistent_without_nuclei': GeoDataFrame shape: (248, 1) (2D shapes)
├── 'DAPI': GeoDataFrame shape: (409, 1) (2D shapes)
├── 'DAPI_consistent_in_cell': GeoDataFrame shape: (204, 1) (2D shapes)
├── 'DAPI_consistent_not_in_cell': GeoDataFrame shape: (171, 1) (2D shapes)
├── 'sopa_patches_rna2seg_1200_50': GeoDataFrame shape: (16, 3) (2D shapes)
└── 'sopa_patches_rna2seg_1200_150': GeoDataFrame shape: (16, 3) (2D shapes)
with coordinate systems:
▸ 'microns', with elements:
staining_z3 (Images), transcripts (Points), Cellbound1 (Shapes), Cellbound1_consistent_with_nuclei (Shapes), Cellbound1_consistent_without_nuclei (Shapes), DAPI (Shapes), DAPI_consistent_in_cell (Shapes), DAPI_consistent_not_in_cell (Shapes), sopa_patches_rna2seg_1200_50 (Shapes), sopa_patches_rna2seg_1200_150 (Shapes)
Step 3: Training RNA2seg¶
Now, we will train RNA2seg using the target segmentation created in Step 2.
Initialize a RNA2segDataset¶
[7]:
from rna2seg.models import RNA2seg
transform_resize = A.Compose([
A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST),
])
dataset = RNA2segDataset(
sdata=sdata,
channels_dapi=channels_dapi,
channels_cellbound=channels_cellbound,
shape_patch_key=f"sopa_patches_rna2seg_{patch_width}_{patch_overlap}", # Created at step 1
key_cell_consistent=key_cell_consistent, # Created at step 2
key_nucleus_consistent=key_nucleus_consistent, # Created at step 2
key_nuclei_segmentation=key_nuclei_segmentation,
gene_column=gene_column_name,
density_threshold = None,
kernel_size_background_density = 10 ,
kernel_size_rna2img = 0.5,
max_filter_size_rna2img = 2,
transform_resize = transform_resize,
training_mode = True,
patch_dir = folder_patch_rna2seg,
patch_width=1200,
patch_overlap=50,
use_cache = True,
)
No module named 'vmunet'
VMUnet not loaded
100%|██████████| 16/16 [00:00<00:00, 42.29it/s]
Number of valid patches: 6
100%|██████████| 16/16 [00:00<00:00, 26.38it/s]
compute density threshold
100%|██████████| 6/6 [00:01<00:00, 4.60it/s]
Time to compute density threshold: 1.305933s
Train / Validataion split¶
[8]:
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler
train_indices, val_indices = train_test_split(
range(len(dataset)), test_size=0.1, shuffle=True, random_state=42
)
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
Initialize Dataloaders¶
[9]:
training_loader = torch.utils.data.DataLoader(dataset,
batch_size=2,
shuffle=False,
num_workers = 0,
sampler=train_sampler,
collate_fn = custom_collate_fn,
)
print( f"len(training_loader) {len(training_loader)}")
validation_loader = torch.utils.data.DataLoader(dataset,
batch_size=2,
shuffle=False,
num_workers = 0,
sampler=valid_sampler,
collate_fn = custom_collate_fn,
)
print( f"len(training_loader) {len(validation_loader)}")
len(training_loader) 3
len(training_loader) 1
Initilize RNA2seg Model¶
[10]:
import torch
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
Using device: cpu
[11]:
rna2seg = RNA2seg(
device,
net='unet',
flow_threshold = 0.9,
cellbound_flow_threshold = 0.4,
pretrained_model = None,
)
rna2seg = rna2seg.to(device)
optimizer = torch.optim.AdamW(rna2seg.parameters(), lr=0.001, weight_decay=0.01)
initiaisation of CPnet
Initiaisation of ChannelInvariantNet
Training RNA2seg¶
[12]:
from rna2seg.train import train_one_epoch
best_val_loss = np.inf
for epoch_index in tqdm(range(3)):
train_one_epoch(
device=device,
epoch_index=epoch_index,
rna2seg=rna2seg,
training_loader=training_loader,
optimizer=optimizer,
print_loss_every = int(len(training_loader) /3),
tb_writer= None,
validation_loader=validation_loader,
path_save_model=path_save_model,
cellbound_prob= 0.8,
best_val_loss=best_val_loss
)
0%| | 0/3 [00:00<?, ?it/s]
No cache found for patch 8 Recomputing the patch 8
No cache found for patch 10 Recomputing the patch 10
No cache found for patch 4 Recomputing the patch 4
validation loss: 11.355034828186035
best_val_loss: 11.355034828186035
No cache found for patch 6 Recomputing the patch 6t]
No cache found for patch 5 Recomputing the patch 5
validation loss: 11.071903228759766
best_val_loss: 11.071903228759766
No cache found for patch 9 Recomputing the patch 9t]
validation loss: 11.06725025177002
best_val_loss: 11.06725025177002
training: 100%|██████████| 3/3 [00:36<00:00, 12.10s/it]
33%|███▎ | 1/3 [00:36<01:12, 36.31s/it]
validation loss: 11.359213829040527 ?it/s]
best_val_loss: 11.359213829040527
validation loss: 11.140854835510254:12, 6.36s/it]
best_val_loss: 11.140854835510254
validation loss: 11.295135498046875:06, 6.66s/it]
training: 100%|██████████| 3/3 [00:17<00:00, 5.67s/it]
67%|██████▋ | 2/3 [00:53<00:24, 24.99s/it]
validation loss: 11.087203025817871 ?it/s]
best_val_loss: 11.087203025817871
validation loss: 11.102073669433594:16, 8.08s/it]
validation loss: 11.101568222045898:07, 7.54s/it]
training: 100%|██████████| 3/3 [00:18<00:00, 6.27s/it]
100%|██████████| 3/3 [01:12<00:00, 24.07s/it]
[ ]: