Skip to contents

seurat_num_clusters <- c() recall_num_clusters <- c()

First, we set up two functions for generating the rare cell type scenarios.

downsample_cell_type <- function(seurat_obj, cell_type, num_downsampled) {
  cells_to_downsample <- Cells(seurat_obj)[which(seurat_obj@meta.data$Group == cell_type)]
  cells_to_keep <- Cells(seurat_obj)[which(seurat_obj@meta.data$Group != cell_type)]
  downsampled_cells <-  sample(cells_to_downsample, size = num_downsampled)

  subsetted_seurat_obj = subset(seurat_obj, cells = c(downsampled_cells, cells_to_keep))
  
  return(subsetted_seurat_obj)
}


rare_cell_type_simulation <- function(total_num_cells,
                                      num_genes,
                                      cell_type_proportions,
                                      downsampled_counts,
                                      num_replicates) {
  
  sim.groups <- splatter::splatSimulate(group.prob = cell_type_proportions, method = "groups",
                                        verbose = FALSE,
                                        nGenes = num_genes,
                                        batchCells = total_num_cells,
                                        dropout.type = "experiment", 
                                        de.prob = 0.05)
  
  seurat_obj <- Seurat::as.Seurat(sim.groups, counts = "counts", data = NULL)
  seurat_obj <- SeuratObject::RenameAssays(object = seurat_obj, originalexp = 'RNA')
  
  
  # set up vectors for results
  num_groups <- c()
  original_num_cells <- c()
  downsampled_count <- c()
  replicate <- c()
  
  recall_num_clusters <- c()
  scSHC_num_clusters <- c()
  CHOIR_num_clusters <- c()
  
  j = 0
  for (num_downsampled in downsampled_counts) {
    for (i in 1:num_replicates) {
      j = j + 1
      
      print(table(seurat_obj@meta.data$Group))
      
      downsampled_seurat_obj <- downsample_cell_type(seurat_obj, cell_type = "Group1", num_downsampled)
      
      print(table(downsampled_seurat_obj@meta.data$Group))
      
      
      downsampled_seurat_obj <- NormalizeData(downsampled_seurat_obj)
      downsampled_seurat_obj <- FindVariableFeatures(downsampled_seurat_obj)
      downsampled_seurat_obj <- ScaleData(downsampled_seurat_obj)
      downsampled_seurat_obj <- RunPCA(downsampled_seurat_obj)
      downsampled_seurat_obj <- FindNeighbors(downsampled_seurat_obj)
      
      
      # save file to h5ad for scAce clustering
      # this also fixes a bug in CHOIR using Seuratv5
      downsampled_seurat_obj[["RNA3"]] <- as(object = downsampled_seurat_obj[["RNA"]], Class = "Assay")
      DefaultAssay(downsampled_seurat_obj) <- "RNA3"
      downsampled_seurat_obj[["RNA"]] <- NULL
      downsampled_seurat_obj[["RNA"]] <- downsampled_seurat_obj[["RNA3"]]
      DefaultAssay(downsampled_seurat_obj) <- "RNA"
      downsampled_seurat_obj[["RNA3"]] <- NULL
      
      filename = stringr::str_interp("h5ad_dir/simulated_${total_num_cells}_cells_${length(cell_type_proportions)}_groups_downsampled_${num_downsampled}_replicate_${i}.h5Seurat")
      SaveH5Seurat(downsampled_seurat_obj, filename = filename)
      Convert(filename, dest = "h5ad")
      
      
      
      cores = 24
      
      # run recall
      print("Running recall")
      downsampled_seurat_obj <- recall::FindClustersRecall(downsampled_seurat_obj, cores=cores, reduction_percentage = 0.1)
      
      # run sc-SHC
      print("Running sc-SHC")
      scSHC_clusters <- scSHC(GetAssayData(downsampled_seurat_obj,
                                           assay = "RNA", layer = "counts")[Seurat::VariableFeatures(downsampled_seurat_obj),],
                              num_features = length(VariableFeatures(downsampled_seurat_obj)),
                              num_PCs = 10,
                              cores = cores)[[1]]
      
      # run CHOIR
      print("Running CHOIR")
      downsampled_seurat_obj <- CHOIR(downsampled_seurat_obj, 
                          n_cores = cores,
                          reduction = downsampled_seurat_obj@reductions$pca@cell.embeddings[, 1:10],
                          var_features = Seurat::VariableFeatures(downsampled_seurat_obj))
      
      
      
      # store cluster labels
      downsampled_seurat_obj[['scSHC_clusters']] <- scSHC_clusters
      downsampled_seurat_obj[["CHOIR_clusters"]] <- downsampled_seurat_obj@meta.data$CHOIR_clusters_0.05
      
      
      print("Smallest cluster size:")
      print(num_downsampled)
      print("Num Clusters:")
      print(length(levels(downsampled_seurat_obj@meta.data$recall_clusters)))

      num_groups[j] <- length(cell_type_proportions)
      original_num_cells[j] <- total_num_cells
      downsampled_count[j] <- num_downsampled
      replicate[j] <- i
      
      recall_num_clusters[j] <- length(unique(downsampled_seurat_obj@meta.data$recall_clusters))
      scSHC_num_clusters[j] <- length(unique(downsampled_seurat_obj@meta.data$scSHC_clusters))
      CHOIR_num_clusters[j] <- length(unique(downsampled_seurat_obj@meta.data$CHOIR_clusters))
      
    }
  }
  
  return(data.frame(num_groups, original_num_cells, replicate, downsampled_count, recall_num_clusters, scSHC_num_clusters, CHOIR_num_clusters))
}

Now, we actually loop over the simulation parameters and run the simulations.


num_groups_list <- c(5, 10)
num_cells_list <- c(5000, 10000)

for (num_groups in num_groups_list) {
  for (num_cells in num_cells_list) {
    print("num_groups")
    print(num_groups)
    print("num_cells")
    print(num_cells)
    
    filename = stringr::str_interp("rare_cell_type_${num_cells}_cells_${num_groups}.csv")
    print(filename)
    
    cell_type_proportions <- rep(1 / num_groups, num_groups)
    
    print("cell_type_proportions")
    print(cell_type_proportions)
    
    downsample_max <- min(num_cells / num_groups - 50, 500)
    
    print("downsample_max")
    print(downsample_max)
    print(seq(from=100, to=downsample_max, by=50))
    
    ret <- rare_cell_type_simulation(total_num_cells = num_cells,
                                     num_genes = 1000,
                                     cell_type_proportions = cell_type_proportions,
                                     downsampled_counts = seq(from=100, to=downsample_max, by=50),
                                     #downsampled_counts = c(100, 150, 200, 250, 300, 350, 400, 450, 500),
                                     num_replicates = 1)
    
    print("writing output csv")
    print(filename)
    
    
    write.csv(ret, file = filename)
    print(ret)
  }
}