MCMCmagic

Author

Sarah Urbut

Published

August 30, 2024

First how does our data look?


Attaching package: 'dplyr'
The following objects are masked from 'package:stats':

    filter, lag
The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union

Attaching package: 'reshape'
The following object is masked from 'package:dplyr':

    rename

Attaching package: 'MASS'
The following object is masked from 'package:dplyr':

    select

Attaching package: 'tidyr'
The following objects are masked from 'package:reshape':

    expand, smiths
Shape of mu_d_matrix: 12 56 
Shape of mean_lambda: 10 2 
Shape of mu_i: 10 56 

Now we initialize, here’s an adaptvie way to initialize length and var:

Code
library(rTensor)
library(MASS)

# Warm initialization function
warm_gp_initialization <- function(Y, mu_d_logit, g_i, n_topics) {
  N <- dim(Y)[1]
  D <- dim(Y)[2]
  T <- dim(Y)[3]
  P <- ncol(g_i)

  mu_i_mean <- matrix(colMeans(do.call(rbind, mu_d_logit)), nrow = 1)
  deviations <- matrix(rnorm(N * T, mean = 0, sd = 0.0001), nrow = N)
  mu_i <- t(apply(deviations, 1, function(x) {
    mu_i_mean + x  # No need for pmax here as we're working on logit scale
  }))

  # Convert Y to logit scale, with smoothing to avoid infinite values
  Y_smooth <- (Y * (N * D * T - 1) + 0.5) / (N * D * T)
  Y_logit <- log(Y_smooth / (1 - Y_smooth))

  # Center Y_logit by subtracting mu_d_logit
  Y_centered <- array(0, dim = c(N, D, T))
  for (d in 1:D) {
    Y_centered[, d, ] <- Y_logit[, d, ] - matrix(mu_d_logit[[d]], nrow = N, ncol = T, byrow = TRUE)
  }

  # Create a tensor object
  Y_tensor <- as.tensor(Y_centered)

  # Perform Tucker decomposition
  tucker_result <- tucker(Y_tensor, ranks = c(n_topics, D, T))

  # Initialize Lambda (N x K x T)
  Lambda <- array(0, dim = c(N, n_topics, T))
  for (k in 1:n_topics) {
    Lambda[, k, ] <- tucker_result$U[[1]][, k] %o% tucker_result$U[[3]][, k]
  }

  # Initialize Phi (K x D x T)
  Phi <- array(0, dim = c(n_topics, D, T))
  for (k in 1:n_topics) {
    Phi[k, , ] <- tucker_result$U[[2]][, k] %o% tucker_result$U[[3]][, k]
  }

  # Initialize Gamma using a conservative approach
  Gamma <- matrix(rnorm(n_topics * P, mean = 0, sd = 0.1), nrow = n_topics, ncol = P)

  # Fixed hyperparameters for length and variance scales
  length_scales_lambda <- rep(20, n_topics)
  var_scales_lambda <- rep(1, n_topics)
  length_scales_phi <- rep(20, n_topics * D)
  var_scales_phi <- rep(1, n_topics * D)

  return(list(
    Lambda = Lambda,
    Phi = Phi,
    Gamma = Gamma,
    mu_i = mu_i,
    length_scales_lambda = length_scales_lambda,
    var_scales_lambda = var_scales_lambda,
    length_scales_phi = length_scales_phi,
    var_scales_phi = var_scales_phi
  ))
}
Code
# GP prior function
log_gp_prior_vec <- function(x, mean, length_scale, var_scale) {
  T <- length(x)
  time_points <- 1:T
  K <- var_scale * exp(-0.5 * outer(time_points, time_points, "-")^2 / length_scale^2)
  K <- K + diag(1e-6, T)

  # Ensure mean is the same length as x
  if (length(mean) == 1) {
    mean <- rep(mean, T)
  } else if (length(mean) != T) {
    stop(paste("Length of mean must be 1 or equal to length of x. Length of x:", T, "Length of mean:", length(mean)))
  }

  centered_x <- x - mean

  # Use more stable computation for log determinant
  log_det_K <- sum(log(eigen(K, symmetric = TRUE, only.values = TRUE)$values))

  # Use Cholesky decomposition for more stable matrix inversion
  L <- chol(K)
  quad_form <- sum(backsolve(L, centered_x, transpose = TRUE)^2)

  return(-0.5 * (log_det_K + quad_form + T * log(2 * base::pi)))
}
Code
# Optional likelihood function
log_likelihood_vec <- function(y, Lambda, Phi, mu_d_logit, mu_i = NULL, use_survival = TRUE) {
  n_individuals <- dim(Lambda)[1]
  n_topics <- dim(Lambda)[2]
  T <- dim(Lambda)[3]
  n_diseases <- dim(Phi)[2]

  logit_pi <- array(0, dim = c(n_individuals, n_diseases, T))

  for (k in 1:n_topics) {
    for (t in 1:T) {
      logit_pi[, , t] <- logit_pi[, , t] + Lambda[, k, t] %*% t(Phi[k, , t])
    }
  }

  for (d in 1:n_diseases) {
    logit_pi[, d, ] <- logit_pi[, d, ] + matrix(mu_d_logit[[d]], nrow = n_individuals, ncol = T, byrow = TRUE)
    if (!is.null(mu_i)) {
      logit_pi[, d, ] <- logit_pi[, d, ] + mu_i
    }
  }

  pi <- 1 / (1 + exp(-logit_pi))

  if (!use_survival) {
    return(sum(dbinom(y, 1, pi, log = TRUE)))
  }

  log_lik <- 0
  for (i in 1:n_individuals) {
    for (d in 1:n_diseases) {
      at_risk <- which(cumsum(y[i, d, ]) == 0)
      if (length(at_risk) > 0) {
        event_time <- max(at_risk) + 1
        if (event_time <= T) { # Event occurred
          log_lik <- log_lik + log(pi[i, d, event_time])
        }
        log_lik <- log_lik + sum(log(1 - pi[i, d, at_risk]))
      } else {
        # Event occurred at the first time point
        log_lik <- log_lik + log(pi[i, d, 1])
      }
    }
  }

  return(log_lik)
}
Code
adaptive_random_walk_proposal <- function(current_value, proposal_sd) {
  proposed_value <- current_value + rnorm(length(current_value), 0, proposal_sd)
  return(proposed_value)
}
Code
mcmc_sampler_flexible <- function(y, mu_d_logit, g_i, n_iterations, initial_values,
                                  use_mu_i = TRUE, use_survival = TRUE, estimate_scales = FALSE,
                                  proposal_method = "random_walk",
                                  alpha_lambda, beta_lambda, alpha_sigma, beta_sigma,
                                  alpha_phi, beta_phi, alpha_sigma_phi, beta_sigma_phi,
                                  alpha_Gamma, beta_Gamma) {

  current_state <- initial_values
  n_individuals <- dim(current_state$Lambda)[1]
  n_topics <- dim(current_state$Lambda)[2]
  T <- dim(current_state$Lambda)[3]
  n_diseases <- dim(current_state$Phi)[2]
  P <- ncol(g_i)

  # Initialize storage for samples
  samples <- list(
    Lambda = array(0, dim = c(n_iterations, dim(current_state$Lambda))),
    Phi = array(0, dim = c(n_iterations, dim(current_state$Phi))),
    Gamma = array(0, dim = c(n_iterations, dim(current_state$Gamma)))
  )

  # Initialize adaptive proposal standard deviations
  proposal_sd <- list(
    Lambda = array(0.1, dim = dim(current_state$Lambda)),
    Phi = array(0.1, dim = dim(current_state$Phi)),
    Gamma = matrix(0.1, nrow = nrow(current_state$Gamma), ncol = ncol(current_state$Gamma))
  )

  if (use_mu_i) {
    samples$mu_i <- array(0, dim = c(n_iterations, dim(current_state$mu_i)))
    proposal_sd$mu_i <- array(0.1, dim = dim(current_state$mu_i))
  }

  if (estimate_scales) {
    samples$length_scales_lambda <- matrix(0, nrow = n_iterations, ncol = n_topics)
    samples$var_scales_lambda <- matrix(0, nrow = n_iterations, ncol = n_topics)
    samples$length_scales_phi <- matrix(0, nrow = n_iterations, ncol = n_topics * n_diseases)
    samples$var_scales_phi <- matrix(0, nrow = n_iterations, ncol = n_topics * n_diseases)
    proposal_sd$length_scales_lambda <- rep(0.1, n_topics)
    proposal_sd$var_scales_lambda <- rep(0.1, n_topics)
    proposal_sd$length_scales_phi <- rep(0.1, n_topics * n_diseases)
    proposal_sd$var_scales_phi <- rep(0.1, n_topics * n_diseases)
  }

  start_time=Sys.time()

  for (iter in 1:n_iterations) {
    # Update Lambda
    for (i in 1:n_individuals) {
      for (k in 1:n_topics) {
        proposed_Lambda <- current_state$Lambda
        if (proposal_method == "random_walk") {
          proposed_Lambda[i,k,] <- adaptive_random_walk_proposal(current_state$Lambda[i,k,], proposal_sd$Lambda[i,k,])
        } else if (proposal_method == "gp") {
          proposed_Lambda[i,k,] <- propose_gp(current_state$Lambda[i,k,],
                                              current_state$length_scales_lambda[k],
                                              current_state$var_scales_lambda[k])
        }

        current_log_lik <- log_likelihood_vec(y, current_state$Lambda, current_state$Phi,
                                              mu_d_logit,
                                              if(use_mu_i) current_state$mu_i else NULL,
                                              use_survival)
        proposed_log_lik <- log_likelihood_vec(y, proposed_Lambda, current_state$Phi,
                                               mu_d_logit,
                                               if(use_mu_i) current_state$mu_i else NULL,
                                               use_survival)

        current_log_prior <- log_gp_prior_vec(current_state$Lambda[i,k,],
                                              current_state$Gamma[k,] %*% g_i[i,],
                                              current_state$length_scales_lambda[k],
                                              current_state$var_scales_lambda[k])
        proposed_log_prior <- log_gp_prior_vec(proposed_Lambda[i,k,],
                                               current_state$Gamma[k,] %*% g_i[i,],
                                               current_state$length_scales_lambda[k],
                                               current_state$var_scales_lambda[k])

        log_accept_ratio <- (proposed_log_lik + proposed_log_prior) -
                            (current_log_lik + current_log_prior)

        if (log(runif(1)) < log_accept_ratio) {
          current_state$Lambda <- proposed_Lambda
          if (proposal_method == "random_walk") {
            proposal_sd$Lambda[i,k,] <- proposal_sd$Lambda[i,k,] * 1.01
          }
        } else if (proposal_method == "random_walk") {
          proposal_sd$Lambda[i,k,] <- proposal_sd$Lambda[i,k,] * 0.99
        }
      }
    }

    # Update Phi
    for (k in 1:n_topics) {
      for (d in 1:n_diseases) {
        proposed_Phi <- current_state$Phi
        if (proposal_method == "random_walk") {
          proposed_Phi[k,d,] <- adaptive_random_walk_proposal(current_state$Phi[k,d,], proposal_sd$Phi[k,d,])
        } else if (proposal_method == "gp") {
          proposed_Phi[k,d,] <- propose_gp(current_state$Phi[k,d,],
                                           current_state$length_scales_phi[(k-1)*n_diseases + d],
                                           current_state$var_scales_phi[(k-1)*n_diseases + d])
        }

        current_log_lik <- log_likelihood_vec(y, current_state$Lambda, current_state$Phi,
                                              mu_d_logit,
                                              if(use_mu_i) current_state$mu_i else NULL,
                                              use_survival)
        proposed_log_lik <- log_likelihood_vec(y, current_state$Lambda, proposed_Phi,
                                               mu_d_logit,
                                               if(use_mu_i) current_state$mu_i else NULL,
                                               use_survival)

        current_log_prior <- log_gp_prior_vec(current_state$Phi[k,d,],
                                              0,
                                              current_state$length_scales_phi[(k-1)*n_diseases + d],
                                              current_state$var_scales_phi[(k-1)*n_diseases + d])
        proposed_log_prior <- log_gp_prior_vec(proposed_Phi[k,d,],
                                               0,
                                               current_state$length_scales_phi[(k-1)*n_diseases + d],
                                               current_state$var_scales_phi[(k-1)*n_diseases + d])

        log_accept_ratio <- (proposed_log_lik + proposed_log_prior) -
                            (current_log_lik + current_log_prior)

        if (log(runif(1)) < log_accept_ratio) {
          current_state$Phi <- proposed_Phi
          if (proposal_method == "random_walk") {
            proposal_sd$Phi[k,d,] <- proposal_sd$Phi[k,d,] * 1.01
          }
        } else if (proposal_method == "random_walk") {
          proposal_sd$Phi[k,d,] <- proposal_sd$Phi[k,d,] * 0.99
        }
      }
    }

    # Update Gamma
    for (k in 1:n_topics) {
      proposed_Gamma <- current_state$Gamma
      if (proposal_method == "random_walk") {
        proposed_Gamma[k,] <- adaptive_random_walk_proposal(current_state$Gamma[k,], proposal_sd$Gamma[k,])
      } else if (proposal_method == "gp") {
        proposed_Gamma[k,] <- mvrnorm(1, current_state$Gamma[k,], diag(0.1, P))
      }

      current_log_prior <- sum(dnorm(current_state$Gamma[k,], 0, sqrt(alpha_Gamma / beta_Gamma), log = TRUE))
      proposed_log_prior <- sum(dnorm(proposed_Gamma[k,], 0, sqrt(alpha_Gamma / beta_Gamma), log = TRUE))

      current_log_lik <- sum(sapply(1:n_individuals, function(i) {
        log_gp_prior_vec(current_state$Lambda[i,k,],
                         current_state$Gamma[k,] %*% g_i[i,],
                         current_state$length_scales_lambda[k],
                         current_state$var_scales_lambda[k])
      }))
      proposed_log_lik <- sum(sapply(1:n_individuals, function(i) {
        log_gp_prior_vec(current_state$Lambda[i,k,],
                         proposed_Gamma[k,] %*% g_i[i,],
                         current_state$length_scales_lambda[k],
                         current_state$var_scales_lambda[k])
      }))

      log_accept_ratio <- (proposed_log_lik + proposed_log_prior) -
                          (current_log_lik + current_log_prior)

      if (log(runif(1)) < log_accept_ratio) {
        current_state$Gamma <- proposed_Gamma
        if (proposal_method == "random_walk") {
          proposal_sd$Gamma[k,] <- proposal_sd$Gamma[k,] * 1.01
        }
      } else if (proposal_method == "random_walk") {
        proposal_sd$Gamma[k,] <- proposal_sd$Gamma[k,] * 0.99
      }
    }

    # Update mu_i if used
    if (use_mu_i) {
      for (i in 1:n_individuals) {
        proposed_mu_i <- current_state$mu_i
        if (proposal_method == "random_walk") {
          proposed_mu_i[i,] <- adaptive_random_walk_proposal(current_state$mu_i[i,], proposal_sd$mu_i[i,])
        } else if (proposal_method == "gp") {
          proposed_mu_i[i,] <- propose_gp(current_state$mu_i[i,], 20, 1)  # Using fixed length scale and variance
        }

        current_log_lik <- log_likelihood_vec(y, current_state$Lambda, current_state$Phi,
                                              mu_d_logit, current_state$mu_i, use_survival)
        proposed_log_lik <- log_likelihood_vec(y, current_state$Lambda, current_state$Phi,
                                               mu_d_logit, proposed_mu_i, use_survival)

        mu_d_logit_mean <- colMeans(do.call(rbind, mu_d_logit))
        current_log_prior <- log_gp_prior_vec(current_state$mu_i[i,], mu_d_logit_mean, 20, 1)
        proposed_log_prior <- log_gp_prior_vec(proposed_mu_i[i,], mu_d_logit_mean, 20, 1)

        log_accept_ratio <- (proposed_log_lik + proposed_log_prior) -
                            (current_log_lik + current_log_prior)

        if (log(runif(1)) < log_accept_ratio) {
          current_state$mu_i <- proposed_mu_i
          if (proposal_method == "random_walk") {
            proposal_sd$mu_i[i,] <- proposal_sd$mu_i[i,] * 1.01
          }
        } else if (proposal_method == "random_walk") {
          proposal_sd$mu_i[i,] <- proposal_sd$mu_i[i,] * 0.99
        }
      }
    }

    # Update scales if estimating
    if (estimate_scales) {
      # Update length_scales_lambda
      for (k in 1:n_topics) {
        proposed_length_scale <- current_state$length_scales_lambda[k]
        if (proposal_method == "random_walk") {
          proposed_length_scale <- adaptive_random_walk_proposal(current_state$length_scales_lambda[k], proposal_sd$length_scales_lambda[k])
        } else if (proposal_method == "gp") {
          proposed_length_scale <- exp(log(current_state$length_scales_lambda[k]) + rnorm(1, 0, 0.1))
        }

        current_log_prior <- dgamma(current_state$length_scales_lambda[k], shape = alpha_lambda, rate = beta_lambda, log = TRUE)
        proposed_log_prior <- dgamma(proposed_length_scale, shape = alpha_lambda, rate = beta_lambda, log = TRUE)

        current_log_lik <- sum(sapply(1:n_individuals, function(i) {
          log_gp_prior_vec(current_state$Lambda[i,k,],
                           current_state$Gamma[k,] %*% g_i[i,],
                           current_state$length_scales_lambda[k],
                           current_state$var_scales_lambda[k])
        }))
        proposed_log_lik <- sum(sapply(1:n_individuals, function(i) {
          log_gp_prior_vec(current_state$Lambda[i,k,],
                           current_state$Gamma[k,] %*% g_i[i,],
                           proposed_length_scale,
                           current_state$var_scales_lambda[k])
        }))

        log_accept_ratio <- (proposed_log_lik + proposed_log_prior) -
                            (current_log_lik + current_log_prior)

        if (log(runif(1)) < log_accept_ratio) {
          current_state$length_scales_lambda[k] <- proposed_length_scale
          if (proposal_method == "random_walk") {
            proposal_sd$length_scales_lambda[k] <- proposal_sd$length_scales_lambda[k] * 1.01
          }
        } else if (proposal_method == "random_walk") {
          proposal_sd$length_scales_lambda[k] <- proposal_sd$length_scales_lambda[k] * 0.99
        }
      }

      # Update var_scales_lambda
      for (k in 1:n_topics) {
        proposed_var_scale <- current_state$var_scales_lambda[k]
        if (proposal_method == "random_walk") {
          proposed_var_scale <- adaptive_random_walk_proposal(current_state$var_scales_lambda[k], proposal_sd$var_scales_lambda[k])
        } else if (proposal_method == "gp") {
          proposed_var_scale <- exp(log(current_state$var_scales_lambda[k]) + rnorm(1, 0, 0.1))
        }

        current_log_prior <- dgamma(current_state$var_scales_lambda[k], shape = alpha_sigma, rate = beta_sigma, log = TRUE)
        proposed_log_prior <- dgamma(proposed_var_scale, shape = alpha_sigma, rate = beta_sigma, log = TRUE)

        current_log_lik <- sum(sapply(1:n_individuals, function(i) {
          log_gp_prior_vec(current_state$Lambda[i,k,],
                           current_state$Gamma[k,] %*% g_i[i,],
                           current_state$length_scales_lambda[k],
                           current_state$var_scales_lambda[k])
        }))
        proposed_log_lik <- sum(sapply(1:n_individuals, function(i) {
          log_gp_prior_vec(current_state$Lambda[i,k,],
                           current_state$Gamma[k,] %*% g_i[i,],
                           current_state$length_scales_lambda[k],
                           proposed_var_scale)
        }))

        log_accept_ratio <- (proposed_log_lik + proposed_log_prior) -
                            (current_log_lik + current_log_prior)

        if (log(runif(1)) < log_accept_ratio) {
          current_state$var_scales_lambda[k] <- proposed_var_scale
          if (proposal_method == "random_walk") {
            proposal_sd$var_scales_lambda[k] <- proposal_sd$var_scales_lambda[k] * 1.01
          }
        } else if (proposal_method == "random_walk") {
          proposal_sd$var_scales_lambda[k] <- proposal_sd$var_scales_lambda[k] * 0.99
        }
      }

      # Update length_scales_phi and var_scales_phi
      for (k in 1:n_topics) {
        for (d in 1:n_diseases) {
          idx <- (k-1)*n_diseases + d

          # Update length_scales_phi
          proposed_length_scale <- current_state$length_scales_phi[idx]
          if (proposal_method == "random_walk") {
            proposed_length_scale <- adaptive_random_walk_proposal(current_state$length_scales_phi[idx], proposal_sd$length_scales_phi[idx])
          } else if (proposal_method == "gp") {
            proposed_length_scale <- exp(log(current_state$length_scales_phi[idx]) + rnorm(1, 0, 0.1))
          }

          current_log_prior <- dgamma(current_state$length_scales_phi[idx], shape = alpha_phi, rate = beta_phi, log = TRUE)
          proposed_log_prior <- dgamma(proposed_length_scale, shape = alpha_phi, rate = beta_phi, log = TRUE)

          current_log_lik <- log_gp_prior_vec(current_state$Phi[k,d,], 0,
                                              current_state$length_scales_phi[idx],
                                              current_state$var_scales_phi[idx])
          proposed_log_lik <- log_gp_prior_vec(current_state$Phi[k,d,], 0,
                                               proposed_length_scale,
                                               current_state$var_scales_phi[idx])

          log_accept_ratio <- (proposed_log_lik + proposed_log_prior) -
                              (current_log_lik + current_log_prior)

          if (log(runif(1)) < log_accept_ratio) {
            current_state$length_scales_phi[idx] <- proposed_length_scale
            if (proposal_method == "random_walk") {
              proposal_sd$length_scales_phi[idx] <- proposal_sd$length_scales_phi[idx] * 1.01
            }
          } else if (proposal_method == "random_walk") {
            proposal_sd$length_scales_phi[idx] <- proposal_sd$length_scales_phi[idx] * 0.99
          }

          # Update var_scales_phi
          proposed_var_scale <- current_state$var_scales_phi[idx]
          if (proposal_method == "random_walk") {
            proposed_var_scale <- adaptive_random_walk_proposal(current_state$var_scales_phi[idx], proposal_sd$var_scales_phi[idx])
          } else if (proposal_method == "gp") {
            proposed_var_scale <- exp(log(current_state$var_scales_phi[idx]) + rnorm(1, 0, 0.1))
          }

          current_log_prior <- dgamma(current_state$var_scales_phi[idx], shape = alpha_sigma_phi, rate = beta_sigma_phi, log = TRUE)
          proposed_log_prior <- dgamma(proposed_var_scale, shape = alpha_sigma_phi, rate = beta_sigma_phi, log = TRUE)

          current_log_lik <- log_gp_prior_vec(current_state$Phi[k,d,], 0,
                                              current_state$length_scales_phi[idx],
                                              current_state$var_scales_phi[idx])
          proposed_log_lik <- log_gp_prior_vec(current_state$Phi[k,d,], 0,
                                               current_state$length_scales_phi[idx],
                                               proposed_var_scale)

          log_accept_ratio <- (proposed_log_lik + proposed_log_prior) -
                              (current_log_lik + current_log_prior)

          if (log(runif(1)) < log_accept_ratio) {
            current_state$var_scales_phi[idx] <- proposed_var_scale
            if (proposal_method == "random_walk") {
              proposal_sd$var_scales_phi[idx] <- proposal_sd$var_scales_phi[idx] * 1.01
            }
          } else if (proposal_method == "random_walk") {
            proposal_sd$var_scales_phi[idx] <- proposal_sd$var_scales_phi[idx] * 0.99
          }
        }
      }
    }

    # Store samples
    samples$Lambda[iter,,,] <- current_state$Lambda
    samples$Phi[iter,,,] <- current_state$Phi
    samples$Gamma[iter,,] <- current_state$Gamma
    if (use_mu_i) samples$mu_i[iter,,] <- current_state$mu_i
    if (estimate_scales) {
      samples$length_scales_lambda[iter,] <- current_state$length_scales_lambda
      samples$var_scales_lambda[iter,] <- current_state$var_scales_lambda
      samples$length_scales_phi[iter,] <- current_state$length_scales_phi
      samples$var_scales_phi[iter,] <- current_state$var_scales_phi
    }

     if (iter %% 100 == 0) {
       print(iter)
      # current_time <- Sys.time()
      # elapsed <- as.numeric(difftime(current_time, start_time, units = "secs"))
      # estimated_total <- elapsed * (n_iterations / iter)
      # estimated_remaining <- estimated_total - elapsed
      # 
      # cat(sprintf("Iteration %d / %d (%.2f%%) completed\n",
      #             iter, n_iterations, 100 * iter / n_iterations))
      # cat(sprintf("Time elapsed: %.2f minutes\n", elapsed / 60))
      # cat(sprintf("Estimated time remaining: %.2f minutes\n\n", estimated_remaining / 60))
    }
  }

  # Print total time taken
  end_time <- Sys.time()
  total_time <- as.numeric(difftime(end_time, start_time, units = "mins"))
  cat(sprintf("\nTotal time taken: %.2f minutes\n", total_time))

  return(samples)
}
Code
# Define a custom print method for the progress object
knit_print.progress <- function(x, ...) {
  output <- sprintf("Iteration %d / %d (%.2f%%) completed\n",
                    x$iteration, x$total_iterations, x$percent_complete)
  output <- paste0(output, sprintf("Time elapsed: %.2f minutes\n", x$elapsed_time))
  output <- paste0(output, sprintf("Estimated time remaining: %.2f minutes\n\n", x$remaining_time))

  knitr::asis_output(output)
}
Code
# Load necessary libraries
library(rTensor)
library(MASS)

# Assuming you have your data (y, mu_d_logit, g_i) ready

# Step 1: Set up parameters
n_topics <- 3  # Number of topics (adjust as needed)
n_iterations <- 5000  # Number of MCMC iterations
use_mu_i <- TRUE
use_survival <- TRUE
estimate_scales <- FALSE  # We're using fixed scales

# Fixed scales
fixed_length_scale <- 15
fixed_var_scale <- 1

# Step 2: Initialize the model
initial_values <- warm_gp_initialization(y, mu_d_logit, g_i, n_topics)

|
| | 0% |
|=== | 4% |
|====== | 8% |
|======================================================================| 100%

Code
# Step 3: Modify initial values to use fixed scales
initial_values$length_scales_lambda <- rep(fixed_length_scale, n_topics)
initial_values$var_scales_lambda <- rep(fixed_var_scale, n_topics)
initial_values$length_scales_phi <- rep(fixed_length_scale, n_topics * ncol(y))
initial_values$var_scales_phi <- rep(fixed_var_scale, n_topics * ncol(y))

# Step 4: Set up hyperparameters (these won't be used for scale updates, but are still required by the function)
alpha_lambda <- beta_lambda <- alpha_sigma <- beta_sigma <- 1
alpha_phi <- beta_phi <- alpha_sigma_phi <- beta_sigma_phi <- 1
alpha_Gamma <- beta_Gamma <- 1


# Step 5: Run the MCMC sampler
# Run the MCMC sampler
mcmc_results <- mcmc_sampler_flexible(
  y = y,
  mu_d_logit = mu_d_logit,
  g_i = g_i,
  n_iterations = 5000,
  initial_values = initial_values,
  use_mu_i = TRUE,
  use_survival = TRUE,
  estimate_scales = FALSE,
  proposal_method = "random_walk",
  alpha_lambda = 1, beta_lambda = 1,
  alpha_sigma = 1, beta_sigma = 1,
  alpha_phi = 1, beta_phi = 1,
  alpha_sigma_phi = 1, beta_sigma_phi = 1,
  alpha_Gamma = 1, beta_Gamma = 1
)

[1] 100 [1] 200 [1] 300 [1] 400 [1] 500 [1] 600 [1] 700 [1] 800 [1] 900 [1] 1000 [1] 1100 [1] 1200 [1] 1300 [1] 1400 [1] 1500 [1] 1600 [1] 1700 [1] 1800 [1] 1900 [1] 2000 [1] 2100 [1] 2200 [1] 2300 [1] 2400 [1] 2500 [1] 2600 [1] 2700 [1] 2800 [1] 2900 [1] 3000 [1] 3100 [1] 3200 [1] 3300 [1] 3400 [1] 3500 [1] 3600 [1] 3700 [1] 3800 [1] 3900 [1] 4000 [1] 4100 [1] 4200 [1] 4300 [1] 4400 [1] 4500 [1] 4600 [1] 4700 [1] 4800 [1] 4900 [1] 5000

Total time taken: 19.86 minutes

Code
library(coda)
analyze_results <- function(mcmc_results, parameter) {
  # Extract the parameter samples
  samples <- mcmc_results[[parameter]]

  # Get dimensions
  dim_samples <- dim(samples)
  n_iterations <- dim_samples[1]

  if (length(dim_samples) == 4) {  # For 3D parameters like Phi and Lambda
    n_topics <- dim_samples[2]
    n_diseases_or_individuals <- dim_samples[3]
    T <- dim_samples[4]

    # Choose a few representative slices to plot
    slices <- list(
      c(1, 1, T %/% 2),  # First topic, first disease/individual, middle time point
      c(n_topics, n_diseases_or_individuals, T %/% 2),  # Last topic, last disease/individual, middle time point
      c(n_topics %/% 2, n_diseases_or_individuals %/% 2, T)  # Middle topic, middle disease/individual, last time point
    )

    for (slice in slices) {
      chain <- mcmc(samples[, slice[1], slice[2], slice[3]])

      # Plot trace
      plot(chain, main=paste(parameter, "- Topic", slice[1], "Disease/Individual", slice[2], "Time", slice[3]))

      # Print summary
      print(summary(chain))

      # Effective sample size
      print(effectiveSize(chain))
    }
  } else if (length(dim_samples) == 3) {  # For 2D parameters like Gamma
    n_topics <- dim_samples[2]
    n_columns <- dim_samples[3]

    # Choose a few representative slices to plot
    slices <- list(
      c(1, 1),  # First topic, first column
      c(n_topics, n_columns),  # Last topic, last column
      c(n_topics %/% 2, n_columns %/% 2)  # Middle topic, middle column
    )

    for (slice in slices) {
      chain <- mcmc(samples[, slice[1], slice[2]])

      # Plot trace
      plot(chain, main=paste(parameter, "- Topic", slice[1], "Column", slice[2]))

      # Print summary
      print(summary(chain))

      # Effective sample size
      print(effectiveSize(chain))
    }
  } else {  # For 1D parameters like length scales and var scales
    chain <- mcmc(samples)

    # Plot trace
    plot(chain, main=parameter)

    # Print summary
    print(summary(chain))

    # Effective sample size
    print(effectiveSize(chain))
  }

  # Calculate overall acceptance rates
  acceptance_rates <- mean(diff(as.vector(samples)) != 0)
  print(paste("Overall acceptance rate:", round(acceptance_rates, 3)))
}


# Usage
analyze_results(mcmc_results, "Phi")


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
    -3.421e-02      1.703e-03      2.408e-05      5.358e-04 

2. Quantiles for each variable:

    2.5%      25%      50%      75%    97.5% 
-0.03777 -0.03496 -0.03383 -0.03313 -0.03128 

   var1 
10.0985 


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
     1.014e-02      5.199e-03      7.353e-05      2.095e-03 

2. Quantiles for each variable:

    2.5%      25%      50%      75%    97.5% 
0.006098 0.007559 0.009092 0.010114 0.028899 

    var1 
6.161467 


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
     6.065e-02      4.191e-03      5.928e-05      2.227e-03 

2. Quantiles for each variable:

   2.5%     25%     50%     75%   97.5% 
0.05496 0.05782 0.05986 0.06177 0.07024 

    var1 
3.542008 
[1] "Overall acceptance rate: 0.441"
Code
analyze_results(mcmc_results, "Lambda")


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
    -4.494e-02      2.558e-03      3.617e-05      9.859e-04 

2. Quantiles for each variable:

    2.5%      25%      50%      75%    97.5% 
-0.05083 -0.04572 -0.04429 -0.04317 -0.04164 

    var1 
6.730941 


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
    -7.885e-03      4.562e-03      6.452e-05      1.539e-03 

2. Quantiles for each variable:

     2.5%       25%       50%       75%     97.5% 
-0.024232 -0.008404 -0.007075 -0.005689 -0.003354 

    var1 
8.793405 


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
     8.519e-02      2.913e-03      4.119e-05      1.272e-03 

2. Quantiles for each variable:

   2.5%     25%     50%     75%   97.5% 
0.07779 0.08363 0.08566 0.08722 0.08897 

    var1 
5.244309 
[1] "Overall acceptance rate: 0.441"
Code
analyze_results(mcmc_results, "Gamma")


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
     -0.304587       0.479222       0.006777       0.123175 

2. Quantiles for each variable:

    2.5%      25%      50%      75%    97.5% 
-1.22503 -0.65061 -0.29321  0.04967  0.63304 

    var1 
15.13664 


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
     -0.186976       0.498230       0.007046       0.120126 

2. Quantiles for each variable:

   2.5%     25%     50%     75%   97.5% 
-1.1841 -0.5214 -0.1647  0.1775  0.7788 

    var1 
17.20226 


Iterations = 1:5000
Thinning interval = 1 
Number of chains = 1 
Sample size per chain = 5000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

          Mean             SD       Naive SE Time-series SE 
      -0.24699        0.46739        0.00661        0.10484 

2. Quantiles for each variable:

   2.5%     25%     50%     75%   97.5% 
-1.0480 -0.5924 -0.2710  0.0593  0.6903 

    var1 
19.87526 
[1] "Overall acceptance rate: 0.501"