Commit f69b9e0d authored by Poppy Miller's avatar Poppy Miller
Browse files

Changed ddirich to calculate on log scale rather than logging after. Changed...

Changed ddirich to calculate on log scale rather than logging after. Changed starting tuning values for r's. Fixed lambda_j_prop calculation (didn't work if lambda j not calculated first)
parent 82964bde
......@@ -378,14 +378,23 @@ HaldDP <- R6::R6Class(
}
},
calc_lambda_j_prop = function(nTimes, nLocations) {
# TODO: check with multiple times and locations
if (is.null(self$lambda_j)) self$calc_lambda_j()
calc_lambda_j_prop = function(n_iter, nSources, nTimes,
nLocations,
namesSources, namesTimes,
namesLocations,
namesIters,
k) {
if (is.null(self$lambda_j)) self$calc_lambda_j(nSources, nTimes,
nLocations, nTypes, n_iter,
namesSources, namesTimes,
namesLocations, namesTypes,
namesIters)
self$lambda_j_prop <- self$lambda_j
for (iter in 1:dim(self$lambda_j)[4]) {
for (times in 1:nTimes) {
for (locations in 1:nLocations) {
self$lambda_j_prop[,times,locations,iter] <- self$lambda_j[,times,locations,iter] / sum(self$lambda_j[,times,locations,iter])
self$lambda_j_prop[,times,locations,iter] <-
self$lambda_j[,times,locations,iter] / sum(self$lambda_j[,times,locations,iter])
}
}
}
......@@ -511,10 +520,15 @@ HaldDP <- R6::R6Class(
if (isTRUE(private$params_fix$r)) {
r_updaters[[time]][[sources]] <- NULL
} else {
# method of moments: beta to choose tuning values
alpha <- private$priors$a_r[, sources, time] + private$X[, sources, time]
alpha_0 <- sum(alpha)
var_alphas <- (alpha * (alpha_0 - alpha)) / ((alpha_0 ^ 2) * (alpha_0 + 1))
tune_val <- 100 * sqrt(var_alphas)
r_updaters[[time]][[sources]] <- AdaptiveLogDirMRW$new(private$DPModel_impl$rNodes[[time]][[sources]],
toupdate = function() sample(private$nTypes, private$n_r),
tune = rep(1.0, private$nTypes),
batchsize=10,
tune = tune_val,
batchsize = 10, # adaptive batch size
name = paste0("r: ", "time ", private$namesTimes[time],
", source ", private$namesSources[sources]))
}
......@@ -1315,7 +1329,15 @@ HaldDP <- R6::R6Class(
}
if ("lambda_j_prop" %in% params) {
private$posterior$calc_lambda_j_prop(private$nTimes, private$nLocations)
private$posterior$calc_lambda_j_prop(private$n_iter,
private$nSources,
private$nTimes,
private$nLocations,
private$namesSources,
private$namesTimes,
private$namesLocations,
1:private$n_iter,
private$k)
}
return(list(params = params, times = times, locations = locations, sources = sources, types = types, iters = iters, flatten = flatten))
......
......@@ -343,6 +343,7 @@ AdaptiveLogDirMRW <- R6::R6Class(
pican <- self$node$logPosterior()
alpha <- pican - picur + log(self$node$data[j]/old_data[j])
if (is.finite(alpha) &
log(runif(1)) < alpha) {
self$acceptbatch[j] <- self$acceptbatch[j] + 1
......
......@@ -286,7 +286,7 @@ DirichletNode <- R6::R6Class(
self$addParent(alpha, name = 'alpha')
},
logDensity = function() {
log(gtools::ddirichlet(self$data, self$parents$alpha$getData()))
sum((self$parents$alpha$getData() - 1) * log(self$data))
}
)
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment