Commit ec0609fd authored by Chris Jewell's avatar Chris Jewell
Browse files

Added AdaptiveSingleSiteMRW updater. Corrected a problem in AdaptiveLogDirMRW...

Added AdaptiveSingleSiteMRW updater.  Corrected a problem in AdaptiveLogDirMRW where individual components were not being able to adapt independently.
parent 818d3e48
......@@ -7,6 +7,65 @@
# class DAG model. #
##################################
AdaptiveSingleSiteMRW <- R6::R6Class(
"AdaptiveSingleSiteMRW",
public = list(
toupdate = NA,
naccept = NA,
ncalls = NA,
acceptbatch = NA,
batchsize = NA,
node = NA,
tune = NA,
initialize = function(node, tune = 0.1,
batchsize = 50) {
self$naccept <- 0
self$ncalls <- 0
self$acceptbatch <- 0
self$batchsize <- batchsize
self$tune <- tune
self$node <- node
},
update = function() {
self$ncalls <- self$ncalls + 1
old_data <- self$node$getData()
picur <- self$node$logPosterior()
# Propose using MHRW
self$node$data <- old_data * exp( rnorm(1, 0, self$tune) )
pican <- self$node$logPosterior()
alpha <- pican - picur + log(self$node$data/old_data)
if (is.finite(alpha) &
log(runif(1)) < alpha) {
self$acceptbatch <- self$acceptbatch + 1
}
else {
self$node$data <- old_data
}
private$adapt()
},
acceptance = function() {
self$naccept / self$ncalls
}
),
private = list(
adapt = function() {
if ((self$ncalls > 0) &
(self$ncalls %% self$batchsize == 0)) {
m <- ifelse(self$acceptbatch / self$batchsize > 0.44, 1,-1)
self$tune <-
exp(log(self$tune) + m * min(0.05, 1.0 / sqrt(self$ncalls)))
if(!is.finite(self$tune)) self$tune <- 1e-8
self$naccept <-
self$naccept + self$acceptbatch
self$acceptbatch <- 0
}
}
)
)
#' AdaptiveMultiMRW
#'
#' This class implements an Adaptive Multi-site Metropolis random walk
......@@ -258,7 +317,7 @@ AdaptiveLogDirMRW <- R6::R6Class(
batchsize = 50) {
self$toupdate = toupdate
self$naccept <- rep(0, length(node$getData()))
self$ncalls <- 0
self$ncalls <- rep(0, length(node$getData()))
self$acceptbatch <-
rep(0, length(node$getData()))
self$batchsize <- batchsize
......@@ -266,8 +325,9 @@ AdaptiveLogDirMRW <- R6::R6Class(
self$node <- node
},
update = function() {
self$ncalls <- self$ncalls + 1
for (j in self$toupdate()) {
updIdx <- self$toupdate()
for (j in updIdx) {
self$ncalls[j] <- self$ncalls[j] + 1
old_data <- self$node$getData()
picur <- self$node$logPosterior()
......@@ -286,23 +346,28 @@ AdaptiveLogDirMRW <- R6::R6Class(
else {
self$node$data <- old_data
}
private$adapt(j)
}
if(self$ncalls %% 50 == 0) private$adapt()
},
acceptance = function() {
self$naccept <- self$naccept + self$acceptbatch
self$naccept / self$ncalls
}
),
private = list(
adapt = function() {
if ((self$ncalls > 0) &
(self$ncalls %% self$batchsize == 0)) {
m <- ifelse(self$acceptbatch / self$batchsize > 0.44, 1,-1)
self$tune <-
exp(log(self$tune) + m * min(0.05, 2.0 / sqrt(self$ncalls)))
self$naccept <-
self$naccept + self$acceptbatch
self$acceptbatch <- self$acceptbatch * 0
adapt = function(j) {
if ((self$ncalls[j] > 0) &
(self$ncalls[j] %% self$batchsize == 0)) {
m <- ifelse(self$acceptbatch[j] / self$batchsize > 0.44, 1,-1)
self$tune[j] <-
exp(log(self$tune[j]) + m * min(0.05, 1.0 / sqrt(self$ncalls[j])))
if(!is.finite(self$tune[j])) {
warning('AdaptiveLogDirMRW tuning variance non-finite')
self$tune[j] <- 1e-9
}
self$naccept[j] <-
self$naccept[j] + self$acceptbatch[j]
self$acceptbatch[j] <- 0
}
}
)
......@@ -325,7 +390,7 @@ AdaptiveLogDirMRW2 <- R6::R6Class(
batchsize = 50) {
self$toupdate = toupdate
self$naccept <- rep(0, length(node$data))
self$ncalls <- 0
self$ncalls <- rep(0, length(node$data))
self$acceptbatch <-
rep(0, length(node$data))
self$batchsize <- batchsize
......@@ -333,8 +398,9 @@ AdaptiveLogDirMRW2 <- R6::R6Class(
self$node <- node
},
update = function() {
self$ncalls <- self$ncalls + 1
for (j in self$toupdate()) {
updIdx <- self$toupdate()
for (j in updIdx) {
self$ncalls[j] <- self$ncalls[j] + 1
old_data <- self$node$data
picur <- self$node$logPosterior()
......@@ -350,23 +416,24 @@ AdaptiveLogDirMRW2 <- R6::R6Class(
else {
self$node$data <- old_data
}
private$adapt(j)
}
if(self$ncalls %% 50 == 0) private$adapt()
},
acceptance = function() {
self$naccept <- self$naccept + self$acceptbatch
self$naccept / self$ncalls
}
),
private = list(
adapt = function() {
if ((self$ncalls > 0) &
(self$ncalls %% self$batchsize == 0)) {
m <- ifelse(self$acceptbatch / self$batchsize > 0.44, 1,-1)
self$tune <-
exp(log(self$tune) + m * min(0.05, 2.0 / sqrt(self$ncalls)))
self$naccept <-
self$naccept + self$acceptbatch
self$acceptbatch <- self$acceptbatch * 0
adapt = function(j) {
if ((self$ncalls[j] > 0) &
(self$ncalls[j] %% self$batchsize == 0)) {
m <- ifelse(self$acceptbatch[j] / self$batchsize > 0.44, 1,-1)
self$tune[j] <-
exp(log(self$tune[j]) + m * min(0.05, 1.0 / sqrt(self$ncalls[j])))
self$naccept[j] <-
self$naccept[j] + self$acceptbatch[j]
self$acceptbatch[j] <- 0
}
}
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment