Commit 14c03ff9 authored by Chris Jewell's avatar Chris Jewell
Browse files

Completed new data handling architecture. Need to propagate to prior...

Completed new data handling architecture.  Need to propagate to prior parameters and initial values.
parent 85ffdb79
......@@ -16,7 +16,7 @@ Data_ = R6::R6Class(
{
dn = dimnames(self$x)
self$x = do.call('[', c(list(self$x), lapply(dn, function(a)
gtools::mixedorder(a))))
gtools::mixedorder(a)), list(drop=FALSE)))
}
)
)
......@@ -40,6 +40,9 @@ Y_ = R6::R6Class('Y',
}
data = data[, c(y, type, time, location)]
names(data) = c('y', 'type', 'time', 'location')
data$type = as.character(data$type)
data$time = as.character(data$time)
data$location = as.character(data$location)
self$x = tryCatch(
reshape2::acast(data, type ~ time ~ location, value.var = 'y',drop=F),
condition = function(c)
......@@ -70,7 +73,7 @@ Y_ = R6::R6Class('Y',
#' @return A Y disease count data structure for use in sourceR models
#' @export
Y = function(data, y, type, time = NULL, location = NULL)
Y_$new(data, y, type, time = NULL, location = NULL)
Y_$new(data, y, type, time, location)
X_ = R6::R6Class('X',
......@@ -84,7 +87,9 @@ X_ = R6::R6Class('X',
}
data = data[, c(x, type, source, time)]
names(data) = c('x', 'type', 'source', 'time')
data$type = as.character(data$type)
data$source = as.character(data$source)
data$time = as.character(data$time)
self$x = tryCatch(
reshape2::acast(data, type ~ source ~ time, value.var = 'x', drop=F),
condition = function(c)
......@@ -128,19 +133,20 @@ Prev_ = R6::R6Class('Prev',
private = list(
pack = function(data, prev, source, time = NULL)
{
if (is.null(time))
if (is.null(time)) {
data$time = rep('1', nrow(data))
data[, c(prev, source, time)]
time = 'time'
}
data = data[, c(prev, source, time)]
names(data) = c('prev', 'source', 'time')
data$source = as.character(data$source)
data$time = as.character(data$time)
self$x = tryCatch(
acast(data, source ~ time, value.var = 'prev', drop=F),
reshape2::acast(data, source ~ time, value.var = 'prev', drop=F),
condition = function(c)
stop('Prevalence must have one value per source/time')
)
names(dimnames(self$x)) = c('source', 'time')
self$x = self$x[order(dimnames(self$x)$source),
order(dimnames(self$x)$time)]
},
check = function()
{
......
......@@ -408,12 +408,12 @@ HaldDP <- R6::R6Class(
private$nTypes = dim(private$y)[1]
private$nTimes = dim(private$y)[2]
private$nLocations = dim(private$y)[3]
private$nSources = dim(private$x)[2]
private$nSources = dim(private$X)[2]
private$namesTypes = dimnames(private$y)$type
private$namesTimes = dimnames(private$y)$time
private$namesLocations = dimnames(private$y)$location
private$namesSources = dimnames(private$x)$source
private$namesSources = dimnames(private$X)$source
},
set_a_q = function(a_q)
{
......@@ -472,7 +472,6 @@ HaldDP <- R6::R6Class(
!isFinitePositive(priors$a_r))
stop("priors$a_r must be a data frame or a single number.")
}
browser()
a_r <- array(
NA,
dim = c(private$nTypes, private$nSources, private$nTimes),
......@@ -580,27 +579,27 @@ HaldDP <- R6::R6Class(
## r values
if (!("r" %in% names(inits))) {
## default is the source data matrix (plus stability factor)
inits$r = apply(private$X + 1e-6, c('Source', 'Time'), function(x)
x / sum(x)) %>% aperm(c('Type', 'Source', 'Time'))
inits$r = apply(private$X + 1e-6, c('source', 'time'), function(x)
x / sum(x)) %>% aperm(c('type', 'source', 'time'))
} else {
if (!is.data.frame(inits$r))
stop("inits$r must be a data frame.")
if (!all(c("Type", "Source", "Time", "Value") %in% colnames(inits$r)))
stop("inits$r must be a data frame with columns called Type, Source, Time and Value")
if (!all(c("type", "source", "time", "value") %in% colnames(inits$r)))
stop("inits$r must be a data frame with columns called type, source, time and value")
if (!setequal(unique(inits$r$Type), private$namesTypes) |
!setequal(unique(inits$r$Time), private$namesTimes) |
!setequal(unique(inits$r$Source), private$namesSources))
stop(
"inits$r must be a data frame with columns called Type, Source, Time and Value with one row per combination of type, source and time."
"inits$r must be a data frame with columns called type, source, time and value with one row per combination of type, source and time."
)
initr_r = tryCatch(
acast(inits$r, Type ~ Source ~ Time, value.var = 'Value'),
condition=function(c) stop("inits$r must have a non-NA row for each Type/Time/Source combination.")
)
names(dimnames(initr_r)) = c('Type','Source','Time')
sumToUnit = apply(initr_r, c('Type','Time'), sum)
names(dimnames(initr_r)) = c('type','source','time')
sumToUnit = apply(initr_r, c('type','time'), sum)
if(!isTRUE(all.equal(sumToUnit, 1, tol=1e-5)))
stop("inits$r must sum to 1 within each time and source")
inits$r <- inits_r
......@@ -629,8 +628,8 @@ HaldDP <- R6::R6Class(
if (!is.data.frame(inits$alpha))
stop("inits$alpha must be a data frame.")
inits$alpha <- na.omit(inits$alpha)
if (!all(c("Location", "Source", "Time", "Value") %in% colnames(inits$alpha)))
stop("inits$alpha must be a data frame with columns called Location, Source, Time and Value")
if (!all(c("location", "source", "time", "value") %in% colnames(inits$alpha)))
stop("inits$alpha must be a data frame with columns called location, source, time and value")
inits$alpha$Location <- as.factor(inits$alpha$Location)
inits$alpha$Source <- as.factor(inits$alpha$Source)
......@@ -647,7 +646,7 @@ HaldDP <- R6::R6Class(
))) == private$namesSources) |
dim(inits$alpha)[1] != (private$nLocations * private$nSources * private$nTimes))
stop(
"inits$alpha must be a data frame with columns called Location, Source, Time and Value with one row per combination of location, source and time."
"inits$alpha must be a data frame with columns called location, source, time and value with one row per combination of location, source and time."
)
if (!all(isFinitePositive(inits$alpha$Value)))
stop("inits$alpha$Value must contain only positive numbers")
......@@ -702,15 +701,15 @@ HaldDP <- R6::R6Class(
if (!is.data.frame(inits$q))
stop("inits$q must be a data frame.")
inits$q <- na.omit(inits$q)
if (!all(c("Type", "Value") %in% colnames(inits$q)))
stop("inits$q must be a data frame with columns called Type and Value")
if (!all(c("type", "value") %in% colnames(inits$q)))
stop("inits$q must be a data frame with columns called type and value")
inits$q$Type <- as.factor(inits$q$Type)
if (!all(paste(gtools::mixedsort(unique(inits$q$Type))) == private$namesTypes) |
dim(inits$q)[1] != private$nTypes)
stop("inits$q must be a data frame with columns called Type and Value with one row per type.")
stop("inits$q must be a data frame with columns called type and value with one row per type.")
if (!all(isFinitePositive(inits$q$Value)))
stop("inits$q$Value must contain only positive numbers.")
stop("inits$q$value must contain only positive numbers.")
inits$q <- inits$q[gtools::mixedsort(inits$q$Type),]
inits$theta <- sort(unique(inits$q$Value))
......@@ -1263,7 +1262,7 @@ HaldDP <- R6::R6Class(
})
if (isTRUE(flatten)) {
res = lapply(res, melt)
res = lapply(res, reshape2::melt)
}
res
},
......
......@@ -66,15 +66,10 @@ DPModel_impl <- R6::R6Class(
initialize = function(y, X, R, Time, Location, Sources, Type, prev, a_q, a_theta,
b_theta, a_r, a_alpha, s, theta, alpha)
{
## test y column is correct
# merge(data.frame(Human = y[, 1, 1], Type = dimnames(y)$type), data.frame(Human2 = campy$Human, Type = as.character(campy$Type)))
## Test x column is correct
# merge(as.data.frame(cbind(X[,,1], Type = dimnames(X)$type)), campy[, c(2:7, 10)])
self$qNodes <- DirichletProcessNode$new(theta = theta, s = s, alpha = a_q,
base = dgamma, shape = a_theta,
rate = b_theta, name = 'q')
# Node lists
self$alphaNodes <- list()
self$yNodes <- list()
......
......@@ -69,7 +69,7 @@ priors <-
test_that("HaldDP model construction", {
set.seed(1)
y = Y(dat, y = 'Human', type = 'Type')
src = melt(
src = reshape2::melt(
dat,
id.vars = c('Time', 'Type'),
measure.vars = c(
......@@ -96,7 +96,7 @@ test_that("HaldDP model construction", {
priors = priors,
a_q = 0.1
)
model$fit_params(n_iter = 100,
model$mcmc_params(n_iter = 100,
burn_in = 0,
thin = 1)
# Test model barfs if summary called on empty posterior
......@@ -107,7 +107,7 @@ test_that("HaldDP model construction", {
expect_equal_to_reference(model$summary(), "haldDPres2.rds")
expect_equal_to_reference(model$extract('lambda_j'), "haldDPlambdaj.rds")
expect_equal_to_reference(model$extract(flatten = TRUE), 'haldDPFlat.rds')
model$fit_params(n_iter = 100,
model$mcmc_params(n_iter = 100,
burn_in = 10,
thin = 5)
expect_error(model$summary())
......
......@@ -13,13 +13,33 @@ test_that("Test time/location data structures", {
set.seed(1)
ss = group_by(sim_SA_data, Time, Location) %>%
slice(1:10) %>% ungroup %>% as.data.frame
src = reshape2::melt(
ss[ss$Location == 'A', c(1, 2:9)],
id.vars = c('Type', 'Time'),
measure.vars = c(
'Source1',
'Source2',
'Source3',
'Source4',
'Source5',
'Source6'
),
variable.name = 'Source',
value.name = 'Count'
)
y = Y(data=ss, y = 'Human', type='Type', time = 'Time', location = 'Location')
x = X(data=src, x='Count', type='Type', time='Time', source='Source')
k = Prev(sim_SA_prev, prev='Value', time='Time', source='Source')
model = HaldDP$new(
data = ss,
k = sim_SA_prev,
y = y,
x = x,
k = k,
priors = priors,
a_q = 0.1
)
model$fit_params(n_iter = 100,
model$mcmc_params(n_iter = 100,
burn_in = 0,
thin = 1)
model$update()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment