Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Chris Jewell
covid19uk
Commits
552b1d03
Commit
552b1d03
authored
Jul 09, 2020
by
Chris Jewell
Browse files
Stochastic simulation model update
parent
0b337d47
Changes
6
Hide whitespace changes
Inline
Side-by-side
covid/impl/occult_events_mh.py
View file @
552b1d03
...
...
@@ -41,12 +41,20 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
"""UncalibratedEventTimesUpdate"""
def
__init__
(
self
,
target_log_prob_fn
,
target_event_id
,
nmax
,
seed
=
None
,
name
=
None
,
self
,
target_log_prob_fn
,
target_event_id
,
nmax
,
t_range
=
None
,
seed
=
None
,
name
=
None
,
):
"""An uncalibrated random walk for event times.
:param target_log_prob_fn: the log density of the target distribution
:param target_event_id: the position in the last dimension of the events
tensor that we wish to move
:param t_range: a tuple containing earliest and latest times between which
to update occults.
:param seed: a random seed
:param name: the name of the update step
"""
...
...
@@ -57,6 +65,7 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
target_log_prob_fn
=
target_log_prob_fn
,
target_event_id
=
target_event_id
,
nmax
=
nmax
,
t_range
=
t_range
,
seed
=
seed
,
name
=
name
,
)
...
...
@@ -101,7 +110,9 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
def
true_fn
():
with
tf
.
name_scope
(
"true_fn"
):
proposal
=
AddOccultProposal
(
current_events
,
self
.
parameters
[
"nmax"
]
events
=
current_events
,
n_max
=
self
.
parameters
[
"nmax"
],
t_range
=
self
.
parameters
[
"t_range"
],
)
update
=
proposal
.
sample
()
next_state
=
_add_events
(
...
...
@@ -128,7 +139,11 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
x
=
[
self
.
tx_topology
.
target
],
x_star
=
tf
.
cast
(
-
update
[
"x_star"
],
current_events
.
dtype
),
)
reverse
=
AddOccultProposal
(
next_state
,
self
.
parameters
[
"nmax"
])
reverse
=
AddOccultProposal
(
events
=
next_state
,
n_max
=
self
.
parameters
[
"nmax"
],
t_range
=
self
.
parameters
[
"t_range"
],
)
q_fwd
=
tf
.
reduce_sum
(
proposal
.
log_prob
(
update
))
q_rev
=
tf
.
reduce_sum
(
reverse
.
log_prob
(
update
))
log_acceptance_correction
=
q_rev
-
q_fwd
...
...
covid/impl/occult_proposal.py
View file @
552b1d03
...
...
@@ -7,7 +7,10 @@ from covid.impl.Categorical2 import Categorical2
tfd
=
tfp
.
distributions
def
AddOccultProposal
(
events
,
n_max
,
dtype
=
tf
.
int32
,
name
=
None
):
def
AddOccultProposal
(
events
,
n_max
,
t_range
=
None
,
dtype
=
tf
.
int32
,
name
=
None
):
if
t_range
is
None
:
t_range
=
[
0
,
events
.
shape
[
-
2
]]
def
m
():
"""Select a metapopulation"""
with
tf
.
name_scope
(
"m"
):
...
...
@@ -16,7 +19,7 @@ def AddOccultProposal(events, n_max, dtype=tf.int32, name=None):
def
t
():
"""Select a timepoint"""
with
tf
.
name_scope
(
"t"
):
return
UniformInteger
(
low
=
[
0
],
high
=
[
events
.
shap
e
[
1
]],
dtype
=
dtype
)
return
UniformInteger
(
low
=
[
t_range
[
0
]],
high
=
[
t_rang
e
[
1
]],
dtype
=
dtype
)
def
x_star
():
"""Draw num to add"""
...
...
covid/model.py
View file @
552b1d03
...
...
@@ -155,7 +155,7 @@ class CovidUKStochastic(CovidUK):
*
commute_volume
*
tf
.
linalg
.
matvec
(
self
.
C
,
state
[...,
2
]
/
self
.
N
)
)
infec_rate
=
infec_rate
/
self
.
N
# Vector of length nc
infec_rate
=
infec_rate
/
self
.
N
+
0.00000001
# Vector of length nc
ei
=
tf
.
broadcast_to
(
[
param
[
"nu"
]],
shape
=
[
state
.
shape
[
0
]]
...
...
covid_stochastic.py
View file @
552b1d03
...
...
@@ -5,9 +5,6 @@ import pickle as pkl
import
tensorflow
as
tf
import
tensorflow_probability
as
tfp
tfd
=
tfp
.
distributions
tfb
=
tfp
.
bijectors
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
yaml
...
...
@@ -15,6 +12,11 @@ import yaml
from
covid.model
import
CovidUKStochastic
,
load_data
from
covid.util
import
sanitise_parameter
,
sanitise_settings
,
seed_areas
tfd
=
tfp
.
distributions
tfb
=
tfp
.
bijectors
DTYPE
=
np
.
float64
...
...
@@ -157,81 +159,63 @@ def plot_age_attack_rate(ax, sim, N, label):
ax
.
plot
(
data
[
"age_groups"
],
attack_rate
,
"o-"
,
label
=
label
)
if
__name__
==
"__main__"
:
parser
=
optparse
.
OptionParser
()
parser
.
add_option
(
"--config"
,
"-c"
,
dest
=
"config"
,
default
=
"ode_config.yaml"
,
help
=
"configuration file"
,
)
options
,
args
=
parser
.
parse_args
()
with
open
(
options
.
config
,
"r"
)
as
ymlfile
:
config
=
yaml
.
load
(
ymlfile
)
param
=
sanitise_parameter
(
config
[
"parameter"
])
settings
=
sanitise_settings
(
config
[
"settings"
])
parser
=
optparse
.
OptionParser
()
parser
.
add_option
(
"--config"
,
"-c"
,
dest
=
"config"
,
default
=
"ode_config.yaml"
,
help
=
"configuration file"
,
)
options
,
args
=
parser
.
parse_args
()
with
open
(
options
.
config
,
"r"
)
as
ymlfile
:
config
=
yaml
.
load
(
ymlfile
)
param
=
sanitise_parameter
(
config
[
"parameter"
])
settings
=
sanitise_settings
(
config
[
"settings"
])
data
=
load_data
(
config
[
"data"
],
settings
,
DTYPE
)
data
[
"pop"
]
=
data
[
"pop"
].
sum
(
level
=
0
)
model
=
CovidUKStochastic
(
C
=
data
[
"C"
],
N
=
data
[
"pop"
][
"n"
].
to_numpy
(),
W
=
data
[
"W"
],
date_range
=
settings
[
"prediction_period"
],
holidays
=
settings
[
"holiday"
],
lockdown
=
settings
[
"lockdown"
],
time_step
=
1.0
,
)
# seeding = seed_areas(data['pop']['n'].to_numpy(), data['pop']['Area.name.2']) # Seed 40-44 age group, 30 seeds by popn size
# seeding = tf.one_hot(tf.squeeze(tf.where(data['pop'].index=='E09000008')), depth=data['pop'].size, dtype=DTYPE)
seeding
=
tf
.
one_hot
(
58
,
depth
=
model
.
N
.
shape
[
0
],
dtype
=
DTYPE
)
# Manchester
state_init
=
model
.
create_initial_state
(
init_matrix
=
seeding
)
start
=
time
.
perf_counter
()
t
,
sim
=
model
.
simulate
(
param
,
state_init
)
end
=
time
.
perf_counter
()
print
(
f
"Run 1 Complete in
{
end
-
start
}
seconds"
)
start
=
time
.
perf_counter
()
for
i
in
range
(
1
):
t
,
upd
=
model
.
simulate
(
param
,
state_init
)
end
=
time
.
perf_counter
()
print
(
f
"Run 2 Complete in
{
(
end
-
start
)
/
1.
}
seconds"
)
# Plotting functions
fig_uk
=
plt
.
figure
()
sim
=
tf
.
reduce_sum
(
upd
,
axis
=-
2
)
# plot_age_attack_rate(fig_attack.gca(), sim, data['pop']['n'].to_numpy(), "Attack Rate")
# fig_attack.suptitle("Attack Rate")
# plot_infec_curve(fig_uk.gca(), sim.numpy(), "Infections")
fig_uk
.
gca
().
plot
(
sim
[:,
:,
2
])
fig_uk
.
suptitle
(
"UK Infections"
)
fig_uk
.
autofmt_xdate
()
fig_uk
.
gca
().
grid
(
True
)
plt
.
show
()
with
open
(
"stochastic_sim_covid.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
({
"events"
:
upd
.
numpy
(),
"state_init"
:
state_init
.
numpy
()},
f
)
# parser = optparse.OptionParser()
# parser.add_option(
# "--config",
# "-c",
# dest="config",
# default="ode_config.yaml",
# help="configuration file",
# )
# options, args = parser.parse_args([])
with
open
(
"ode_config.yaml"
,
"r"
)
as
ymlfile
:
config
=
yaml
.
load
(
ymlfile
)
param
=
sanitise_parameter
(
config
[
"parameter"
])
settings
=
sanitise_settings
(
config
[
"settings"
])
data
=
load_data
(
config
[
"data"
],
settings
,
DTYPE
)
data
[
"pop"
]
=
data
[
"pop"
].
sum
(
level
=
0
)
model
=
CovidUKStochastic
(
C
=
data
[
"C"
],
N
=
data
[
"pop"
][
"n"
].
to_numpy
(),
W
=
data
[
"W"
],
date_range
=
settings
[
"prediction_period"
],
holidays
=
settings
[
"holiday"
],
lockdown
=
settings
[
"lockdown"
],
time_step
=
1.0
,
)
# seeding = seed_areas(data['pop']['n'].to_numpy(), data['pop']['Area.name.2']) # Seed 40-44 age group, 30 seeds by popn size
# seeding = tf.one_hot(tf.squeeze(tf.where(data['pop'].index=='E09000008')), depth=data['pop'].size, dtype=DTYPE)
seeding
=
tf
.
one_hot
(
58
,
depth
=
model
.
N
.
shape
[
0
],
dtype
=
DTYPE
)
# Manchester
state_init
=
model
.
create_initial_state
(
init_matrix
=
seeding
)
start
=
time
.
perf_counter
()
t
,
sim
=
model
.
simulate
(
param
,
state_init
)
end
=
time
.
perf_counter
()
print
(
f
"Run 1 Complete in
{
end
-
start
}
seconds"
)
start
=
time
.
perf_counter
()
for
i
in
range
(
1
):
t
,
upd
=
model
.
simulate
(
param
,
state_init
)
end
=
time
.
perf_counter
()
print
(
f
"Run 2 Complete in
{
(
end
-
start
)
/
1.
}
seconds"
)
# Plotting functions
fig_uk
=
plt
.
figure
()
sim
=
tf
.
reduce_sum
(
upd
,
axis
=-
2
)
# plot_age_attack_rate(fig_attack.gca(), sim, data['pop']['n'].to_numpy(), "Attack Rate")
# fig_attack.suptitle("Attack Rate")
# plot_infec_curve(fig_uk.gca(), sim.numpy(), "Infections")
fig_uk
.
gca
().
plot
(
sim
[:,
:,
2
])
fig_uk
.
suptitle
(
"UK Infections"
)
fig_uk
.
autofmt_xdate
()
fig_uk
.
gca
().
grid
(
True
)
plt
.
show
()
with
open
(
"stochastic_sim_covid1.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
({
"events"
:
upd
.
numpy
(),
"state_init"
:
state_init
.
numpy
()},
f
)
mcmc.py
View file @
552b1d03
...
...
@@ -72,11 +72,11 @@ model = CovidUKStochastic(
# Load data
with
open
(
"stochastic_sim_covid.pkl"
,
"rb"
)
as
f
:
with
open
(
"stochastic_sim_covid
1
.pkl"
,
"rb"
)
as
f
:
example_sim
=
pkl
.
load
(
f
)
event_tensor
=
example_sim
[
"events"
]
# shape [T, M, S, S]
event_tensor
=
event_tensor
[:
8
0
,
...]
event_tensor
=
event_tensor
[:
6
0
,
...]
num_times
=
event_tensor
.
shape
[
0
]
num_meta
=
event_tensor
.
shape
[
1
]
state_init
=
example_sim
[
"state_init"
]
...
...
@@ -159,6 +159,7 @@ def make_occults_step(target_event_id):
target_log_prob_fn
=
logp
,
target_event_id
=
target_event_id
,
nmax
=
config
[
"mcmc"
][
"occult_nmax"
],
t_range
=
[
se_events
.
shape
[
0
]
-
21
,
se_events
.
shape
[
0
]],
),
name
=
"occult_update"
,
)
...
...
@@ -253,7 +254,7 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
state
[
2
],
results
[
3
]
=
se_occult
(
occult_logp
).
one_step
(
state
[
2
],
forward_results
(
results
[
2
],
results
[
3
])
)
# results[3] = forward_results(results[2], results[3])
#
results[3] = forward_results(results[2], results[3])
state
[
2
],
results
[
4
]
=
ei_occult
(
occult_logp
).
one_step
(
state
[
2
],
forward_results
(
results
[
3
],
results
[
4
])
)
...
...
@@ -324,7 +325,7 @@ event_samples = posterior.create_dataset(
"samples/events"
,
event_size
,
dtype
=
DTYPE
,
chunks
=
(
min
(
NUM_BURSTS
*
NUM_BURST_SAMPLES
,
1000
)
,)
+
tuple
(
ev
ent_s
ize
[
1
:]
),
chunks
=
(
10
,)
+
tuple
(
curr
ent_s
tate
[
1
].
shape
),
compression
=
"gzip"
,
compression_opts
=
1
,
)
...
...
@@ -332,7 +333,7 @@ occult_samples = posterior.create_dataset(
"samples/occults"
,
event_size
,
dtype
=
DTYPE
,
chunks
=
(
min
(
NUM_BURSTS
*
NUM_BURST_SAMPLES
,
1000
)
,)
+
tuple
(
ev
ent_s
ize
[
1
:]
),
chunks
=
(
10
,)
+
tuple
(
curr
ent_s
tate
[
1
].
shape
),
compression
=
"gzip"
,
compression_opts
=
1
,
)
...
...
@@ -361,7 +362,7 @@ output_results = [
print
(
"Initial logpi:"
,
logp
(
*
current_state
))
par_scale
=
tf
.
linalg
.
diag
(
tf
.
ones
(
current_state
[
0
].
shape
,
dtype
=
current_state
[
0
].
dtype
)
*
0
.0
000001
tf
.
ones
(
current_state
[
0
].
shape
,
dtype
=
current_state
[
0
].
dtype
)
*
1
.0
)
# We loop over successive calls to sample because we have to dump results
...
...
ode_config.yaml
View file @
552b1d03
...
...
@@ -10,7 +10,7 @@ data:
parameter
:
beta1
:
0.6
# R0 2.4
beta2
:
0.
33
# Contact with commuters 1/3rd of the time
beta2
:
0.
5
# Contact with commuters 1/3rd of the time
beta3
:
0.25
# lockdown vs normal
omega
:
1.0
# Non-linearity parameter for commuting volume
nu
:
0.5
# E -> I transition rate
...
...
@@ -35,10 +35,10 @@ settings:
mcmc
:
dmax
:
16
nmax
:
16
0
nmax
:
2
0
m
:
1
occult_nmax
:
250
num_event_time_updates
:
1
00
occult_nmax
:
15
num_event_time_updates
:
1
49
num_bursts
:
100
num_burst_samples
:
100
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment