Commit 0ca9ced2 authored by Poppy Miller's avatar Poppy Miller
Browse files

Fixed bug calculating lambda_j_prop posterior

parent cf0a9863
Package: CpjSA
Package: sourceR
Type: Package
Title: What the Package Does (Title Case)
Version: 0.1
......
......@@ -362,12 +362,16 @@ HaldDP <- R6::R6Class(
}
},
calc_lambda_j_prop = function() {
calc_lambda_j_prop = function(nTimes, nLocations) {
# TODO: check with multiple times and locations
if (is.null(self$lambda_j)) self$calc_lambda_j()
self$lambda_j_prop <- self$lambda_j
for (iter in 1:dim(self$lambda_j)[4]) {
self$lambda_j_prop[,,,iter] <- self$lambda_j[,,,iter] / sum(self$lambda_j[,,,iter])
for (times in 1:nTimes) {
for (locations in 1:nLocations) {
self$lambda_j_prop[,times,locations,iter] <- self$lambda_j[,times,locations,iter] / sum(self$lambda_j[,times,locations,iter])
}
}
}
}
)
......@@ -492,9 +496,7 @@ HaldDP <- R6::R6Class(
if (isTRUE(private$params_fix$q)) {
private$q_updaters <- function() return(NULL)
} else {
private$q_updaters <- PoisGammaDPUpdate$new(private$DPModel_impl$qNodes,
alpha = private$priors$a_theta,
beta = private$priors$b_theta)
private$q_updaters <- PoisGammaDPUpdate$new(private$DPModel_impl$qNodes)
}
},
......@@ -813,7 +815,7 @@ HaldDP <- R6::R6Class(
stop("inits$alpha must be a data frame with columns called Location, Source, Time and Value with one row per combination of location, source and time.")
if (!all(private$test_positive_number(inits$alpha$Value))) stop("inits$alpha$Value must contain only positive numbers")
## Extract values from inits r data frame and put into arrays.
## Extract values from inits alpha data frame and put into arrays.
## Surely this is very slow!! TODO: find better way!
for (time in private$namesTimes) {
for (location in private$namesLocations) {
......@@ -1279,7 +1281,7 @@ HaldDP <- R6::R6Class(
}
if ("lambda_j_prop" %in% params) {
private$posterior$calc_lambda_j_prop()
private$posterior$calc_lambda_j_prop(private$nTimes, private$nLocations)
}
return(list(params = params, times = times, locations = locations, sources = sources, types = types, iters = iters, flatten = flatten))
......@@ -1335,7 +1337,6 @@ HaldDP <- R6::R6Class(
update = function(n_iter, append) # FINISHED
{
browser()
if (!missing(append)) {
if (all(is.na(private$posterior$r))) {
## if first time running, set append to false
......
......@@ -181,7 +181,7 @@ AdaptiveDirMRW <- R6::R6Class(
self$node$data <- old_data
}
}
if(self$ncalls %% 50 == 0) private$adapt()
if (self$ncalls %% 50 == 0) private$adapt()
},
acceptance = function() {
self$naccept / self$ncalls
......@@ -344,53 +344,43 @@ PoisGammaDPUpdate <- R6::R6Class(
alpha = NULL,
beta = NULL,
initialize = function(node, alpha, beta) {
initialize = function(node) {
self$node <- node
self$alpha <- alpha
self$beta <- beta
self$alpha <- node$baseShape
self$beta <- node$baseRate
},
update =
function() {
# Get y data
y <-
sapply(self$node$children, function(child)
child$getData()) %>% rowSums
aX <-
sapply(self$node$children, function(child)
child$parents$offset$getData()) %>% rowSums
y <- sapply(self$node$children, function(child)
child$getData()) %>% rowSums
aX <- sapply(self$node$children, function(child)
child$parents$offset$getData()) %>% rowSums
# Update s
nk <- table(self$node$s) # Expensive!
for (i in 1:length(self$node$s)) {
nk_local <-
as.numeric(nk)
names(nk_local) <-
names(nk)
nk_local[self$node$s[i]] <-
nk_local[self$node$s[i]] - 1
nk_local <- as.numeric(nk)
names(nk_local) <- names(nk)
nk_local[self$node$s[i]] <- nk_local[self$node$s[i]] - 1
theta <-
self$node$theta$values()
logpi <-
log(nk_local) + y[i] * log(theta) - theta * aX[i]
theta <- self$node$theta$values()
logpi <- log(nk_local) + y[i] * log(theta) - theta * aX[i]
logpiplus <- log(self$node$conc) +
log(self$beta ^ self$alpha / gamma(self$alpha)) +
lgamma(self$alpha + y[i]) -
(self$alpha + y[i]) * log(self$beta + aX[i])
logpi <-
c(logpi,logpiplus)
logpi[is.nan(logpi)] <-
-Inf
logpi <-
logpi - max(logpi)
logpi <- c(logpi, logpiplus)
logpi[is.nan(logpi)] <- -Inf
logpi <- logpi - max(logpi)
pi <- exp(logpi)
s_old <- self$node$s[i]
idx <-
sample(length(pi),1, prob = pi)
idx <- sample(length(pi),1, prob = pi)
# Check if we're adding another cluster
# n.b. could use has.key(model$s[i], model$theta)
......@@ -398,8 +388,7 @@ PoisGammaDPUpdate <- R6::R6Class(
if (idx == length(pi)) {
self$node$s[i] <- uuid::UUIDgenerate(TRUE)
self$node$theta$insert(self$node$s[i], rgamma(1, self$alpha + y[i], self$beta + aX[i]))
nk[self$node$s[i]] <-
0 # Add new class, increment later
nk[self$node$s[i]] <- 0 # Add new class, increment later
}
else {
self$node$s[i] <- names(pi)[idx]
......@@ -413,13 +402,11 @@ PoisGammaDPUpdate <- R6::R6Class(
# Delete class
{
self$node$theta$erase(s_old)
nk <-
nk[names(nk) != s_old]
nk <- nk[names(nk) != s_old]
if (any(is.na(nk)))
browser()
}
nk[self$node$s[i]] <-
nk[self$node$s[i]] + 1
nk[self$node$s[i]] <- nk[self$node$s[i]] + 1
}
# Update theta -- sapply here?
......@@ -427,9 +414,7 @@ PoisGammaDPUpdate <- R6::R6Class(
{
sumy <- sum(y[self$node$s == label])
aXs <- sum(aX[self$node$s == label])
self$node$theta$insert(label, rgamma(1,
self$alpha + sumy,
self$beta + aXs))
self$node$theta$insert(label, rgamma(1, shape = self$alpha + sumy, rate = self$beta + aXs))
}
}
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment