Commit 5af67f5d authored by Chris Jewell's avatar Chris Jewell
Browse files

Implemented AdaptiveLogDirMRW2 and DirichletNode2 which reparameterises the...

Implemented AdaptiveLogDirMRW2 and DirichletNode2 which reparameterises the Dirichlet distribution as independent Gamma distributions.
parent cf0a9863
......@@ -483,7 +483,7 @@ HaldDP <- R6::R6Class(
if (isTRUE(private$params_fix$r)) {
private$r_updaters[[time]][[sources]] <- function() return(NULL)
} else {
private$r_updaters[[time]][[sources]] <- AdaptiveLogDirMRW$new(private$DPModel_impl$rNodes[[time]][[sources]],
private$r_updaters[[time]][[sources]] <- AdaptiveLogDirMRW2$new(private$DPModel_impl$rNodes[[time]][[sources]],
toupdate = function() sample(private$nTypes, private$n_r),
tune = rep(0.01, private$nTypes))
}
......@@ -1335,7 +1335,6 @@ HaldDP <- R6::R6Class(
update = function(n_iter, append) # FINISHED
{
browser()
if (!missing(append)) {
if (all(is.na(private$posterior$r))) {
## if first time running, set append to false
......
......@@ -299,7 +299,71 @@ AdaptiveLogDirMRW <- R6::R6Class(
(self$ncalls %% self$batchsize == 0)) {
m <- ifelse(self$acceptbatch / self$batchsize > 0.44, 1,-1)
self$tune <-
exp(log(self$tune) + m * min(0.05, 1.0 / sqrt(self$ncalls)))
exp(log(self$tune) + m * min(0.05, 2.0 / sqrt(self$ncalls)))
self$naccept <-
self$naccept + self$acceptbatch
self$acceptbatch <- self$acceptbatch * 0
}
}
)
)
AdaptiveLogDirMRW2 <- R6::R6Class(
"AdaptiveDirMRW2",
public = list(
toupdate = NA,
naccept = NA,
ncalls = NA,
acceptbatch = NA,
batchsize = NA,
node = NA,
tune = NA,
initialize = function(node, toupdate = function() 1:length(node$data),
tune = rep(0.1, length(node$data)),
batchsize = 50) {
self$toupdate = toupdate
self$naccept <- rep(0, length(node$data))
self$ncalls <- 0
self$acceptbatch <-
rep(0, length(node$data))
self$batchsize <- batchsize
self$tune <- tune
self$node <- node
},
update = function() {
self$ncalls <- self$ncalls + 1
for (j in self$toupdate()) {
old_data <- self$node$data
picur <- self$node$logPosterior()
# Propose using MHRW
self$node$data[j] <- old_data[j] * exp( rnorm(1, 0, self$tune[j]) )
pican <- self$node$logPosterior()
alpha <- pican - picur + log(self$node$data[j]/old_data[j])
if (is.finite(alpha) &
log(runif(1)) < alpha) {
self$acceptbatch[j] <- self$acceptbatch[j] + 1
}
else {
self$node$data <- old_data
}
}
if(self$ncalls %% 50 == 0) private$adapt()
},
acceptance = function() {
self$naccept / self$ncalls
}
),
private = list(
adapt = function() {
if ((self$ncalls > 0) &
(self$ncalls %% self$batchsize == 0)) {
m <- ifelse(self$acceptbatch / self$batchsize > 0.44, 1,-1)
self$tune <-
exp(log(self$tune) + m * min(0.05, 2.0 / sqrt(self$ncalls)))
self$naccept <-
self$naccept + self$acceptbatch
self$acceptbatch <- self$acceptbatch * 0
......
......@@ -71,7 +71,6 @@ DPModel_impl <- R6::R6Class(
# Construct y, lambda, a, R, and k for location/time pairs
for (time in unique(Time)) {
# Prevalences
k <- DataNode$new(data = setNames(prev[, time], Sources), name = paste('k',time,sep = '_'))
......@@ -82,7 +81,7 @@ DPModel_impl <- R6::R6Class(
a_r_full <- DataNode$new(data = a_r[, src, time] + X[, src, time], name = paste('a_r', time, src, sep = '_')) # Todo: Prior here
xcol <- R[, src, time] + 0.000001 # Added for numeric stability -- only affects starting values for r
xcol <- xcol/sum(xcol)
self$rNodes[[time]][[src]] <- DirichletNode$new(data = setNames(xcol, Type),
self$rNodes[[time]][[src]] <- DirichletNode2$new(data = setNames(xcol, Type),
alpha = a_r_full,
name = paste('r', time, src, sep = '_'))
}
......
......@@ -290,6 +290,25 @@ DirichletNode <- R6::R6Class(
)
)
DirichletNode2 <- R6::R6Class(
"DirichletNode2",
inherit = StochasticNode,
public = list(
initialize = function(data, alpha, name) {
super$initialize(name = name)
self$data <- data
self$addParent(alpha, name='alpha')
},
logDensity = function() {
sum(dgamma(self$data, shape=self$parents$alpha$getData(), rate=1, log=T))
},
getData = function() {
self$data / sum(self$data)
}
)
)
#' Transformed Dirichlet node
#'
#' Uses transformation due to Betancourt 2013
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment