# Training Custom Plantseg Model (Pytorch)

This blog post aims to help biologists that are not experienced in programming and deep learning to train their own model for **[PlantSeg](https://github.com/hci-unihd/plant-seg)**. PlantSeg is a versatile 3D segmentation tool of data with **fluorescently labelled membranes**. Originally it was developed for segmentation of dense plant tissues at cellular resolution; however, it can be applied to a wide variety of data. Plantseg combines deep learning based prediction of fluorescently labelled membranes and graph partitioning methods to get final instance segmentation. You can read more about the inner workings of the method in the [original publication](https://elifesciences.org/articles/57613). There are a few pre-trained models provided, and you can easily do the segmentation using these models in the GUI. You can find the instructions on how to use the GUI in the [original repository](https://github.com/hci-unihd/plant-seg). However, training a new model is not possible in the GUI. The training is based on the [pytorch-3dunet repository](https://github.com/wolny/pytorch-3dunet). This blog post tries to simplify the process of training a deep learning model for Plantseg by providing instructions on **how to do it from the jupyter notebook**, which is a more user-friendly way than the command line. You can also use these instructions to **tune the pre-trained model** with your data instead of performing the training from scratch.

## This blog post is for you if:
* You have already tried to predict with provided pre-trained models, and felt like results were good but not good enough for your data because your data differs substantially from plant data the models were trained on.
* You have thought about training your own model, however, feel intimidated by the instructions in the [*pytorch-3dunet* repository on GitHub](https://github.com/wolny/pytorch-3dunet) because originally the training is done from the terminal with YML configuration files. 

## Custom training Plantseg model is not for you if:
* You do not have good quality annotated data (so-called ground truth data).
* You have very little data.
* You don't have access to the hardware needed to handle your data and deep learning model's training.

## What will you need?
* A **conda/mamba environment** or a **singularity container** with required packages (described in Step 1). If you don't have yet conda/mamba installed on your machine, you can read [this blog post by Mara](../../mara_lampert/getting_started_with_mambaforge_and_python/readme) on how to set it up. Or if you want to do the training on the HPC cluster you can read [this blog post by Till](../../till_korten/???) on how to get started with it.
* **Annotated (labelled) data**, where labels are of good quality, otherwise, the model will learn bad "knowledge".
* Quite a bit of **data** if you are training from scratch because deep learning models are **data greedy**. Therefore, you need at least a couple of hundred of 3D timelapse time points. The model needs to "see" a variety of images during training to "learn" well, that means to perform later on well on "unseen" data.
* **Hardware powerful** enough for deep learning, including NVIDIA GPU. This will also depend on your data and parameters that you choose for the training, particularly patch size, which determines how big overlapping patches will be extracted from your images. One patch should fit into your GPU memory but it should not be way smaller than objects in your images. Therefore, it might be that a laptop with a GPU of 4 or 8 GB will be enough for the training but it also can happen that you will need to do the training on the workstation or HPC cluster. 

> _Note:_ The training in this blog post example was done on the Taurus cluster at TU Dresden with 4 NVIDIA A100 GPUs.

## Step 1: Prepare the Environment/container

If you are using conda/mamba to create the environment, the easiest way to install required packages is via the following command:

```
mamba create -n plantseg-env -c pytorch -c conda-forge -c lcerrone -c awolny plantseg

conda activate pytorch3dunet
```

Pay attention that this command will install the newest available versions of libraries (or pinned versions in some cases). Depending on your GPU, you might need a different version of cuda for pytorch to be able to run on the GPU and not only CPU, which is very important for speed of the training. You can check which versions you would need [here](https://pytorch.org/get-started/locally/), and follow installation instructions accordingly. 

> _Note:_ If you want to train the model on the alpha partition of the Taurus cluster at TUD, the versions that you need to install can be found [here](environment.txt).

To check whether you have all packages that we will need installed, run the following cell and see whether it executes without any errors.

In [1]:
from pytorch3dunet.train import main as train_plantseg
from natsort import natsorted
from glob import glob
from pathlib import Path
from sklearn.model_selection import train_test_split
import h5py
import os
import numpy as np
from tqdm import tqdm
from skimage.io import imread
from tifffile import imwrite
import yaml

import random
import torch

from pytorch3dunet.unet3d.config import load_config
from pytorch3dunet.unet3d.trainer import create_trainer
from pytorch3dunet.unet3d.utils import get_logger

  from .autonotebook import tqdm as notebook_tqdm


## Step 2: Split your Data to Train, Validation and Test Sets
Typically, in deep learning model's training workflow you split your data (raw and labels data) into three parts: training, validation and testing. To put it in very simple words, the training will be done on the training data, validation data is used to assess how well the model is doing up to that point and to adjust the direction it will continue the training, and testing data is used to evaluate the final model after it finished training. Testing data was not "seen" by the model during the training, therefore, it can be used for the evaluation.

First of all, we give the directories containing multiple 3D *tif* images (timepoints of the timelapse) of the raw data and labels data, and load all the data to memory (so it needs to fit into RAM) because we will split it to the before-mentioned three sets.

In [5]:
# give the path to the folder containing raw data
trn_raw_dir = "/projects/p_bioimage/laura/data/raw_timepoints"

# get all files with the extension .tif from the raw data path
filenames = natsorted(glob(os.path.join(trn_raw_dir, "*.tif")))

# read all images
train_images = [imread(os.path.join(trn_raw_dir, f)) for f in tqdm(filenames)]

100%|██████████| 345/345 [02:12<00:00,  2.61it/s]


In [6]:
# do the same for the labels data
trn_labels_dir = "/projects/p_bioimage/laura/data/labels_timepoints"
filenames_lbl = natsorted(glob(os.path.join(trn_labels_dir, "*.tif")))
train_labels = [imread(os.path.join(trn_labels_dir, f)) for f in tqdm(filenames_lbl)]

100%|██████████| 345/345 [02:32<00:00,  2.27it/s]


We will do the splitting the following way:
* Training data 70%
* Validation data 20%
* Testing data 10%

In [8]:
# set aside 10% of the whole dataset for the evaluation (test set)
X_train, X_test, Y_train, Y_test = train_test_split(train_images, train_labels, test_size=0.1, shuffle = True, random_state = 8)

# freeing up some memory before continuing
del train_images
del train_labels

# use the same function as above for the split into training and validation sets
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.28, random_state= 8)

In [13]:
# let's write a small function that we can just call to save the images
def save_images(images, path):
    # make a new directory, including all parent directories, if it does not exist yet (must have write access)
    Path(path).mkdir(parents=True, exist_ok=True)
    
    # iterate through images and save them
    for i, img in tqdm(enumerate(images), total=len(images)):
        imwrite(os.path.join(path, str(i) + ".tif"), img)

In [14]:
# call the function multiple times
# save training data
save_images(X_train, "data/train/raw")
save_images(Y_train, "data/train/labels")

# save validation data
save_images(X_val, "data/val/raw")
save_images(Y_val, "data/val/labels")

# save testing data
save_images(X_test, "data/test/raw")
save_images(Y_test, "data/test/labels")

100%|██████████| 223/223 [00:46<00:00,  4.80it/s]
100%|██████████| 223/223 [01:31<00:00,  2.44it/s]
100%|██████████| 87/87 [00:18<00:00,  4.72it/s]
100%|██████████| 87/87 [00:35<00:00,  2.43it/s]
100%|██████████| 35/35 [00:07<00:00,  4.78it/s]
100%|██████████| 35/35 [00:14<00:00,  2.43it/s]


## Step 3: Convert Data to HDF5 format
This particular implementation of deep learning model's training requires the dataset to be in HDF5 format.

In [1]:
def convert_tif_to_hdf5(trn_raw_dir, trn_labels_dir, val_raw_dir, val_labels_dir, save_dir):
    # MAKE TRAIN DATASET
    
    # First load only filenames of raw training data. Corresponding labels should have matching filenames. 
    # This way we can iterate over a list of filenames, and load only one file into memory during each iteration.
    print(f"Loading all .tif filenames from {trn_raw_dir}")
    fetched_files = natsorted(glob(os.path.join(trn_raw_dir, "*.tif")))
    filenames = [Path(f).name for f in fetched_files]

    # Creating a "train" folder in the given save directory
    Path(os.path.join(save_dir, "train")).mkdir(parents=True, exist_ok=True)
 
    for tif_name in tqdm(filenames, desc="Converting .tif files to hdf5. Train dataset"):
        
        raw_image = imread(os.path.join(trn_raw_dir, tif_name))          # load  raw image
        labels_image = imread(os.path.join(trn_labels_dir, tif_name))    # load labels image

        hf = h5py.File(os.path.join(save_dir, "train", tif_name.split(".")[0] + ".hdf5"), 'a')  # open a hdf5 file
        hf.create_dataset(name="raw", data=np.array(raw_image))          # write raw data to hdf5 file
        hf.create_dataset(name="label", data=np.array(labels_image))     # write label data to hdf5 file
        hf.close()                                                       # close the hdf5 file

    # MAKE VALIDATION DATASET
    
    print(f"Loading all .tif filenames from {val_raw_dir}")
    fetched_files_val = natsorted(glob(os.path.join(val_raw_dir, "*.tif")))
    filenames_val = [Path(f).name for f in fetched_files_val]

    Path(os.path.join(save_dir, "val")).mkdir(parents=True, exist_ok=True)

    for tif_name in tqdm(filenames_val, desc="Converting .tif files to hdf5. Val dataset"):
        
        raw_image = imread(os.path.join(val_raw_dir, tif_name))          # load  raw image
        labels_image = imread(os.path.join(val_labels_dir, tif_name))    # load labels image
        
        hf = h5py.File(os.path.join(save_dir, "val", tif_name.split(".")[0] + ".hdf5"), 'a')    # open a hdf5 file
        hf.create_dataset(name="raw", data=np.array(raw_image))          # write raw data to hdf5 file
        hf.create_dataset(name="label", data=np.array(labels_image))     # write label data to hdf5 file
        hf.close() 

In [25]:
trn_raw_dir    = "/projects/p_bioimage/laura/data/train/raw"
trn_labels_dir = "/projects/p_bioimage/laura/data/train/labels"
val_raw_dir    = "/projects/p_bioimage/laura/data/val/raw"
val_labels_dir = "/projects/p_bioimage/laura/data/val/labels"
save_dir       = "data/hdf5"

convert_tif_to_hdf5(trn_raw_dir, trn_labels_dir, val_raw_dir, val_labels_dir, save_dir)

Loading all .tif filenames from /projects/p_bioimage/laura/data/val/raw


Converting .tif files to hdf5. Val dataset: 100%|██████████| 87/87 [02:09<00:00,  1.49s/it]


# Open Training Configuration file

In [2]:
with open("/projects/p_bioimage/laura/train_config.yml", 'r') as file:
    config = yaml.safe_load(file)

In [3]:
config

{'model': {'name': 'ResidualUNet3D',
  'in_channels': 1,
  'out_channels': 1,
  'layer_order': 'gcr',
  'f_maps': 32,
  'num_groups': 8,
  'final_sigmoid': True},
 'loss': {'name': 'BCEDiceLoss',
  'ignore_index': None,
  'skip_last_target': True},
 'optimizer': {'learning_rate': 0.0002, 'weight_decay': 1e-05},
 'eval_metric': {'name': 'BoundaryAdaptedRandError',
  'threshold': 0.4,
  'use_last_target': True,
  'use_first_input': True},
 'lr_scheduler': {'name': 'ReduceLROnPlateau',
  'mode': 'min',
  'factor': 0.2,
  'patience': 20},
 'trainer': {'eval_score_higher_is_better': False,
  'checkpoint_dir': 'CHECKPOINT_DIR',
  'resume': None,
  'pre_trained': None,
  'validate_after_iters': 1000,
  'log_after_iters': 500,
  'max_num_epochs': 1000,
  'max_num_iterations': 150000},
 'loaders': {'num_workers': 8,
  'raw_internal_path': '/raw',
  'label_internal_path': '/label',
  'train': {'file_paths': ['PATH_TO_TRAIN_DIR'],
   'slice_builder': {'name': 'FilterSliceBuilder',
    'patch_shap

In [4]:
config["loaders"]["train"]["file_paths"] = ["/projects/p_bioimage/laura/data/hdf5_for_plantseg/train"]
print(config["loaders"]["train"]["file_paths"])

['/projects/p_bioimage/laura/data/hdf5_for_plantseg/train']


In [5]:
config["loaders"]["val"]["file_paths"] = ["/projects/p_bioimage/laura/data/hdf5_for_plantseg/val"]
print(config["loaders"]["val"]["file_paths"])

['/projects/p_bioimage/laura/data/hdf5_for_plantseg/val']


In [6]:
config["model"]["name"] = "UNet3D"

In [7]:
config["lr_scheduler"]["patience"] = 10

In [8]:
config["trainer"]["validate_after_iters"] = 100

In [9]:
config["trainer"]["log_after_iters"] = 100

In [10]:
config["trainer"]["max_num_epochs"] = 500

In [11]:
config["trainer"]["max_num_iterations"] = 60000

In [12]:
config["optimizer"]["learning_rate"] = 0.001

In [13]:
config["trainer"]["checkpoint_dir"] = "Model_3"

In [14]:
# config["loaders"]["batch_size"] = 2

### Patch size

In this case we also need to change patch size, because it needs to be smaller than data size.

In [15]:
# print patch size in the default configuration
print(config["loaders"]["train"]["slice_builder"]["patch_shape"])

[80, 170, 170]


In [16]:
config["loaders"]["train"]["slice_builder"]["patch_shape"] = [64, 260, 260]
print(config["loaders"]["train"]["slice_builder"]["patch_shape"])

[64, 260, 260]


### Stride shape

We also change the stride shape, which determines how much our patches will overlap

In [17]:
STRIDE_MENU = {
    "accurate": 0.5,
    "balanced": 0.75,
    "draft": 0.9
}

In [18]:
def get_stride_shape(patch_shape, stride_key):
    return [max(int(p * STRIDE_MENU[stride_key]), 1) for p in patch_shape]

In [19]:
# print patch stride in the default configuration
print(config["loaders"]["train"]["slice_builder"]["stride_shape"])

[20, 40, 40]


In [20]:
config["loaders"]["train"]["slice_builder"]["stride_shape"] = get_stride_shape(config["loaders"]["train"]["slice_builder"]["patch_shape"], "accurate")
config["loaders"]["train"]["slice_builder"]["stride_shape"] = [1, 130, 130]
#print(config["loaders"]["train"]["slice_builder"]["stride_shape"])

Now we should do the same for the validation dataset

In [21]:
config["loaders"]["val"]["slice_builder"]["stride_shape"] = config["loaders"]["train"]["slice_builder"]["stride_shape"]
config["loaders"]["val"]["slice_builder"]["patch_shape"] = config["loaders"]["train"]["slice_builder"]["patch_shape"]

In [22]:
print(config["loaders"]["num_workers"])

8


In [23]:
config["loaders"]["num_workers"] = 6

## Continuing Training the Pre-trained Model

If a pretrained model worked relatively well to your data, it might make sense to take pretrained model's weights and continue training your model on your data, which will be less time constly. If you want to do the training this way, download pre-trained model and provide the path to it to the parameter below:

In [24]:
# cfg["trainer"]["checkpoint_dir"] = [""]
config["trainer"]["checkpoint_dir"]

'Model_3'

In [25]:
config["device"] = None

In [26]:
config

{'model': {'name': 'UNet3D',
  'in_channels': 1,
  'out_channels': 1,
  'layer_order': 'gcr',
  'f_maps': 32,
  'num_groups': 8,
  'final_sigmoid': True},
 'loss': {'name': 'BCEDiceLoss',
  'ignore_index': None,
  'skip_last_target': True},
 'optimizer': {'learning_rate': 0.001, 'weight_decay': 1e-05},
 'eval_metric': {'name': 'BoundaryAdaptedRandError',
  'threshold': 0.4,
  'use_last_target': True,
  'use_first_input': True},
 'lr_scheduler': {'name': 'ReduceLROnPlateau',
  'mode': 'min',
  'factor': 0.2,
  'patience': 10},
 'trainer': {'eval_score_higher_is_better': False,
  'checkpoint_dir': 'Model_3',
  'resume': None,
  'pre_trained': None,
  'validate_after_iters': 100,
  'log_after_iters': 100,
  'max_num_epochs': 500,
  'max_num_iterations': 60000},
 'loaders': {'num_workers': 6,
  'raw_internal_path': '/raw',
  'label_internal_path': '/label',
  'train': {'file_paths': ['/projects/p_bioimage/laura/data/hdf5_for_plantseg/train'],
   'slice_builder': {'name': 'FilterSliceBuilde

# Training

In [27]:
device_str = config.get('device', None)
if device_str is not None:
    print(f"Device specified in config: '{device_str}'")
    if device_str.startswith('cuda') and not torch.cuda.is_available():
        print('CUDA not available, using CPU')
        device_str = 'cpu'
else:
    device_str = "cuda:0" if torch.cuda.is_available() else 'cpu'
    print(f"Using '{device_str}' device")

device = torch.device(device_str)
config['device'] = device

Using 'cuda:0' device


In [None]:
manual_seed = config.get('manual_seed', None)
if manual_seed is not None:
    logger.info(f'Seed the RNG for all devices with {manual_seed}')
    logger.warning('Using CuDNN deterministic setting. This may slow down the training!')
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    # see https://pytorch.org/docs/stable/notes/randomness.html
    torch.backends.cudnn.deterministic = True

# create trainer
trainer = create_trainer(config)
# Start training
trainer.fit()

2023-02-03 11:50:50,569 [MainThread] INFO UNet3DTrainer - Using 4 GPUs for training
2023-02-03 11:50:50,570 [MainThread] INFO UNet3DTrainer - Sending the model to 'cuda:0'
2023-02-03 11:50:54,192 [MainThread] INFO UNet3DTrainer - Number of learnable params 4081267
2023-02-03 11:50:54,193 [MainThread] INFO Dataset - Creating training and validation set loaders...
2023-02-03 11:50:54,390 [MainThread] INFO HDF5Dataset - Loading train set from: /projects/p_bioimage/laura/data/hdf5_for_plantseg/train/0.hdf5...
2023-02-03 11:50:54,972 [MainThread] INFO Dataset - Slice builder config: {'name': 'FilterSliceBuilder', 'patch_shape': [64, 260, 260], 'stride_shape': [1, 130, 130], 'threshold': 0.6, 'slack_acceptance': 0.01}
2023-02-03 11:50:55,403 [MainThread] INFO HDF5Dataset - Number of patches: 1
2023-02-03 11:50:55,404 [MainThread] INFO HDF5Dataset - Loading train set from: /projects/p_bioimage/laura/data/hdf5_for_plantseg/train/1.hdf5...
2023-02-03 11:50:55,893 [MainThread] INFO Dataset - Sli