Code
library(dplyr)
Attaching package: 'dplyr'
The following objects are masked from 'package:stats':
filter, lag
The following objects are masked from 'package:base':
intersect, setdiff, setequal, union
Code
library(reshape2)
library(mvtnorm)
library(rstan)Loading required package: StanHeaders
rstan version 2.32.6 (Stan version 2.32.2)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)
For within-chain threading using `reduce_sum()` or `map_rect()` Stan functions,
change `threads_per_chain` option:
rstan_options(threads_per_chain = 1)
Code
# Set seed for reproducibility
set.seed(123)
# 1. Simulation (from simnoulli)
# Assuming this function is defined in your simnoulli script
source("../sim_noulli.R")
Attaching package: 'MASS'
The following object is masked from 'package:dplyr':
select
Attaching package: 'Matrix'
The following objects are masked from 'package:pracma':
expm, lu, tril, triu
# A tibble: 20 × 2
age `length(unique(eid))`
<int> <int>
1 1 41
2 2 23
3 3 16
4 4 7
5 5 9
6 6 12
7 7 9
8 8 10
9 9 9
10 10 8
11 11 7
12 12 16
13 13 17
14 14 11
15 15 13
16 16 9
17 17 10
18 18 6
19 19 4
20 20 8
Attaching package: 'plotly'
The following object is masked from 'package:ggplot2':
last_plot
The following object is masked from 'package:MASS':
select
The following object is masked from 'package:stats':
filter
The following object is masked from 'package:graphics':
layout
Attaching package: 'data.table'
The following objects are masked from 'package:reshape2':
dcast, melt
The following objects are masked from 'package:dplyr':
between, first, last
Attaching package: 'DT'
The following objects are masked from 'package:shiny':
dataTableOutput, renderDataTable
Warning: The melt generic in data.table has been passed a array and will
attempt to redirect to the relevant reshape2 method; please note that reshape2
is superseded and is no longer actively developed, and this redirection is now
deprecated. To continue using melt methods from reshape2 while both libraries
are attached, e.g. melt.list, you can prepend the namespace, i.e.
reshape2::melt(topic_disease_time_array). In the next version, this warning
will become an error.
Code
# Extract dimensions
N <- dim(y)[1]
D <- dim(y)[2]
T <- dim(y)[3]
K <- 3 # Number of topics, adjust as needed
P <- ncol(g_i)
# Define the RBF kernel function
rbf_kernel <- function(t1, t2, length_scale, variance) {
return(variance * exp(-0.5 * (t1 - t2) ^ 2 / length_scale ^ 2))
}
# 2. Initialization functions
initialize_model <- function(y,
g_i,
K,
length_scale_lambda,
length_scale_phi,
variance_phi,
variance_lambda,smooth_span=0.75) {
N <- dim(y)[1]
D <- dim(y)[2]
T <- dim(y)[3]
P <- ncol(g_i)
# Generate time points
time_points <- 1:T
jitter <- 1e-6 # Small jitter term to improve conditioning
# Initialize Gamma
Gamma_init <- matrix(rnorm(K * P, 0, 0.1), nrow = K, ncol = P)
# Initialize Lambda
Lambda_init <- array(0, dim = c(N, K, T))
for (k in 1:K) {
mean_lambda <- g_i %*% Gamma_init[k, ]
ls_lambda = length_scale_lambda[k]
vs_lambda = variance_lambda[k]
# Compute the first row of the Toeplitz matrix
first_row <- sapply(time_points, function(t)
rbf_kernel(time_points[1], t, ls_lambda, vs_lambda))
# Create the Toeplitz matrix
Kern <- toeplitz(first_row)
image(Kern)
# Add jitter to the diagonal to improve stability
Kern <- Kern + diag(jitter, nrow(Kern))
for (i in 1:N) {
Lambda_init[i, k, ] <- mvrnorm(1, mu = rep(mean_lambda[i], T), Sigma = Kern)
}
}
# Initialize Phi
Phi_init <- array(0, dim = c(K, D, T))
for (k in 1:K) {
ls_phi = length_scale_phi[k]
vs_phi = variance_phi[k]
# Compute the first row of the Toeplitz matrix
first_row <- sapply(time_points, function(t)
rbf_kernel(time_points[1], t, ls_phi, vs_phi))
# Create the Toeplitz matrix
Kern <- toeplitz(first_row)
image(Kern)
# Add jitter to the diagonal to improve stability
Kern <- Kern + diag(jitter, nrow(Kern))
for (d in 1:D) {
Phi_init[k, d, ] <- mvrnorm(1, mu = rep(0, T), Sigma = Kern)
}
}
# Calculate and smooth mu_d
mu_d <- matrix(0, D, T)
for (d in 1:D) {
# Calculate raw prevalence
raw_prevalence <- sapply(1:T, function(t) mean(y[, d, t]))
# Apply logit transformation with small constant to avoid infinite values
logit_prev <- qlogis((raw_prevalence * (N - 1) + 0.5) / N)
# Apply LOESS smoothing
time_points <- 1:T
smooth_logit <- loess(logit_prev ~ time_points, span = smooth_span)
# Predict smoothed values
mu_d[d, ] <- predict(smooth_logit, time_points)
}
return(list(
Lambda = Lambda_init,
Gamma = Gamma_init,
Phi = Phi_init,
mu_d = mu_d
))
}
#### Here we initialize with a different phi and lambda for each
# Assuming you've already defined and run initialize_model function
initial_values <- initialize_model(y, g_i, K, length_scale_lambda=length_scales_lambda, length_scale_phi=length_scales_phi,
variance_phi = var_scales_phi,
variance_lambda = var_scales_lambda,smooth_span = 0.75)Code
plot(initial_values$mu_d[1,],type="l")
matplot(t(initial_values$Phi[1,,]),type="l")Code
matplot(t(initial_values$Lambda[1,,]),type="l")