Commit 03bd651b authored by Chris Jewell's avatar Chris Jewell
Browse files

Added abstracted functions to make the main HaldDP class more sane in length!

parent e5ba6da3
Acceptance = R6::R6Class(
"acceptance",
public = list(
alpha = NULL,
r = NULL,
initialize = function(nSources = private$nSources,
nTimes = private$nTimes,
nLocations = private$nLocations,
nTypes = private$nTypes,
namesSources = private$namesSources,
namesTimes = private$namesTimes,
namesLocations = private$namesLocations,
namesTypes = private$namesTypes) {
self$alpha <- array(
dim = c(nSources,
nTimes,
nLocations),
dimnames = list(
source = namesSources,
time = namesTimes,
location = namesLocations
)
)
self$r <- array(
dim = c(nTypes,
nSources,
nTimes),
dimnames = list(
type = namesTypes, source = namesSources, time = namesTimes
)
)
}
)
)
#####################################################
# Name: credibleIntervals.R #
# Author: Chris Jewell <c.jewell@lancaster.ac.uk> #
# Created: 20161206 #
# Copyright: Chris Jewell 2016 #
# Purpose: Implements credible interval calculation #
#####################################################
ci_chenShao = function(x, alpha) {
n <- length(x)
sorted <- sort(x)
upper_pos <- round(n * (1 - alpha))
ci.lower <- 0
ci.upper <- 0
region_size <- Inf
## checks all intervals that are n*alpha apart, and chooses the shortest one.
for (i in 1:(n - upper_pos)) {
test_interval <- sorted[upper_pos + i] - sorted[i]
if (test_interval < region_size) {
region_size <- test_interval
ci.lower <- sorted[i]
ci.upper <- sorted[upper_pos + i]
}
}
return(c(
median = median(sorted),
lower = ci.lower,
upper = ci.upper
))
}
ci_percentiles <- function(x, alpha) {
n <- length(x)
sorted <- sort(x)
upper_pos <- round(n * (1 - (alpha / 2)))
lower_pos <- round(n * (alpha / 2))
return(c(
median = median(x),
lower = sorted[lower_pos],
upper = sorted[upper_pos]
))
}
ci_SPIn <- function(x, alpha) {
region <- tryCatch({
SPIn(x, conf = 1 - alpha)$spin
},
error = function(cond) {
print("Error calculating SPIn interval.")
return(c(NA, NA))
})
return(c(
median = median(x),
lower = region[1],
upper = region[2]
))
}
#####################################################
# Name: heatmap.R #
# Author: Chris Jewell <c.jewell@lancaster.ac.uk> #
# Created: 20161206 #
# Copyright: Chris Jewell 2016 #
# Purpose: Draws a clustered heatmap #
#####################################################
are_colours <- function(x) {
sapply(x, function(X) {
tryCatch(
is.matrix(col2rgb(X)),
error = function(e)
FALSE
)
})
}
clusterHeatMap <- function(x, cols, xnames = 1:length(x), hclust_method) {
# Check colours
if (length(cols) != 2 |
!mode(cols) %in% c("character") | !all(are_colours(cols))) {
message(
"The argument cols contain colours that are not valid. The defaults will be used instead."
)
cols <- c("blue", "white")
}
groups <-
as.data.frame(apply(groups, 2, function(x)
as.factor(x)))
# compute dissimilarity matrix for the type effect clusters
disim_clust_g <- cluster::daisy(groups)
clu <-
stats::hclust(disim_clust_g, hclust_method) # default method is complete
dend <- stats::as.dendrogram(clu)
# OPTIONAL: change the colour of the heatmap. The lighter the colour
# (when using the default white blue colour scheme),
# the higher the dissimilarity between the 2 types (i.e. the less
# often two type effects are assigned to the same group in the mcmc)
hmcols <- colorRampPalette(cols)(299)
heatmap_data <- as.matrix(disim_clust_g)
rownames(heatmap_data) <- colnames(heatmap_data) <- xnames
gplots::heatmap.2(
heatmap_data,
density.info = "none",
# turns off density plot in the legend
trace = "none",
# turns off trace lines in the heat map
col = hmcols,
# use color palette defined earlier
dendrogram = "col",
# only draw a row dendrogram
Colv = dend,
Rowv = dend,
symm = TRUE,
key = F
)
}
Posterior = R6::R6Class(
"Posterior",
public = list(
q = NULL,
s = NULL,
alpha = NULL,
r = NULL,
lambda_i = NULL,
lambda_j = NULL,
lambda_j_prop = NULL,
initialize = function(nSources, nTimes,
nLocations, nTypes, n_iter,
namesSources, namesTimes,
namesLocations, namesTypes,
namesIters) {
self$alpha <- array(
dim = c(nSources,
nTimes,
nLocations,
n_iter),
dimnames = list(
source = namesSources,
time = namesTimes,
location = namesLocations,
iter = namesIters
)
)
self$q <- array(
dim = c(
nTypes,
n_iter
),
dimnames = list(
type = namesTypes, iter = namesIters
)
)
self$s <- array(
dim = c(
nTypes,
n_iter
),
dimnames = list(
type = namesTypes, iter = namesIters
)
)
self$r <- array(
dim = c(nTypes,
nSources,
nTimes,
n_iter),
dimnames = list(
type = namesTypes, source = namesSources, time = namesTimes, iter = namesIters
)
)
},
calc_lambda_i = function(n_iter, nTimes,
nLocations, nTypes,
namesTimes,
namesLocations, namesTypes,
namesIters,
k) {
self$lambda_i <- array(
dim = c(nTypes,
nTimes,
nLocations,
n_iter),
dimnames = list(
type = namesTypes,
time = namesTimes,
location = namesLocations,
iter = namesIters
)
)
if (!is.null(self$q)) { # don't try to calculate before any iterations have occured
for (i in 1:n_iter) {
for (times in 1:nTimes) {
for (locations in 1:nLocations) {
self$lambda_i[, times, locations, i] <- self$q[, i] * self$r[, , times, i] %*% (k[, times] * self$alpha[, times, locations, i])
}
}
}
} else {
self$lambda_i <- NULL
}
},
calc_lambda_j = function(n_iter, nSources, nTimes,
nLocations,
namesSources, namesTimes,
namesLocations,
namesIters,
k) {
self$lambda_j <- array(
dim = c(nSources,
nTimes,
nLocations,
n_iter),
dimnames = list(
source = namesSources,
time = namesTimes,
location = namesLocations,
iter = namesIters
)
)
if (!is.null(self$q)) { # don't try to calculate before any iterations have occured
for (i in 1:n_iter) {
for (times in 1:nTimes) {
for (locations in 1:nLocations) {
self$lambda_j[, times, locations, i] <- self$alpha[, times, locations, i] * colSums(self$r[, , times, i] * self$q[, i]) * k[, times]
}
}
}
} else {
self$lambda_j <- NULL
}
},
calc_lambda_j_prop = function() {
# TODO: check with multiple times and locations
if (is.null(self$lambda_j)) self$calc_lambda_j()
self$lambda_j_prop <- self$lambda_j
for (iter in 1:dim(self$lambda_j)[4]) {
self$lambda_j_prop[,,,iter] <- self$lambda_j[,,,iter] / sum(self$lambda_j[,,,iter])
}
}
)
)
#####################################################
# Name: utils.R #
# Author: Chris Jewell <c.jewell@lancaster.ac.uk> #
# Created: 20161206 #
# Copyright: Chris Jewell 2016 #
# Purpose: Miscellaneous helper functions #
#####################################################
isFiniteInteger = function(a)
{
return(is.finite(a) & isTRUE(all.equal(a, as.integer(a))))
}
isFiniteLogical = function(a)
{
return(is.finite(a) & is.logical(a))
}
isFinitePositive = function(a)
# FINISHED
{
return(is.finite(a) & a > 0)
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment