{ "cells": [ { "cell_type": "markdown", "id": "45c85f94c95b5ccf", "metadata": {}, "source": [ "# Training RNA2seg on Zarr-Saved SpatialData \n", "\n", "\n", "This notebook demonstrates how to train RNA2seg on spatial transcriptomics data stored in a Zarr file. The process consists of four main steps: \n", "\n", "1. **Patch Creation** – Extract patches of a reasonable size to process efficiently (saved in the Zarr file). \n", "2. **Filtered Target Generation** – Create a curated segmentation mask from a teacher model for RNA2seg training (saved in the Zarr file). \n", "3. **Model Training** – Train RNA2seg using the generated patches and filtered segmentation. \n", "4. **Apply to the whole dataset** – Use the notebook `apply_model_on_zarr.ipynb` to apply the trained model to the entire dataset.\n", "\n", "\n", "**Test data for this notebook can be downloaded at: https://cloud.minesparis.psl.eu/index.php/s/qw2HaDVxwwy1EOK** \\\n", "This dataset is a subset of the Mouse Ileum dataset from Petukhov. et al. Nat Biotechnol 40, 345–354 (2022). https://doi.org/10.1038/s41587-021-01044-w \n" ] }, { "cell_type": "markdown", "id": "094f120a", "metadata": {}, "source": [ "## Import " ] }, { "cell_type": "code", "execution_count": 1, "id": "6f930c9e-86ac-4c2a-84fe-45839146d9eb", "metadata": { "ExecuteTime": { "end_time": "2025-03-14T11:44:58.841955Z", "start_time": "2025-03-14T11:44:58.833959Z" } }, "outputs": [ { "data": { "text/plain": [ "'0.1.0'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import rna2seg\n", "rna2seg.__version__" ] }, { "cell_type": "code", "execution_count": 2, "id": "cf4aca94", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "c645b82d3b2173fb", "metadata": { "ExecuteTime": { "end_time": "2025-02-06T14:44:29.532122Z", "start_time": "2025-02-06T14:44:22.498902Z" } }, "outputs": [], "source": [ "import cv2\n", "import torch\n", "import numpy as np\n", "from tqdm import tqdm\n", "import spatialdata as sd\n", "from pathlib import Path\n", "import albumentations as A\n", "\n", "from rna2seg.dataset_zarr import (\n", " RNA2segDataset, custom_collate_fn, compute_consistent_cell\n", ")\n" ] }, { "cell_type": "markdown", "id": "fd139779-deee-4a88-b02a-89628a171fcf", "metadata": {}, "source": [ "### Set your own path " ] }, { "cell_type": "code", "execution_count": 4, "id": "ecc8cc5c-6189-4666-927f-b70cb6e3563a", "metadata": {}, "outputs": [], "source": [ "\n", "## path to spatial data\n", "merfish_zarr_path = \"/media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr\"\n", "\n", "path_save_model = \"/media/tom/Transcend/open_merfish/test_spatial_data/test005\"\n", "\n", "### load sdata and set path parameters \n", "sdata = sd.read_zarr(merfish_zarr_path)\n", "image_key = \"staining_z3\"\n", "patch_width = 1200\n", "patch_overlap = 50\n", "points_key = \"transcripts\"\n", "min_points_per_patch = 0\n", "folder_patch_rna2seg = Path(merfish_zarr_path) / f\".rna2seg_{patch_width}_{patch_overlap}\"\n", "\n", "\n", "channels_dapi= [\"DAPI\"]\n", "channels_cellbound=[\"Cellbound1\"]\n", "\n", "gene_column_name=\"gene\" " ] }, { "cell_type": "markdown", "id": "c3d958b7bbf928bf", "metadata": {}, "source": [ "## Step 1: Create training patches from Zarr files\n", "\n", "In this step, the dataset (image + transcripts) is divided into patches of size `patch_width × patch_width` with an overlap of `patch_overlap`. This allows processing images of a manageable size while preserving spatial continuity. \n", "\n", "**Process** \n", "- The dataset, stored in Zarr format, is loaded. \n", "- Patches coordinates are saved as a `Shape` in the zarr: `sopa_patches_rna2seg_[patch_width]_[patch_overlap]`. \n", "- A `.rna2seg` directory is created to store the transcript data corresponding to each patch. \n", "- The transcript information for each patch is saved in CSV format for further processing. \n" ] }, { "cell_type": "code", "execution_count": 5, "id": "60be0f9817553592", "metadata": { "ExecuteTime": { "end_time": "2025-01-15T13:28:17.140897Z", "start_time": "2025-01-15T13:28:17.023449Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36;20m[INFO] (sopa.patches._patches)\u001b[0m Added 16 patche(s) to sdata['sopa_patches_rna2seg_1200_50']\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[########################################] | 100% Completed | 3.46 sms\n", "[########################################] | 100% Completed | 3.55 ss\n", "SpatialData object, with associated Zarr store: /media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr\n", "├── Images\n", "│ └── 'staining_z3': DataTree[cyx] (5, 3704, 3704), (5, 1852, 1851), (5, 926, 926), (5, 463, 463), (5, 232, 232)\n", "├── Points\n", "│ └── 'transcripts': DataFrame with shape: (, 10) (3D points)\n", "└── Shapes\n", " ├── 'Cellbound1': GeoDataFrame shape: (569, 1) (2D shapes)\n", " ├── 'DAPI': GeoDataFrame shape: (409, 1) (2D shapes)\n", " ├── 'sopa_patches_rna2seg_1200_50': GeoDataFrame shape: (16, 3) (2D shapes)\n", " └── 'sopa_patches_rna2seg_1200_150': GeoDataFrame shape: (16, 3) (2D shapes)\n", "with coordinate systems:\n", " ▸ 'microns', with elements:\n", " staining_z3 (Images), transcripts (Points), Cellbound1 (Shapes), DAPI (Shapes), sopa_patches_rna2seg_1200_50 (Shapes), sopa_patches_rna2seg_1200_150 (Shapes)\n" ] } ], "source": [ "from rna2seg.dataset_zarr import create_patch_rna2seg\n", "\n", "\n", "### create patch in the sdata and precompute transcipt.csv for each patch with sopa\n", "create_patch_rna2seg(sdata,\n", " image_key=image_key,\n", " points_key=points_key,\n", " patch_width=patch_width,\n", " patch_overlap=patch_overlap,\n", " min_points_per_patch=min_points_per_patch,\n", " folder_patch_rna2seg = folder_patch_rna2seg,\n", " overwrite = True,\n", " gene_column_name=gene_column_name)\n", "\n", "print(sdata)" ] }, { "cell_type": "markdown", "id": "e75f5cac439b7149", "metadata": {}, "source": [ "## Step 2: Create a Consistent Target for Training RNA2seg\n", "\n", "**Input:** Spatial data with potentially erroneous nucleus and cell segmentations. \n", "**Output:** Curated cell and nucleus segmentations for training RNA2seg. Saved in the zarr.\n", "\n", "This step refines two segmentations stored in the Zarr file: **cell segmentation** (`key_shape_cell_seg`) and **nuclei segmentation** (`key_nuclei_segmentation`). \n", "The goal is to generate a **teacher segmentation** by filtering out inconsistencies between cells and nuclei. \n", "\n", "**Process** \n", "1. Load the segmentations (`Cellbound1` and `DAPI`) from the Zarr file. \n", "2. Apply a **consistency check** to remove unreliable segmentations: \n", " - **Consistent cell segmentation** → `Cellbound1_consistent` \n", " - **Consistent nuclei segmentation** → `DAPI_consistent` \n", "3. Save the refined segmentations back into the Zarr file. \n", "\n", "This ensures high-quality annotations for training or fine-tuning RNA2seg. \n" ] }, { "cell_type": "code", "execution_count": 6, "id": "3508e0a5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Resolving conflicts: 1344it [00:00, 10513.37it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SpatialData object, with associated Zarr store: /media/tom/Transcend/open_merfish/test_spatial_data/test005/sub_mouse_ileum.zarr\n", "├── Images\n", "│ └── 'staining_z3': DataTree[cyx] (5, 3704, 3704), (5, 1852, 1851), (5, 926, 926), (5, 463, 463), (5, 232, 232)\n", "├── Points\n", "│ └── 'transcripts': DataFrame with shape: (, 10) (3D points)\n", "└── Shapes\n", " ├── 'Cellbound1': GeoDataFrame shape: (569, 1) (2D shapes)\n", " ├── 'Cellbound1_consistent_with_nuclei': GeoDataFrame shape: (204, 1) (2D shapes)\n", " ├── 'Cellbound1_consistent_without_nuclei': GeoDataFrame shape: (248, 1) (2D shapes)\n", " ├── 'DAPI': GeoDataFrame shape: (409, 1) (2D shapes)\n", " ├── 'DAPI_consistent_in_cell': GeoDataFrame shape: (204, 1) (2D shapes)\n", " ├── 'DAPI_consistent_not_in_cell': GeoDataFrame shape: (171, 1) (2D shapes)\n", " ├── 'sopa_patches_rna2seg_1200_50': GeoDataFrame shape: (16, 3) (2D shapes)\n", " └── 'sopa_patches_rna2seg_1200_150': GeoDataFrame shape: (16, 3) (2D shapes)\n", "with coordinate systems:\n", " ▸ 'microns', with elements:\n", " staining_z3 (Images), transcripts (Points), Cellbound1 (Shapes), Cellbound1_consistent_with_nuclei (Shapes), Cellbound1_consistent_without_nuclei (Shapes), DAPI (Shapes), DAPI_consistent_in_cell (Shapes), DAPI_consistent_not_in_cell (Shapes), sopa_patches_rna2seg_1200_50 (Shapes), sopa_patches_rna2seg_1200_150 (Shapes)\n" ] } ], "source": [ "key_cell_segmentation = \"Cellbound1\"\n", "key_nuclei_segmentation=\"DAPI\"\n", "# to name for future shape that will be created in the sdata\n", "key_cell_consistent = \"Cellbound1_consistent\"\n", "key_nucleus_consistent = \"DAPI_consistent\"\n", "\n", "sdata, _ = compute_consistent_cell(\n", " sdata=sdata,\n", " key_shape_nuclei_seg=key_nuclei_segmentation,\n", " key_shape_cell_seg=key_cell_segmentation,\n", " key_cell_consistent=key_cell_consistent,\n", " key_nuclei_consistent=key_nucleus_consistent,\n", " image_key=image_key,\n", " threshold_intersection_contain=0.95,\n", " threshold_intersection_intersect= 0.05,\n", " accepted_nb_nuclei_per_cell=None,\n", " max_cell_nb_intersecting_nuclei=1,\n", ")\n", "print(sdata)" ] }, { "cell_type": "markdown", "id": "dd9ebdb1", "metadata": {}, "source": [ "## Step 3: Training RNA2seg\n", "\n", "Now, we will train RNA2seg using the target segmentation created in Step 2. \n", "\n", "### Initialize a RNA2segDataset" ] }, { "cell_type": "code", "execution_count": 7, "id": "68afe78c9f9e67da", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No module named 'vmunet'\n", "VMUnet not loaded\n", "100%|██████████| 16/16 [00:00<00:00, 42.29it/s]\n", "Number of valid patches: 6\n", "100%|██████████| 16/16 [00:00<00:00, 26.38it/s]\n", "compute density threshold\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 6/6 [00:01<00:00, 4.60it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Time to compute density threshold: 1.305933s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from rna2seg.models import RNA2seg\n", "\n", "transform_resize = A.Compose([\n", " A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST),\n", "])\n", "\n", "dataset = RNA2segDataset(\n", " sdata=sdata,\n", " channels_dapi=channels_dapi,\n", " channels_cellbound=channels_cellbound,\n", " shape_patch_key=f\"sopa_patches_rna2seg_{patch_width}_{patch_overlap}\", # Created at step 1\n", " key_cell_consistent=key_cell_consistent, # Created at step 2\n", " key_nucleus_consistent=key_nucleus_consistent, # Created at step 2\n", " key_nuclei_segmentation=key_nuclei_segmentation,\n", " gene_column=gene_column_name,\n", " density_threshold = None,\n", " kernel_size_background_density = 10 ,\n", " kernel_size_rna2img = 0.5,\n", " max_filter_size_rna2img = 2,\n", " transform_resize = transform_resize,\n", " training_mode = True,\n", " patch_dir = folder_patch_rna2seg,\n", " patch_width=1200,\n", " patch_overlap=50,\n", " use_cache = True, \n", ")" ] }, { "cell_type": "markdown", "id": "33b3cfdf", "metadata": {}, "source": [ "### Train / Validataion split" ] }, { "cell_type": "code", "execution_count": 8, "id": "2f02bd10", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "from torch.utils.data.sampler import SubsetRandomSampler\n", "\n", "train_indices, val_indices = train_test_split(\n", " range(len(dataset)), test_size=0.1, shuffle=True, random_state=42\n", ")\n", "train_sampler = SubsetRandomSampler(train_indices) \n", "valid_sampler = SubsetRandomSampler(val_indices)" ] }, { "cell_type": "markdown", "id": "ec226eab", "metadata": {}, "source": [ "### Initialize Dataloaders" ] }, { "cell_type": "code", "execution_count": 9, "id": "441fd505", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "len(training_loader) 3\n", "len(training_loader) 1\n" ] } ], "source": [ "training_loader = torch.utils.data.DataLoader(dataset,\n", " batch_size=2,\n", " shuffle=False,\n", " num_workers = 0,\n", " sampler=train_sampler,\n", " collate_fn = custom_collate_fn,\n", " )\n", "\n", "print( f\"len(training_loader) {len(training_loader)}\")\n", "\n", "validation_loader = torch.utils.data.DataLoader(dataset,\n", " batch_size=2,\n", " shuffle=False,\n", " num_workers = 0,\n", " sampler=valid_sampler,\n", " collate_fn = custom_collate_fn,\n", " )\n", "\n", "print( f\"len(training_loader) {len(validation_loader)}\")" ] }, { "cell_type": "markdown", "id": "500719ae4a3014d6", "metadata": {}, "source": [ "### Initilize RNA2seg Model" ] }, { "cell_type": "code", "execution_count": 10, "id": "8ca6ce7e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cpu\n" ] } ], "source": [ "import torch\n", "\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", "else:\n", " device = torch.device(\"cpu\")\n", "\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "5da486ed8094a36a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initiaisation of CPnet\n", "Initiaisation of ChannelInvariantNet\n" ] } ], "source": [ "rna2seg = RNA2seg(\n", " device,\n", " net='unet',\n", " flow_threshold = 0.9,\n", " cellbound_flow_threshold = 0.4,\n", " pretrained_model = None,\n", ")\n", "rna2seg = rna2seg.to(device)\n", "\n", "optimizer = torch.optim.AdamW(rna2seg.parameters(), lr=0.001, weight_decay=0.01)\n" ] }, { "cell_type": "markdown", "id": "08c96e56", "metadata": {}, "source": [ "### Training RNA2seg" ] }, { "cell_type": "code", "execution_count": 12, "id": "646fc91640fc8f93", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/3 [00:00