Skip to contents
num_cells_list <- c(5000, 10000, 25000)
num_groups_list <- c(1, 5, 10)

cluster_metrics_df <- data.frame()
for (cells in num_cells_list) {
  for (group in num_groups_list) {
    
    filename <- stringr::str_glue("simulations_num_groups_{group}_num_cells_{cells}.csv")
    
    iteration_df <- read.csv(filename, row.names = 1)
    iteration_df$num_cells <- cells
    iteration_df$num_groups <- group
    
    cluster_metrics_df <- rbind(cluster_metrics_df, iteration_df)
  }
}
# 50k cell replicates
for (group in num_groups_list) {
  for (replicate in 1:5) {
    filename <- stringr::str_glue("simulations_num_groups_{group}_num_cells_50000_replicate_{replicate}.csv")
    iteration_df <- read.csv(filename, row.names = 1)
    #iteration_df$num_cells <- 50000
    iteration_df$num_groups <- group
    
    cluster_metrics_df <- rbind(cluster_metrics_df, iteration_df)
  }
}

scAce was run separately because it is implemented in Python. We load the results here.

# scAce 
scAce_cluster_metrics_df <- read.csv("scAce_simulation_clustering_metrics.csv", row.names = 1)
scAce_cluster_metrics_df <- subset(scAce_cluster_metrics_df, num_groups < 20)
cluster_metrics_df <- rbind(cluster_metrics_df, scAce_cluster_metrics_df)
# replace ARI NA with 1
# these are from dividing by zero in the single group, but in fact the methods got it 100% correct
cluster_metrics_df$ari[is.na(cluster_metrics_df$ari)] <- 1
# add NAs for copula models that didn't scale
copula_na_df <-read.csv("NA_for_copula.csv", row.names = 1)
copula_na_df <- subset(copula_na_df, num_groups < 20)
cluster_metrics_df <- rbind(cluster_metrics_df, copula_na_df)
# rename methods
cluster_metrics_df$method <- factor(cluster_metrics_df$method)
levels(cluster_metrics_df$method)[levels(cluster_metrics_df$method)=="recall_ZIP"] <- "recall+ZIP" 
levels(cluster_metrics_df$method)[levels(cluster_metrics_df$method)=="recall_NB"] <- "recall+NB" 
levels(cluster_metrics_df$method)[levels(cluster_metrics_df$method)=="recall_Poisson-copula"] <- "recall+\nPoisson-copula" 
levels(cluster_metrics_df$method)[levels(cluster_metrics_df$method)=="recall_NB-copula"] <- "recall+\nNB-copula" 
levels(cluster_metrics_df$method)[levels(cluster_metrics_df$method)=="recall_countsplit"] <- "recall+\ncountsplit" 
levels(cluster_metrics_df$method)[levels(cluster_metrics_df$method)=="scSHC"] <- "sc-SHC" 
levels(cluster_metrics_df$method)[levels(cluster_metrics_df$method)=="scace_cluster"] <- "scAce" 




cluster_metrics_df$method <- factor(cluster_metrics_df$method, levels = c("recall+ZIP",
                                                                          "recall+NB",
                                                                          "recall+\nPoisson-copula",
                                                                          "recall+\nNB-copula",
                                                                          "recall+\ncountsplit",
                                                                          "sc-SHC",
                                                                          "CHOIR",
                                                                          "scAce"))

cluster_metrics_df$num_groups <- as.factor(cluster_metrics_df$num_groups)
small_text_size <- 20
large_text_size <- 24

facet_labels <- c(
  `5000` = "N = 5K",
  `10000` = "N = 10K",
  `25000` = "N = 25K",
  `50000` = "N = 50K")


plot_results_heatmap <- function(cluster_metrics_df, statistic) {
  results_heatmap <- ggplot2::ggplot(cluster_metrics_df, aes(y=num_groups, x=method, fill=!!rlang::sym(statistic)
, color="NA")) + 
    facet_grid(rows = vars(num_cells), labeller=as_labeller(facet_labels)) + 
    geom_tile(#color = "black",
      lwd = 1,
      linetype = 1,
      width=0.7,
      height=0.7) +
    #scale_fill_gradient(low = "white", high = "black") + 
    #scale_fill_gradientn(colours = c("red", "yellow", "green")) + 
    scale_fill_gradientn(name = "ARI", colours = c("lightgrey", "gold", "darkblue"), na.value="white") + 
    xlab("Method") +
    ylab("Number of Groups") +
    theme_bw() +
    ggplot2::theme(axis.ticks.x = ggplot2::element_blank(), 
                   axis.text.x = ggplot2::element_text(size = small_text_size, family = "Courier", colour = "black", face="bold",
                                                       angle = 0, vjust = 0.5, hjust=0.5),
                   axis.text.y = ggplot2::element_text(size = small_text_size, colour = "black"),
                   axis.title.y = ggplot2::element_text(size = large_text_size),
                   axis.title.x = ggplot2::element_blank(),
                   strip.text = ggplot2::element_text(size = small_text_size), 
                   legend.text = ggplot2::element_text(size = small_text_size),
                   legend.title = ggplot2::element_text(size = small_text_size)) +
    scale_color_manual(values=NA, na.value = "black") +              
    guides(color=guide_legend("No data", override.aes=list(fill="white")))
  
  return(results_heatmap)
}
  
  
    




ari_results_heatmap <- plot_results_heatmap(cluster_metrics_df, "ari")
v_measure_results_heatmap <- plot_results_heatmap(cluster_metrics_df, "v_measure")
fmi_results_heatmap <- plot_results_heatmap(cluster_metrics_df, "fowlkes_mallows")
homogeneity_results_heatmap <- plot_results_heatmap(cluster_metrics_df, "homogeneity")
completeness_results_heatmap <- plot_results_heatmap(cluster_metrics_df, "completeness")
jaccard_results_heatmap <- plot_results_heatmap(cluster_metrics_df, "jaccard")

ggsave("simulation_heatmap_ari.png", ari_results_heatmap, width = 4 * 1440, height = 2 * 1440, units = 'px')
ggsave("simulation_heatmap_vmeasure.png", v_measure_results_heatmap, width = 4 * 1440, height = 2 * 1440, units = 'px')
ggsave("simulation_heatmap_fmi.png", fmi_results_heatmap, width = 4 * 1440, height = 2 * 1440, units = 'px')
ggsave("simulation_heatmap_homogeneity.png", homogeneity_results_heatmap, width = 4 * 1440, height = 2 * 1440, units = 'px')
ggsave("simulation_heatmap_completeness.png", completeness_results_heatmap, width = 4 * 1440, height = 2 * 1440, units = 'px')
ggsave("simulation_heatmap_jaccard.png", jaccard_results_heatmap, width = 4 * 1440, height = 2 * 1440, units = 'px')















# timing memory plotting
library(dplyr)
library(ggplot2)

num_cells_list <- c(5000, 10000, 25000)
num_groups_list <- c(1, 5, 10)

timing_memory_df <- data.frame()
for (cells in num_cells_list) {
  for (group in num_groups_list) {
    
    filename <- stringr::str_glue("simulations_num_groups_{group}_num_cells_{cells}_timing_memory.csv")
    
    iteration_df <- read.csv(filename, row.names = 1)
    iteration_df$num_cells <- cells
    iteration_df$num_groups <- group
    
    timing_memory_df <- rbind(timing_memory_df, iteration_df)
  }
}

# 50k cell replicates
for (group in num_groups_list) {
  for (replicate in 1:5) {
    filename <- stringr::str_glue("simulations_num_groups_{group}_num_cells_50000_replicate_{replicate}_timing_memory.csv")
    iteration_df <- read.csv(filename, row.names = 1)
    iteration_df$num_cells <- 50000
    iteration_df$num_groups <- group
    
    timing_memory_df <- rbind(timing_memory_df, iteration_df)
    
  }
}

# scAce 

gpu_scAce_timing_df <- read.csv("gpu_simulations_scAce.csv", row.names = 1)
gpu_scAce_timing_df$method <- "scAce-GPU"
gpu_scAce_timing_df$replicate <- NULL # not in other output
gpu_scAce_timing_df <- subset(gpu_scAce_timing_df, num_groups < 20)

scAce_timing_df <- read.csv("simulations_scAce.csv", row.names = 1)
scAce_timing_df$replicate <- NULL # not in other output
scAce_timing_df <- subset(scAce_timing_df, num_groups < 20)

timing_memory_df <- rbind(timing_memory_df, gpu_scAce_timing_df, scAce_timing_df)



timing_memory_df$memory <- timing_memory_df$memory * 0.00104858 # convert from mebibytes to gigabytes

# rename countsplit to have a lowercases r in recall
timing_memory_df$method <- factor(timing_memory_df$method)
levels(timing_memory_df$method)[levels(timing_memory_df$method)=="Recall+countsplit"] <- "recall+countsplit" 


timing_summary_df <- timing_memory_df %>% group_by(method, num_cells, num_groups) %>% dplyr::summarize(
  mean = mean(time), 
  sd = sd(time))

memory_summary_df <- timing_memory_df %>% group_by(method, num_cells, num_groups) %>% dplyr::summarize(
  mean = mean(memory), 
  sd = sd(memory))

# order levels
memory_summary_df$method <- factor(memory_summary_df$method, levels = c("recall+ZIP",
                                                                        "recall+NB",
                                                                        "recall+Poisson-copula",
                                                                        "recall+NB-copula",
                                                                        "recall+countsplit",
                                                                        "sc-SHC",
                                                                        "CHOIR",
                                                                        "scAce",
                                                                        "scAce-GPU"))

timing_summary_df$method <- factor(timing_summary_df$method, levels = c("recall+ZIP",
                                                                        "recall+NB",
                                                                        "recall+Poisson-copula",
                                                                        "recall+NB-copula",
                                                                        "recall+countsplit",
                                                                        "sc-SHC",
                                                                        "CHOIR",
                                                                        "scAce",
                                                                        "scAce-GPU"))



small_text_size <- 12
large_text_size <- 16

memory_plot <- ggplot2::ggplot(memory_summary_df, ggplot2::aes(x = num_cells, y = mean, color = method)) + 
  ggplot2::geom_line(size = 1.5) + 
  ggplot2::geom_point(size = 3) +
  #ggplot2::facet_wrap(~num_groups, ncol=1) + 
  ggplot2::facet_grid(rows = vars(num_groups), cols = vars(method)) + 
  geom_errorbar(aes(x=num_cells, ymin=mean-sd, ymax=mean+sd)) + 
  ggplot2::scale_color_manual(values = c("red", "darkmagenta", "darkgreen", "deepskyblue2", "darkorange1", "grey", "black", "tan", "cadetblue")) +
  #scale_y_continuous(breaks=seq(0,150,30)) + 
  ggplot2::theme_bw() +
  ggplot2::xlab("Number of Cells") + 
  ggplot2::ylab("Peak Memory (GB)") +
  ggplot2::labs(color = "Method") + 
  ggplot2::theme(axis.text = ggplot2::element_text(size = small_text_size),
                 axis.title = ggplot2::element_text(size = large_text_size),
                 strip.text = ggplot2::element_text(size = small_text_size), 
                 legend.text = ggplot2::element_text(size = small_text_size, family = "Courier"),
                 legend.title = ggplot2::element_text(size = small_text_size),
                 legend.position = "bottom")+ 
  scale_x_continuous(labels = scales::label_number(scale = .001, suffix = "K"))




time_plot <- ggplot2::ggplot(timing_summary_df, ggplot2::aes(x = num_cells, y = mean, color = method)) + 
  ggplot2::geom_line(size = 1.5) + 
  ggplot2::geom_point(size = 3) +
  #ggplot2::facet_wrap(~num_groups, ncol=1) + 
  ggplot2::facet_grid(rows = vars(num_groups), cols = vars(method)) + 
  geom_errorbar(aes(x=num_cells, ymin=mean-sd, ymax=mean+sd)) + 
  ggplot2::scale_color_manual(values = c("red", "darkmagenta", "darkgreen", "deepskyblue2", "darkorange1", "grey", "black", "tan", "cadetblue")) +
  #scale_y_continuous(breaks=seq(0,150,30)) + 
  ggplot2::theme_bw() +
  ggplot2::xlab("Number of Cells") + 
  ggplot2::ylab("Time Taken (Min)") +
  ggplot2::labs(color = "Method") + 
  ggplot2::theme(axis.text = ggplot2::element_text(size = small_text_size),
                 axis.title = ggplot2::element_text(size = large_text_size),
                 strip.text = ggplot2::element_text(size = small_text_size), 
                 legend.text = ggplot2::element_text(size = small_text_size, family = "Courier"),
                 legend.title = ggplot2::element_text(size = small_text_size),
                 legend.position = "bottom")+ 
  scale_x_continuous(labels = scales::label_number(scale = .001, suffix = "K"))

Finally, save the plots.


ggsave("simulations_timing_plot.png", time_plot, width = 4 * 1440, height = 2 * 1440, units = 'px')
ggsave("simulations_memory_plot.png", memory_plot, width = 4 * 1440, height = 2 * 1440, units = 'px')
Timing Memory