Commit 552b1d03 authored by Chris Jewell's avatar Chris Jewell
Browse files

Stochastic simulation model update

parent 0b337d47
......@@ -41,12 +41,20 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
"""UncalibratedEventTimesUpdate"""
def __init__(
self, target_log_prob_fn, target_event_id, nmax, seed=None, name=None,
self,
target_log_prob_fn,
target_event_id,
nmax,
t_range=None,
seed=None,
name=None,
):
"""An uncalibrated random walk for event times.
:param target_log_prob_fn: the log density of the target distribution
:param target_event_id: the position in the last dimension of the events
tensor that we wish to move
:param t_range: a tuple containing earliest and latest times between which
to update occults.
:param seed: a random seed
:param name: the name of the update step
"""
......@@ -57,6 +65,7 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
target_log_prob_fn=target_log_prob_fn,
target_event_id=target_event_id,
nmax=nmax,
t_range=t_range,
seed=seed,
name=name,
)
......@@ -101,7 +110,9 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
def true_fn():
with tf.name_scope("true_fn"):
proposal = AddOccultProposal(
current_events, self.parameters["nmax"]
events=current_events,
n_max=self.parameters["nmax"],
t_range=self.parameters["t_range"],
)
update = proposal.sample()
next_state = _add_events(
......@@ -128,7 +139,11 @@ class UncalibratedOccultUpdate(tfp.mcmc.TransitionKernel):
x=[self.tx_topology.target],
x_star=tf.cast(-update["x_star"], current_events.dtype),
)
reverse = AddOccultProposal(next_state, self.parameters["nmax"])
reverse = AddOccultProposal(
events=next_state,
n_max=self.parameters["nmax"],
t_range=self.parameters["t_range"],
)
q_fwd = tf.reduce_sum(proposal.log_prob(update))
q_rev = tf.reduce_sum(reverse.log_prob(update))
log_acceptance_correction = q_rev - q_fwd
......
......@@ -7,7 +7,10 @@ from covid.impl.Categorical2 import Categorical2
tfd = tfp.distributions
def AddOccultProposal(events, n_max, dtype=tf.int32, name=None):
def AddOccultProposal(events, n_max, t_range=None, dtype=tf.int32, name=None):
if t_range is None:
t_range = [0, events.shape[-2]]
def m():
"""Select a metapopulation"""
with tf.name_scope("m"):
......@@ -16,7 +19,7 @@ def AddOccultProposal(events, n_max, dtype=tf.int32, name=None):
def t():
"""Select a timepoint"""
with tf.name_scope("t"):
return UniformInteger(low=[0], high=[events.shape[1]], dtype=dtype)
return UniformInteger(low=[t_range[0]], high=[t_range[1]], dtype=dtype)
def x_star():
"""Draw num to add"""
......
......@@ -155,7 +155,7 @@ class CovidUKStochastic(CovidUK):
* commute_volume
* tf.linalg.matvec(self.C, state[..., 2] / self.N)
)
infec_rate = infec_rate / self.N # Vector of length nc
infec_rate = infec_rate / self.N + 0.00000001 # Vector of length nc
ei = tf.broadcast_to(
[param["nu"]], shape=[state.shape[0]]
......
......@@ -5,9 +5,6 @@ import pickle as pkl
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
import numpy as np
import matplotlib.pyplot as plt
import yaml
......@@ -15,6 +12,11 @@ import yaml
from covid.model import CovidUKStochastic, load_data
from covid.util import sanitise_parameter, sanitise_settings, seed_areas
tfd = tfp.distributions
tfb = tfp.bijectors
DTYPE = np.float64
......@@ -157,81 +159,63 @@ def plot_age_attack_rate(ax, sim, N, label):
ax.plot(data["age_groups"], attack_rate, "o-", label=label)
if __name__ == "__main__":
parser = optparse.OptionParser()
parser.add_option(
"--config",
"-c",
dest="config",
default="ode_config.yaml",
help="configuration file",
)
options, args = parser.parse_args()
with open(options.config, "r") as ymlfile:
config = yaml.load(ymlfile)
param = sanitise_parameter(config["parameter"])
settings = sanitise_settings(config["settings"])
parser = optparse.OptionParser()
parser.add_option(
"--config",
"-c",
dest="config",
default="ode_config.yaml",
help="configuration file",
)
options, args = parser.parse_args()
with open(options.config, "r") as ymlfile:
config = yaml.load(ymlfile)
param = sanitise_parameter(config["parameter"])
settings = sanitise_settings(config["settings"])
data = load_data(config["data"], settings, DTYPE)
data["pop"] = data["pop"].sum(level=0)
model = CovidUKStochastic(
C=data["C"],
N=data["pop"]["n"].to_numpy(),
W=data["W"],
date_range=settings["prediction_period"],
holidays=settings["holiday"],
lockdown=settings["lockdown"],
time_step=1.0,
)
# seeding = seed_areas(data['pop']['n'].to_numpy(), data['pop']['Area.name.2']) # Seed 40-44 age group, 30 seeds by popn size
# seeding = tf.one_hot(tf.squeeze(tf.where(data['pop'].index=='E09000008')), depth=data['pop'].size, dtype=DTYPE)
seeding = tf.one_hot(58, depth=model.N.shape[0], dtype=DTYPE) # Manchester
state_init = model.create_initial_state(init_matrix=seeding)
start = time.perf_counter()
t, sim = model.simulate(param, state_init)
end = time.perf_counter()
print(f"Run 1 Complete in {end - start} seconds")
start = time.perf_counter()
for i in range(1):
t, upd = model.simulate(param, state_init)
end = time.perf_counter()
print(f"Run 2 Complete in {(end - start)/1.} seconds")
# Plotting functions
fig_uk = plt.figure()
sim = tf.reduce_sum(upd, axis=-2)
# plot_age_attack_rate(fig_attack.gca(), sim, data['pop']['n'].to_numpy(), "Attack Rate")
# fig_attack.suptitle("Attack Rate")
# plot_infec_curve(fig_uk.gca(), sim.numpy(), "Infections")
fig_uk.gca().plot(sim[:, :, 2])
fig_uk.suptitle("UK Infections")
fig_uk.autofmt_xdate()
fig_uk.gca().grid(True)
plt.show()
with open("stochastic_sim_covid.pkl", "wb") as f:
pkl.dump({"events": upd.numpy(), "state_init": state_init.numpy()}, f)
# parser = optparse.OptionParser()
# parser.add_option(
# "--config",
# "-c",
# dest="config",
# default="ode_config.yaml",
# help="configuration file",
# )
# options, args = parser.parse_args([])
with open("ode_config.yaml", "r") as ymlfile:
config = yaml.load(ymlfile)
param = sanitise_parameter(config["parameter"])
settings = sanitise_settings(config["settings"])
data = load_data(config["data"], settings, DTYPE)
data["pop"] = data["pop"].sum(level=0)
model = CovidUKStochastic(
C=data["C"],
N=data["pop"]["n"].to_numpy(),
W=data["W"],
date_range=settings["prediction_period"],
holidays=settings["holiday"],
lockdown=settings["lockdown"],
time_step=1.0,
)
# seeding = seed_areas(data['pop']['n'].to_numpy(), data['pop']['Area.name.2']) # Seed 40-44 age group, 30 seeds by popn size
# seeding = tf.one_hot(tf.squeeze(tf.where(data['pop'].index=='E09000008')), depth=data['pop'].size, dtype=DTYPE)
seeding = tf.one_hot(58, depth=model.N.shape[0], dtype=DTYPE) # Manchester
state_init = model.create_initial_state(init_matrix=seeding)
start = time.perf_counter()
t, sim = model.simulate(param, state_init)
end = time.perf_counter()
print(f"Run 1 Complete in {end - start} seconds")
start = time.perf_counter()
for i in range(1):
t, upd = model.simulate(param, state_init)
end = time.perf_counter()
print(f"Run 2 Complete in {(end - start)/1.} seconds")
# Plotting functions
fig_uk = plt.figure()
sim = tf.reduce_sum(upd, axis=-2)
# plot_age_attack_rate(fig_attack.gca(), sim, data['pop']['n'].to_numpy(), "Attack Rate")
# fig_attack.suptitle("Attack Rate")
# plot_infec_curve(fig_uk.gca(), sim.numpy(), "Infections")
fig_uk.gca().plot(sim[:, :, 2])
fig_uk.suptitle("UK Infections")
fig_uk.autofmt_xdate()
fig_uk.gca().grid(True)
plt.show()
with open("stochastic_sim_covid1.pkl", "wb") as f:
pkl.dump({"events": upd.numpy(), "state_init": state_init.numpy()}, f)
......@@ -72,11 +72,11 @@ model = CovidUKStochastic(
# Load data
with open("stochastic_sim_covid.pkl", "rb") as f:
with open("stochastic_sim_covid1.pkl", "rb") as f:
example_sim = pkl.load(f)
event_tensor = example_sim["events"] # shape [T, M, S, S]
event_tensor = event_tensor[:80, ...]
event_tensor = event_tensor[:60, ...]
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
state_init = example_sim["state_init"]
......@@ -159,6 +159,7 @@ def make_occults_step(target_event_id):
target_log_prob_fn=logp,
target_event_id=target_event_id,
nmax=config["mcmc"]["occult_nmax"],
t_range=[se_events.shape[0] - 21, se_events.shape[0]],
),
name="occult_update",
)
......@@ -253,7 +254,7 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
state[2], results[3] = se_occult(occult_logp).one_step(
state[2], forward_results(results[2], results[3])
)
# results[3] = forward_results(results[2], results[3])
# results[3] = forward_results(results[2], results[3])
state[2], results[4] = ei_occult(occult_logp).one_step(
state[2], forward_results(results[3], results[4])
)
......@@ -324,7 +325,7 @@ event_samples = posterior.create_dataset(
"samples/events",
event_size,
dtype=DTYPE,
chunks=(min(NUM_BURSTS * NUM_BURST_SAMPLES, 1000),) + tuple(event_size[1:]),
chunks=(10,) + tuple(current_state[1].shape),
compression="gzip",
compression_opts=1,
)
......@@ -332,7 +333,7 @@ occult_samples = posterior.create_dataset(
"samples/occults",
event_size,
dtype=DTYPE,
chunks=(min(NUM_BURSTS * NUM_BURST_SAMPLES, 1000),) + tuple(event_size[1:]),
chunks=(10,) + tuple(current_state[1].shape),
compression="gzip",
compression_opts=1,
)
......@@ -361,7 +362,7 @@ output_results = [
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(
tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.0000001
tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 1.0
)
# We loop over successive calls to sample because we have to dump results
......
......@@ -10,7 +10,7 @@ data:
parameter:
beta1: 0.6 # R0 2.4
beta2: 0.33 # Contact with commuters 1/3rd of the time
beta2: 0.5 # Contact with commuters 1/3rd of the time
beta3: 0.25 # lockdown vs normal
omega: 1.0 # Non-linearity parameter for commuting volume
nu: 0.5 # E -> I transition rate
......@@ -35,10 +35,10 @@ settings:
mcmc:
dmax: 16
nmax: 160
nmax: 20
m: 1
occult_nmax: 250
num_event_time_updates: 100
occult_nmax: 15
num_event_time_updates: 149
num_bursts: 100
num_burst_samples: 100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment