here we write the Simulation RMD

# Load required libraries
set.seed(123)
source("../logit_factorization/simwithlogit.R")
## 
## Attaching package: 'dplyr'
## The following object is masked from 'package:MASS':
## 
##     select
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union

## 
## Attaching package: 'Matrix'
## The following objects are masked from 'package:pracma':
## 
##     expm, lu, tril, triu

library(MASS)

# Helper Functions
rbf_kernel <- function(t1, t2, length_scale, variance) {
  return(variance * exp(-0.5 * (t1 - t2)^2 / length_scale^2))
}

softmax <- function(x) {
  exp_x <- exp(x - max(x))  # Subtract max for numerical stability
  return(exp_x / sum(exp_x))
}

# Function to apply softmax to Lambda
apply_softmax_to_lambda <- function(Lambda) {
  N <- dim(Lambda)[1]
  K <- dim(Lambda)[2]
  T <- dim(Lambda)[3]

  theta <- array(0, dim = c(N, K, T))
  for (i in 1:N) {
    for (t in 1:T) {
      theta[i, , t] <- softmax(Lambda[i, , t])
    }
  }
  return(theta)
}


precompute_K_inv <- function(T, length_scale, var_scale) {
  time_diff_matrix <- outer(1:T, 1:T, "-")^2
  K <- var_scale * exp(-0.5 * time_diff_matrix / length_scale^2)
  K <- K + diag(1e-6, T)  # Add small jitter for numerical stability
  K_inv <- solve(K)
  log_det_K <- determinant(K, logarithm = TRUE)$modulus
  return(list(K_inv = K_inv, log_det_K = log_det_K))
}

log_gp_prior_vec <- function(eta, mean, K_inv, log_det_K) {
  T <- length(eta)
  centered_eta <- eta - mean
  quad_form <- sum(centered_eta * (K_inv %*% centered_eta))
  log_prior <- -0.5 * (log_det_K + quad_form + T * log(2 * pi))
  return(log_prior)
}



logistic <- function(x) {
  return(1 / (1 + exp(-x)))
}

# Initialization Function
initialize_mcmc <- function(y, g_i, n_topics, n_diseases, T) {
  N <- dim(y)[1]  # Number of individuals
  P <- ncol(g_i)  # Number of genetic covariates

  # Initialize Lambda
  Lambda_init <- array(rnorm(N * n_topics * T, mean = 0, sd = 0.1), dim = c(N, n_topics, T))

  # Initialize Phi
  empirical_rates <- apply(y, c(2, 3), mean)
  Phi_init <- array(0, dim = c(n_topics, n_diseases, T))
  for (k in 1:n_topics) {
    Phi_init[k, , ] <- qlogis(empirical_rates + 0.01) + rnorm(n_diseases * T, mean = 0, sd = 0.1)
  }

  # Initialize Gamma
  Gamma_init <- matrix(rnorm(n_topics * P, mean = 0, sd = 0.1), nrow = n_topics, ncol = P)

  # Initialize mu_d (baseline disease risk)
  mu_d_init <- qlogis(apply(y, 2, mean))

  # Initialize length scales and variance scales
  length_scales_lambda <- rep(T / 4, n_topics)
  var_scales_lambda <- rep(1, n_topics)
  length_scales_phi <- rep(T / 4, n_topics)
  var_scales_phi <- rep(1, n_topics)

  return(list(
    Lambda = Lambda_init,
    Phi = Phi_init,
    Gamma = Gamma_init,
    mu_d = mu_d_init,
    length_scales_lambda = length_scales_lambda,
    var_scales_lambda = var_scales_lambda,
    length_scales_phi = length_scales_phi,
    var_scales_phi = var_scales_phi
  ))
}

# Log-likelihood function
log_likelihood <- function(y, Lambda, Phi) {
  n_individuals <- dim(Lambda)[1]
  n_topics <- dim(Lambda)[2]
  n_diseases <- dim(Phi)[2]
  T <- dim(Lambda)[3]

  theta <- apply_softmax_to_lambda(Lambda) # Apply softmax to Lambda
  pi <- array(0, dim = c(n_individuals, n_diseases, T))

  for (t in 1:T) {
    pi[, , t] <- theta[, ,t ] %*% logistic(Phi[, , t])
  }

  log_lik <- 0
  for (i in 1:n_individuals) {
    for (d in 1:n_diseases) {
      at_risk <- which(cumsum(y[i, d, ]) == 0)
      if (length(at_risk) > 0) {
        event_time <- max(at_risk) + 1
        if (event_time <= T) {
          log_lik <- log_lik + log(pi[i, d, event_time])
        }
        log_lik <- log_lik + sum(log(1 - pi[i, d, at_risk]))
      } else {
        log_lik <- log_lik + log(pi[i, d, 1])
      }
    }
  }
  return(log_lik)
}

# MCMC Sampler
mcmc_sampler_softmax <- function(y, g_i, n_iterations, initial_values) {
  current_state <- initial_values
  n_individuals <- dim(current_state$Lambda)[1]
  n_topics <- dim(current_state$Lambda)[2]
  T <- dim(current_state$Lambda)[3]
  n_diseases <- dim(current_state$Phi)[2]

  # Initialize proposal standard deviations
  adapt_sd <- list(
    Lambda = array(0.01, dim = dim(current_state$Lambda)),
    Phi = array(0.01, dim = dim(current_state$Phi)),
    Gamma = matrix(0.01, nrow = nrow(current_state$Gamma), ncol = ncol(current_state$Gamma))
  )

  # Initialize storage for samples
  samples <- list(
    Lambda = array(0, dim = c(n_iterations, dim(current_state$Lambda))),
    Phi = array(0, dim = c(n_iterations, dim(current_state$Phi))),
    Gamma = array(0, dim = c(n_iterations, dim(current_state$Gamma)))
  )

  # Precompute inverse covariance matrices for Lambda and Phi
  K_inv_lambda <- lapply(1:n_topics, function(k)
    precompute_K_inv(T, current_state$length_scales_lambda[k], current_state$var_scales_lambda[k]))

  K_inv_phi <- lapply(1:n_topics, function(k)
    precompute_K_inv(T, current_state$length_scales_phi[k], current_state$var_scales_phi[k]))

  for (iter in 1:n_iterations) {
    # Update Lambda
    proposed_Lambda <- current_state$Lambda +
      array(rnorm(prod(dim(current_state$Lambda)), 0, adapt_sd$Lambda),
            dim = dim(current_state$Lambda))

    current_log_lik <- log_likelihood(y, current_state$Lambda, current_state$Phi)
    proposed_log_lik <- log_likelihood(y, proposed_Lambda, current_state$Phi)

    # Compute log prior for Lambda (GP prior)
    current_log_prior <- sum(sapply(1:n_individuals, function(i) {
      sapply(1:n_topics, function(k) {
        log_gp_prior_vec(
          current_state$Lambda[i, k, ],
          rep(g_i[i, ] %*% current_state$Gamma[k, ], T),
          K_inv_lambda[[k]]$K_inv,
          K_inv_lambda[[k]]$log_det_K
        )
      })
    }))

    proposed_log_prior <- sum(sapply(1:n_individuals, function(i) {
      sapply(1:n_topics, function(k) {
        log_gp_prior_vec(
          proposed_Lambda[i, k, ],
          rep(g_i[i, ] %*% current_state$Gamma[k, ], T),
          K_inv_lambda[[k]]$K_inv,
          K_inv_lambda[[k]]$log_det_K
        )
      })
    }))

    log_accept_ratio <- (proposed_log_lik + proposed_log_prior) - (current_log_lik + current_log_prior)

    if (log(runif(1)) < log_accept_ratio) {
      current_state$Lambda <- proposed_Lambda
      adapt_sd$Lambda <- adapt_sd$Lambda * 1.01
    } else {
      adapt_sd$Lambda <- adapt_sd$Lambda * 0.99
    }

    # Update Phi
    proposed_Phi <- current_state$Phi +
      array(rnorm(prod(dim(current_state$Phi)), 0, adapt_sd$Phi), dim = dim(current_state$Phi))

    current_log_lik <- log_likelihood(y, current_state$Lambda, current_state$Phi)
    proposed_log_lik <- log_likelihood(y, current_state$Lambda, proposed_Phi)

    # Compute log prior for Phi (using GP prior)
    current_log_prior_phi <- sum(sapply(1:n_topics, function(k) {
      sapply(1:n_diseases, function(d) {
        log_gp_prior_vec(current_state$Phi[k, d, ], current_state$mu_d[d], K_inv_phi[[k]]$K_inv, K_inv_phi[[k]]$log_det_K)
      })
    }))

    proposed_log_prior_phi <- sum(sapply(1:n_topics, function(k) {
      sapply(1:n_diseases, function(d) {
        log_gp_prior_vec(proposed_Phi[k, d, ], current_state$mu_d[d], K_inv_phi[[k]]$K_inv, K_inv_phi[[k]]$log_det_K)
      })
    }))

    log_accept_ratio <- (proposed_log_lik + proposed_log_prior_phi) - (current_log_lik + current_log_prior_phi)

    if (log(runif(1)) < log_accept_ratio) {
      current_state$Phi <- proposed_Phi
      adapt_sd$Phi <- adapt_sd$Phi * 1.01
    } else {
      adapt_sd$Phi <- adapt_sd$Phi * 0.99
    }

    # Update Gamma
    proposed_Gamma <- current_state$Gamma +
      matrix(rnorm(prod(dim(current_state$Gamma)), 0, adapt_sd$Gamma), nrow = nrow(current_state$Gamma))

    current_log_prior <- sum(dnorm(current_state$Gamma, 0, 1, log = TRUE))
    proposed_log_prior <- sum(dnorm(proposed_Gamma, 0, 1, log = TRUE))

    log_accept_ratio <- proposed_log_prior - current_log_prior

    if (log(runif(1)) < log_accept_ratio) {
      current_state$Gamma <- proposed_Gamma
      adapt_sd$Gamma <- adapt_sd$Gamma * 1.01
    } else {
      adapt_sd$Gamma <- adapt_sd$Gamma * 0.99
    }

    # Store samples
    samples$Lambda[iter, , , ] <- current_state$Lambda
    samples$Phi[iter, , , ] <- current_state$Phi
    samples$Gamma[iter, , ] <- current_state$Gamma

    # Print progress
    if (iter %% 100 == 0) cat("Iteration", iter, "\n")
  }

  return(samples)
}

# Main execution
# Assuming y and g_i are already loaded
n_topics <- 3  # Set this to your desired number of topics
n_diseases <- dim(y)[2]
T <- dim(y)[3]

initial_values <- initialize_mcmc(y, g_i, n_topics, n_diseases, T)

n_iterations <- 20000
samples <- mcmc_sampler_softmax(y, g_i, n_iterations, initial_values)

# Save results
saveRDS(samples, "mcmc_samples.rds")

# Basic plotting of results (example for Lambda)
samples=readRDS("../logit_factorization/mcmc_samples.rds")
plot(samples$Lambda[, 1, 1, 1], type = 'l', main = "Trace plot for Lambda[1,1,1]")

plot(samples$Phi[, 1, 1, 1], type = 'l', main = "Trace plot for Phi[1,1,1]")

plot(samples$Gamma[,  1, 1], type = 'l', main = "Trace plot for Gamma[1,1,1]")

library(coda)
library(bayesplot)
## This is bayesplot version 1.11.1
## - Online documentation and vignettes at mc-stan.org/bayesplot
## - bayesplot theme set to bayesplot::theme_default()
##    * Does _not_ affect other ggplot2 plots
##    * See ?bayesplot_theme_set for details on theme setting
library(ggplot2)

# Convert your samples to mcmc objects
lambda_mcmc <- as.mcmc(samples$Lambda[, 1, 1, 1])  # Example for one element
phi_mcmc <- as.mcmc(samples$Phi[, 1, 1, 1])
gamma_mcmc <- as.mcmc(samples$Gamma[, 1, 1])


# Autocorrelation plots
acf(lambda_mcmc)

acf(phi_mcmc)

acf(gamma_mcmc)

# Effective Sample Size
effectiveSize(lambda_mcmc)
##     var1 
## 7.032337
effectiveSize(phi_mcmc)
##     var1 
## 4.238507
effectiveSize(gamma_mcmc)
##     var1 
## 63.46915
# Posterior summaries
summary(lambda_mcmc)
## 
## Iterations = 1:5000
## Thinning interval = 1 
## Number of chains = 1 
## Sample size per chain = 5000 
## 
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
## 
##           Mean             SD       Naive SE Time-series SE 
##      0.2634967      0.0008982      0.0000127      0.0003387 
## 
## 2. Quantiles for each variable:
## 
##   2.5%    25%    50%    75%  97.5% 
## 0.2617 0.2630 0.2633 0.2640 0.2656
summary(phi_mcmc)
## 
## Iterations = 1:5000
## Thinning interval = 1 
## Number of chains = 1 
## Sample size per chain = 5000 
## 
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
## 
##           Mean             SD       Naive SE Time-series SE 
##     -4.3413758      0.0328702      0.0004649      0.0159660 
## 
## 2. Quantiles for each variable:
## 
##   2.5%    25%    50%    75%  97.5% 
## -4.378 -4.356 -4.348 -4.342 -4.224
summary(gamma_mcmc)
## 
## Iterations = 1:5000
## Thinning interval = 1 
## Number of chains = 1 
## Sample size per chain = 5000 
## 
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
## 
##           Mean             SD       Naive SE Time-series SE 
##      -0.007278       1.030893       0.014579       0.129399 
## 
## 2. Quantiles for each variable:
## 
##     2.5%      25%      50%      75%    97.5% 
## -1.92014 -0.75794 -0.04009  0.70549  2.01853
# Compare with true values (assuming you have access to these)
true_lambda <- lambda_ik[1, 1, 1]  # Example for one element
true_phi <- qlogis(phi_kd[1, 1, 1])
true_gamma <- Gamma_k[1, 1]

cat("True Lambda:", true_lambda, "\n")
## True Lambda: 1.165236
cat("Estimated Lambda (mean):", mean(lambda_mcmc), "\n")
## Estimated Lambda (mean): 0.2634967
cat("True Phi:", true_phi, "\n")
## True Phi: -6.644081
cat("Estimated Phi (mean):", mean(phi_mcmc), "\n")
## Estimated Phi (mean): -4.341376
cat("True Gamma:", true_gamma, "\n")
## True Gamma: 0.3626501
cat("Estimated Gamma (mean):", mean(gamma_mcmc), "\n")
## Estimated Gamma (mean): -0.007278354
# Plot posterior distributions with true values
ggplot(data.frame(lambda = as.vector(lambda_mcmc)), aes(x = lambda)) +
  geom_density() +
  geom_vline(xintercept = true_lambda, color = "red") +
  ggtitle("Posterior distribution of Lambda[1,1,1]")

pi

lambda_mean=apply(samples$Lambda,c(2,3,4),mean)
phi_mean=plogis(apply(samples$Phi,c(2,3,4),mean))
theta=apply(lambda_mean,c(1,3),function(x){softmax(x)})
pi_post=array(data = 0,dim=c(N,D,T))
for(t in 1:T){
  pi_post[, , t] <- t(theta[, ,t ]) %*% phi_mean[, , t]
}

par(mfrow = c(2, 2))
for (i in sample(1:N, 4)) {
  matplot(
    t(pi_post[i, , ]),
    type = 'l',
    main = paste("Pi for individual", i),
    xlab = "Time",
    ylab = "Pi"
  )
}

for (i in sample(1:N, 4)) {
  matplot(
    t(pi_values[i, , ]),
    type = 'l',
    main = paste("True Pi for individual", i),
    xlab = "Time",
    ylab = "Pi"
  )
}