Commit 8ab72778 authored by Chris Jewell's avatar Chris Jewell
Browse files

Changed 'print_*' functions to 'get_*'.

parent 5463f4c3
......@@ -140,27 +140,27 @@
#' \code{burn_in} is removed. Running the model for the first time, or changing any
#' model or fitting parameters will set \code{append = FALSE}. }
#'
#' \item{\code{print_data}}{returns a list containing the human data \code{y}
#' \item{\code{get_data}}{returns a list containing the human data \code{y}
#' (an array y[types, times, locations]), the source data \code{X} (an array X[types, sources, times]),
#' the prevalence data (an array k[sources, times]), the type names, source names,
#' time names, location names and number of different types, sources, times and locations.
#' }
#'
#' \item{\code{print_priors}}{returns a list containing the DP concentration
#' \item{\code{get_priors}}{returns a list containing the DP concentration
#' parameter \code{a_q}, and the priors (R6 class with members named \code{a_alpha}
#' (members are array \code{a_alpha[sources, times, locations]}), \code{a_r} (an array
#' \code{a_r[types, sources, times]}), \code{a_theta} and \code{b_theta}).}
#'
#' \item{\code{print_inits}}{returns an R6 class holding the initial values
#' \item{\code{get_inits}}{returns an R6 class holding the initial values
#' (members are \code{alpha} (an array \code{alpha[sources, times, locations]}),
#' \code{theta} (an array \code{theta[types, iters]}), \code{s} (an array
#' \code{s[types, iters]}), and \code{r} (an array \code{r[types, sources, times]})).}
#'
#' \item{\code{print_mcmc_params}}{returns a list of fitting parameters (\code{n_iter},
#' \item{\code{get_mcmc_params}}{returns a list of fitting parameters (\code{n_iter},
#' \code{append}, \code{burn_in}, \code{thin}, \code{params_fix} (R6 class with members
#' \code{alpha}, \code{q}, \code{r})).}
#'
#' \item{\code{print_acceptance}}{returns an R6 class containing the acceptance
#' \item{\code{get_acceptance}}{returns an R6 class containing the acceptance
#' rates for each parameter (members are \code{alpha} (an array \code{alpha[sources, times, locations]}),
#' and \code{r} (an array \code{r[types, sources, times]})).}
#'
......@@ -252,11 +252,11 @@
#' res$mcmc_params(n_iter = 100, burn_in = 10, thin = 1)
#' res$update()
#'
#' dat <- res$print_data()
#' init <- res$print_inits()
#' prior <- res$print_priors()
#' acceptance <- res$print_acceptance()
#' mcmc_params <- res$print_mcmc_params()
#' dat <- res$get_data()
#' init <- res$get_inits()
#' prior <- res$get_priors()
#' acceptance <- res$get_acceptance()
#' mcmc_params <- res$get_mcmc_params()
#'
#' res$plot_heatmap(iters = 10:100, hclust_method = "complete")
#'
......@@ -394,13 +394,13 @@ HaldDP_ <- R6::R6Class(
checkin_data = function(y,x,k)
{
# Check dimension mappings
if(!identical(dimnames(y)$type, dimnames(x)$type))
if(!identical(dimnames(y$x)$type, dimnames(x$x)$type))
stop("Types in x and y do not match")
if(!identical(dimnames(y)$time, dimnames(x)$time))
if(!identical(dimnames(y$x)$time, dimnames(x$x)$time))
stop("Times in x and y do not match")
if(!identical(dimnames(x)$time, dimnames(k)$time))
if(!identical(dimnames(x$x)$time, dimnames(k$x)$time))
stop("Times in x and k do not match")
if(!identical(dimnames(x)$source, dimnames(k)$source))
if(!identical(dimnames(x$x)$source, dimnames(k$x)$source))
stop("Sources in x and k do not match")
private$y = y$x
......@@ -994,59 +994,6 @@ HaldDP_ <- R6::R6Class(
flatten = flatten
)
)
},
calc_acceptance = function()
{
## mcmc is finished, save and print acceptance rate summary for r and alpha
private$acceptance <-
Acceptance$new(
nSources = private$nSources,
nTimes = private$nTimes,
nLocations = private$nLocations,
nTypes = private$nTypes,
namesSources = private$namesSources,
namesTimes = private$namesTimes,
namesLocations = private$namesLocations,
namesTypes = private$namesTypes,
updateSchema = private$update_schema
)
sapply(1:length(private$updaters), function(i) {
tryCatch({
acceptances <- private$updaters[[i]]$acceptance()
names <- private$updaters[[i]]$name
tmp <- strsplit(names, split = " ")[[1]]
if (tmp[1] == "r:") {
t_name <- substr(tmp[3], 1, nchar(tmp[3]) - 1)
s_name <- tmp[5]
private$acceptance$r[,
which(s_name == dimnames(private$acceptance$r)$source),
which(t_name == dimnames(private$acceptance$r)$time)] <-
acceptances
} else if (tmp[1] == "alpha:") {
t_name <- substr(tmp[3], 1, nchar(tmp[3]) - 1)
l_name <- tmp[5]
private$acceptance$alpha[,
which(t_name == dimnames(private$acceptance$alpha)$time),
which(l_name == dimnames(private$acceptance$alpha)$location)] <-
acceptances
}
},
error = function(e) {
NULL
})
})
if ('alpha' %in% private$update_schema) {
cat("\nalpha acceptance: \n")
print(private$acceptance$alpha)
}
if ('r' %in% private$update_schema) {
cat("\nr acceptance: \n")
print(private$acceptance$r)
}
}
),
public = list(
......@@ -1141,42 +1088,28 @@ HaldDP_ <- R6::R6Class(
private$total_iters = private$total_iters + 1
setTxtProgressBar(pb, i)
}
private$calc_acceptance()
},
## Functions to access the data and results
print_data = function()
get_data = function()
{
return(
list(
y = private$y,
X = private$X,
k = private$k,
namesType = private$namesTypes,
namesSource = private$namesSources,
namesTime = private$namesTimes,
namesLocation = private$namesLocations,
nTypes = private$nTypes,
nSources = private$nSources,
nTimes = private$nTimes,
nLocations = private$nLocations
k = private$k
)
)
},
print_priors = function()
get_priors = function()
{
return(list(priors = private$priors,
a_q = private$a_q))
list(priors = private$priors,
a_q = private$a_q)
},
print_inits = function()
get_inits = function()
{
return(private$inits)
private$inits
},
print_mcmc_params = function()
get_mcmc_params = function()
{
return(
list(
n_iter = private$n_iter,
append = private$append,
......@@ -1184,11 +1117,52 @@ HaldDP_ <- R6::R6Class(
thin = private$thin,
params_fix = private$params_fix
)
)
},
print_acceptance = function()
get_acceptance = function()
{
return(private$acceptance)
## mcmc is finished, save acceptance rate summary for r and alpha
acceptance =
Acceptance$new(
nSources = private$nSources,
nTimes = private$nTimes,
nLocations = private$nLocations,
nTypes = private$nTypes,
namesSources = private$namesSources,
namesTimes = private$namesTimes,
namesLocations = private$namesLocations,
namesTypes = private$namesTypes,
updateSchema = private$update_schema
)
sapply(1:length(private$updaters), function(i) {
tryCatch({
acceptances <- private$updaters[[i]]$acceptance()
names <- private$updaters[[i]]$name
tmp <- strsplit(names, split = " ")[[1]]
if (tmp[1] == "r:") {
t_name <- substr(tmp[3], 1, nchar(tmp[3]) - 1)
s_name <- tmp[5]
acceptance$r[,
which(s_name == dimnames(acceptance$r)$source),
which(t_name == dimnames(acceptance$r)$time)] <-
acceptances
} else if (tmp[1] == "alpha:") {
t_name <- substr(tmp[3], 1, nchar(tmp[3]) - 1)
l_name <- tmp[5]
acceptance$alpha[,
which(t_name == dimnames(acceptance$alpha)$time),
which(l_name == dimnames(acceptance$alpha)$location)] <-
acceptances
}
},
error = function(e) {
NULL
})
})
list(alpha=acceptance$alpha, r=acceptance$r)
},
extract = function(params = c("alpha", "q", "s", "r", "lambda_i", "lambda_j", "lambda_j_prop"),
times = NULL,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment