Commit 82964bde authored by Poppy Miller's avatar Poppy Miller
Browse files

Changed init r fropm requiring initial number of cases to actual rs

parent 604ea30f
......@@ -39,11 +39,15 @@
#'
#' \code{q} (a data frame with columns names \code{Value} contining the initial values and \code{Type}), and
#'
#' \code{r} (a data frame with the same format as the data excluding the \code{Human} column,
#' with the initial values in the source columns). }
#' \code{r} (a data frame with columns:
#' a column with the initial r values named \code{Value}
#' (note these must sum to 1 for each source-time combination),
#' a column with the source id's named \code{Source},
#' a column with the time id's named \code{Time},
#' a column with the type id's named \code{Type}.)}
#'
#' \item{\code{fit_params(n_iter = 1000, burn_in = 0, thin = 1,
#' n_r = private$nTypes \%/\% 20, params_fix = NULL)}}{when called, sets the mcmc
#' n_r = ceiling(private$nTypes * 0.2), params_fix = NULL)}}{when called, sets the mcmc
#' parameters.
#'
#' \code{n_iter} sets the number of iterations returned (after removing
......@@ -674,7 +678,8 @@ HaldDP <- R6::R6Class(
if (dim(unique(priors$a_r[, c("Type", "Time", "Source")]))[1] != (private$nTypes * private$nTimes * private$nSources))
stop("priors$a_r must have a single number for each time, type and source combination.")
if (!all(private$test_integer(priors$a_r$Value)) | !all(priors$a_r$Value >= 0)) stop("priors$a_r$Value must contain only positive numbers")
# !all(private$test_integer(priors$a_r$Value)) |
if (!all(priors$a_r$Value >= 0)) stop("priors$a_r$Value must contain only positive numbers")
} else {
if (!isTRUE(length(priors$a_r) == 1) | !private$test_positive_number(priors$a_r)) stop("priors$a_r must be a data frame or a single number.")
}
......@@ -759,7 +764,10 @@ HaldDP <- R6::R6Class(
## r values
if (!("r" %in% names(inits))) {
## default is the source data matrix
inits$r <- private$X
inits$r <- private$X + 0.000001 # Added for numeric stability -- only affects starting values for r
for (times in 1:private$nTimes) {
inits$r[, , times] <- apply(inits$r[, , times], 2, function(x) x / sum(x))
}
} else {
if (!is.data.frame(inits$r))
stop("inits$r must be a data frame.")
......@@ -776,7 +784,7 @@ HaldDP <- R6::R6Class(
dim(inits$r)[1] != (private$nTypes * private$nSources * private$nTimes))
stop("inits$r must be a data frame with columns called Type, Source, Time and Value with one row per combination of type, source and time.")
if (!all(private$test_integer(inits$r$Value)) | !all(inits$r$Value >= 0)) stop("inits$r$Value must contain only positive numbers")
if (!all(is.finite(inits$r$Value)) | !all(inits$r$Value > 0) | !all(inits$r$Value < 1)) stop("inits$r$Value must contain only numbers between 0 and 1.")
inits_r <- array(
NA, dim = c(private$nTypes, private$nSources, private$nTimes),
......@@ -799,7 +807,8 @@ HaldDP <- R6::R6Class(
which(private$namesTimes == time)] <- tmp_init_r
}
## check the initial values sum to 1 within each time and location
# if (isTRUE(all.equal(sum(inits_r[, private$namesSources == sources, private$namesTimes == time]), 1))) stop("inits$r must sum to 1 within each time and location.")
if (!isTRUE(all.equal(sum(inits_r[, private$namesSources == sources, private$namesTimes == time]), 1, tol = 0.000001)))
stop("inits$r must sum to 1 within each time and source.")
}
}
inits$r <- inits_r
......@@ -1349,7 +1358,7 @@ HaldDP <- R6::R6Class(
},
fit_params = function(n_iter = 1000, burn_in = 0, thin = 1, # FINISHED
n_r = private$nTypes %/% 20, params_fix = NULL)
n_r = ceiling(private$nTypes * 0.2), params_fix = NULL)
{
private$set_append(FALSE)
private$set_niter(n_iter)
......
......@@ -78,10 +78,8 @@ DPModel_impl <- R6::R6Class(
self$rNodes[[time]] <- list()
for (src in 1:length(Sources)) {
# Dirichlet prior on r, as a result of Dirichlet/Multinomial conjugacy on R.
a_r_full <- DataNode$new(data = a_r[, src, time] + X[, src, time], name = paste('a_r', time, src, sep = '_')) # Todo: Prior here
xcol <- R[, src, time] + 0.000001 # Added for numeric stability -- only affects starting values for r
xcol <- xcol/sum(xcol)
self$rNodes[[time]][[src]] <- DirichletNode$new(data = setNames(xcol, Type),
a_r_full <- DataNode$new(data = a_r[, src, time] + X[, src, time], name = paste('a_r', time, src, sep = '_'))
self$rNodes[[time]][[src]] <- DirichletNode$new(data = setNames(R[, src, time], Type),
alpha = a_r_full,
name = paste('r', time, src, sep = '_'))
}
......@@ -99,7 +97,7 @@ DPModel_impl <- R6::R6Class(
# Location specific alpha
alpha_tl <- DirichletNode$new(data = setNames(alpha[, time, location], Sources),
alpha = a_tl,
name = paste('alpha', time, location, sep = '_')) # Prior here
name = paste('alpha', time, location, sep = '_'))
# Construct the lambda_i node
lambdaPrime <- self$LambdaNode$new(k = k,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment