Commit c11e821c authored by Chris Jewell's avatar Chris Jewell
Browse files

Changes:

1. Wrote wrapper `FilteredEventTimesProposal` for EventTimesProposal to subset metapopulations.  This currently selects just one meta-population, but the plan is to select more.
2. Moved test fixture to folder inside tests.
3. Wrote (failing) TestFilteredEventTimesProposal test.
parent 67463e3d
......@@ -12,7 +12,8 @@ tfd = tfp.distributions
class UniformInteger(tfd.Distribution):
def __init__(self, low=0, high=1, validate_args=False,
allow_nan_stats=True, dtype=tf.int32, name='UniformInteger'):
allow_nan_stats=True, dtype=tf.int32, float_dtype=tf.float64,
name='UniformInteger'):
"""Initialise a UniformInteger random variable on `[low, high)`.
Args:
......@@ -45,6 +46,7 @@ class UniformInteger(tfd.Distribution):
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=name)
self.float_dtype = float_dtype
@staticmethod
def _param_shapes(sample_shape):
......@@ -103,9 +105,9 @@ class UniformInteger(tfd.Distribution):
self.dtype)
def _prob(self, x):
low = tf.cast(self.low, tf.float32)
high = tf.cast(self.high, tf.float32)
x = tf.cast(x, dtype=tf.float32)
low = tf.cast(self.low, self.float_dtype)
high = tf.cast(self.high, self.float_dtype)
x = tf.cast(x, dtype=self.float_dtype)
return tf.where(
tf.math.is_nan(x),
......@@ -116,4 +118,5 @@ class UniformInteger(tfd.Distribution):
tf.ones_like(x) / self._range(low=low, high=high)))
def _log_prob(self, x):
return tf.math.log(self._prob(x))
res = tf.math.log(self._prob(x))
return res
import collections
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.util import SeedStream
from covid import config
from covid.impl.event_time_proposal import TransitionTopology, \
FilteredEventTimeProposal
from covid.impl.mcmc import KernelResults
from covid.impl.util import which
tfd = tfp.distributions
DTYPE = config.floatX
TransitionTopology = collections.namedtuple('TransitionTopology',
('prev',
'target',
'next'))
def _is_within(x, low, high):
"""Returns true if low <= x < high"""
......@@ -45,14 +38,16 @@ def _max_free_events(events, initial_state,
def true_fn():
target_events_ = tf.gather(events, target_id, axis=-1)
target_cumsum = tf.cumsum(target_events_, axis=0)
constraining_events = tf.gather(events, constraint_id, axis=-1) # TxM
constraining_cumsum = tf.cumsum(constraining_events, axis=0) # TxM
constraining_init_state = tf.gather(initial_state, constraint_id + 1, axis=-1)
constraining_events = tf.gather(events, constraint_id, axis=-1) # TxM
constraining_cumsum = tf.cumsum(constraining_events, axis=0) # TxM
constraining_init_state = tf.gather(initial_state, constraint_id + 1,
axis=-1)
n1 = tf.gather(target_cumsum, constraint_t, axis=0)
n2 = tf.gather(constraining_cumsum, constraint_t, axis=0)
free_events = tf.abs(n1 - n2) + constraining_init_state
max_free_events = tf.minimum(free_events,
tf.gather(target_events_, target_t, axis=0))
tf.gather(target_events_, target_t,
axis=0))
return max_free_events
# Manual broadcasting of n_events_t is required here so that the XLA
......@@ -61,23 +56,26 @@ def _max_free_events(events, initial_state,
# propagated right through the algorithm, so the return value has known shape.
def false_fn():
n_events_t = tf.gather(events[..., target_id], target_t, axis=0)
return tf.broadcast_to([n_events_t], [constraint_t.shape[0]] + [n_events_t.shape[0]])
return tf.broadcast_to([n_events_t],
[constraint_t.shape[0]] + [n_events_t.shape[0]])
ret_val = tf.cond(constraint_id != -1, true_fn, false_fn)
return ret_val
def _move_events(event_tensor, event_id, from_t, to_t, n_move):
"""Subtracts n_move from event_tensor[from_t, :, event_id]
and adds n_move to event_tensor[to_t, :, event_id]."""
num_meta = event_tensor.shape[1]
indices = tf.stack([tf.broadcast_to(from_t, [num_meta]), # Timepoint
tf.range(num_meta), # All meta-populations
tf.broadcast_to([event_id], [num_meta])], axis=-1) # Event
"""Subtracts n_move from event_tensor[:, from_t, event_id]
and adds n_move to event_tensor[:, to_t, event_id]."""
num_meta = event_tensor.shape[0]
indices = tf.stack([tf.range(num_meta), # All meta-populations
from_t,
tf.broadcast_to([event_id], [num_meta])],
axis=-1) # Event
# Subtract x_star from the [from_t, :, event_id] row of the state tensor
n_move = tf.cast(n_move, event_tensor.dtype)
next_state = tf.tensor_scatter_nd_sub(event_tensor, indices, n_move)
indices = tf.stack([tf.broadcast_to(to_t, [num_meta]),
tf.range(num_meta),
indices = tf.stack([tf.range(num_meta),
to_t,
tf.broadcast_to(event_id, [num_meta])], axis=-1)
# Add x_star to the [to_t, :, event_id] row of the state tensor
next_state = tf.tensor_scatter_nd_add(next_state, indices, n_move)
......@@ -110,14 +108,15 @@ class EventTimesUpdate(tfp.mcmc.TransitionKernel):
"""
self._seed_stream = SeedStream(seed, salt='EventTimesUpdate')
self._impl = tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(target_log_prob_fn=target_log_prob_fn,
target_event_id=target_event_id,
prev_event_id=prev_event_id,
next_event_id=next_event_id,
dmax=dmax,
mmax=mmax,
nmax=nmax,
initial_state=initial_state))
inner_kernel=UncalibratedEventTimesUpdate(
target_log_prob_fn=target_log_prob_fn,
target_event_id=target_event_id,
prev_event_id=prev_event_id,
next_event_id=next_event_id,
dmax=dmax,
mmax=mmax,
nmax=nmax,
initial_state=initial_state))
self._parameters = self._impl.inner_kernel.parameters.copy()
self._parameters['seed'] = seed
......@@ -144,7 +143,8 @@ class EventTimesUpdate(tfp.mcmc.TransitionKernel):
:param previous_kernel_results: a named tuple of results.
:returns: (next_state, kernel_results)
"""
next_state, kernel_results = self._impl.one_step(current_state, previous_kernel_results)
next_state, kernel_results = self._impl.one_step(current_state,
previous_kernel_results)
return next_state, kernel_results
def bootstrap_results(self, init_state):
......@@ -174,7 +174,8 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
:param name: the name of the update step
"""
self._target_log_prob_fn = target_log_prob_fn
self._seed_stream = SeedStream(seed, salt='UncalibratedEventTimesUpdate')
self._seed_stream = SeedStream(seed,
salt='UncalibratedEventTimesUpdate')
self._name = name
self._parameters = dict(
target_log_prob_fn=target_log_prob_fn,
......@@ -187,7 +188,8 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
nmax=nmax,
seed=seed,
name=name)
self.tx_topology = TransitionTopology(prev_event_id, target_event_id, next_event_id)
self.tx_topology = TransitionTopology(prev_event_id, target_event_id,
next_event_id)
self.time_offsets = tf.range(self.parameters['dmax'])
@property
......@@ -225,79 +227,61 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def one_step(self, current_events, previous_kernel_results):
"""One update of event times.
:param current_events: a [T, M, X] tensor containing number of events per time t, metapopulation m,
and transition x.
:param previous_kernel_results: an object of type UncalibratedRandomWalkResults.
:returns: a tuple containing new_state and UncalibratedRandomWalkResults.
:param current_events: a [T, M, X] tensor containing number of events
per time t, metapopulation m,
and transition x.
:param previous_kernel_results: an object of type
UncalibratedRandomWalkResults.
:returns: a tuple containing new_state and UncalibratedRandomWalkResults
"""
with tf.name_scope('uncalibrated_event_times_rw/onestep'):
current_events = tf.transpose(current_events, perm=(1, 0, 2))
target_events = current_events[..., self.tx_topology.target]
num_times = target_events.shape[0]
# 1. Choose a timepoint to move, conditional on it having events to move
current_p = _nonzero_rows(target_events)
current_t = tf.squeeze(tf.random.categorical(logits=[tf.math.log(current_p)],
num_samples=1,
seed=self._seed_stream(),
dtype=tf.int32))
# 2. time_delta has a magnitude and sign -- a jump in time for which to move events
# tfp.math.random_rademacker
# bernoulli * 2 - 1
u = tf.squeeze(tf.random.uniform(shape=[1], seed=self._seed_stream(), # 0 is backwards
minval=0, maxval=2, dtype=tf.int32)) # 1 is forwards
jump_sign = tf.gather([-1, 1], u)
jump_magnitude = tf.squeeze(tf.random.uniform([1], seed=self._seed_stream(),
minval=0, maxval=self.parameters['dmax'],
dtype=tf.int32)) + 1
time_delta = jump_sign * jump_magnitude
next_t = current_t + time_delta
# Compute the constraint times (current_t, time_offsets, (target, prev, next),
# events_tensor, initial state, distance)
n_max = self.compute_constraints(current_events, current_t, time_delta)
# Draw number to move uniformly from n_max
p_msk = tf.cast(n_max > 0., dtype=tf.float32)
W = tfd.OneHotCategorical(logits=tf.math.log(p_msk))
msk = tf.cast(W.sample(), n_max.dtype)
clip_max = 20.
n_max = tf.clip_by_value(n_max, clip_value_min=0., clip_value_max=clip_max)
x_star = tf.floor(tf.random.uniform(n_max.shape, minval=0., maxval=(n_max + 1.),
dtype=current_events.dtype)) * msk
# Propose next_state
next_state = _move_events(event_tensor=current_events, event_id=self.tx_topology.target,
from_t=current_t, to_t=next_t,
n_move=x_star)
next_target_log_prob = self.target_log_prob_fn(next_state)
proposal = FilteredEventTimeProposal(current_events,
self.parameters[
'initial_state'],
self.tx_topology,
self.parameters['dmax'],
self.parameters['nmax'])
move = proposal.sample()
next_state = _move_events(event_tensor=current_events,
event_id=self.tx_topology.target,
from_t=move['t'],
to_t=move['t'] + move['delta_t'],
n_move=move['x_star'])
next_state_tr = tf.transpose(next_state, perm=(1, 0, 2))
next_target_log_prob = self.target_log_prob_fn(next_state_tr)
# Trap out-of-bounds moves that go outside [0, num_times)
next_target_log_prob = tf.where(_is_within(next_t, 0, num_times),
next_target_log_prob,
tf.constant(-np.inf, dtype=current_events.dtype))
# Calculate proposal density
# 1. Calculate probability of choosing a timepoint
next_p = _nonzero_rows(next_state[..., self.target_event_id])
log_acceptance_correction = tf.math.log(tf.reduce_sum(current_p)) - \
tf.math.log(tf.reduce_sum(next_p))
# 2. Calculate probability of selecting events
next_n_max = self.compute_constraints(next_state, next_t, -time_delta)
next_n_max = tf.clip_by_value(next_n_max, clip_value_min=0., clip_value_max=clip_max)
log_acceptance_correction += tf.reduce_sum(tf.math.log(n_max + 1.) - tf.math.log(next_n_max + 1.))
# 3. Prob of choosing a non-zero element to move
log_acceptance_correction = tf.math.log(
tf.math.count_nonzero(n_max, dtype=log_acceptance_correction.dtype)) - tf.math.log(
tf.math.count_nonzero(next_n_max, dtype=log_acceptance_correction.dtype))
return [next_state,
next_target_log_prob = tf.where(
_is_within(move['t'] + move['delta_t'], 0,
num_times),
next_target_log_prob,
tf.constant(-np.inf,
dtype=current_events.dtype))
# Calculate proposal mass ratio
q_fwd = proposal.log_prob(move)
move['t'] = move['t'] + move['delta_t']
move['delta_t'] = -move['delta_t']
q_rev = FilteredEventTimeProposal(event_tensor=next_state,
initial_state=self.parameters[
'initial_state'],
topology=self.tx_topology,
d_max=self.parameters['dmax'],
n_max=self.parameters[
'nmax']).log_prob(move)
log_acceptance_correction = q_rev - q_fwd
return [next_state_tr,
KernelResults(
log_acceptance_correction=log_acceptance_correction,
target_log_prob=next_target_log_prob,
extra=tf.concat([x_star, n_max], axis=0)
extra=tf.concat(move['x_star'], axis=0)
)]
def compute_constraints(self, current_events, current_t, time_delta):
......@@ -315,12 +299,17 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
self.tx_topology.next or -1)
# 3. Calculate max number of events to move subject to constraints
n_max = _max_free_events(events=current_events, initial_state=self.parameters['initial_state'],
target_t=current_t, target_id=self.tx_topology.target,
constraint_t=constraint_time_idx, constraint_id=constraining_event_id)
n_max = _max_free_events(events=current_events,
initial_state=self.parameters['initial_state'],
target_t=current_t,
target_id=self.tx_topology.target,
constraint_t=constraint_time_idx,
constraint_id=constraining_event_id)
inf_mask = tf.cumsum(tf.one_hot(tf.math.abs(time_delta),
self.parameters['dmax'], dtype=tf.int32)) * tf.int32.max
n_max = tf.reduce_min(tf.cast(inf_mask[:, None], n_max.dtype) + n_max, axis=0)
self.parameters['dmax'],
dtype=tf.int32)) * tf.int32.max
n_max = tf.reduce_min(tf.cast(inf_mask[:, None], n_max.dtype) + n_max,
axis=0)
return n_max
def bootstrap_results(self, init_state):
......
......@@ -70,6 +70,30 @@ def _abscumdiff(events, initial_state,
return ret_val
class Deterministic2(tfd.Deterministic):
def __init__(self,
loc,
atol=None,
rtol=None,
validate_args=False,
allow_nan_stats=True,
log_prob_dtype=tf.float32,
name='Deterministic'):
parameters = dict(locals())
super(Deterministic2, self).__init__(
loc,
atol=atol,
rtol=rtol,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name
)
self.log_prob_dtype=log_prob_dtype
def _prob(self, x):
return tf.constant(1, dtype=self.log_prob_dtype)
def TimeDelta(dmax, name=None):
outcomes = tf.concat([-tf.range(1, dmax + 1), tf.range(1, dmax + 1)],
axis=0)
......@@ -140,3 +164,49 @@ def EventTimeProposal(events, initial_state, topology, d_max, n_max,
return tfd.JointDistributionNamed(dict(t=t,
delta_t=delta_t,
x_star=x_star), name=name)
def FilteredEventTimeProposal(events, initial_state, topology, d_max, n_max,
dtype=tf.int32, name=None):
"""FilteredEventTimeProposal allows us to choose a subset of indices
in `range(events.shape[0])` for which to propose an update. The
results are then broadcast back to `events.shape[0]`. """
target_events = tf.gather(events, topology.target, axis=-1)
def m():
hot_meta = tf.math.count_nonzero(target_events, axis=1) > 0
logits = tf.math.log(tf.cast(hot_meta, tf.float64))
return tfd.Categorical(logits=[logits], name='m')
def inner_move(m):
select_meta = tf.gather(events, m, axis=0)
select_init = tf.gather(initial_state, m, axis=0)
return EventTimeProposal(select_meta, select_init, topology, d_max, n_max,
dtype=dtype, name=None)
def t(m, inner_move):
"""Scatter metapop updates to original metapop dimension"""
return Deterministic2(
tf.scatter_nd([m], inner_move['t'], [events.shape[0]]),
log_prob_dtype=events.dtype,
name='t'
)
def delta_t(inner_move):
return Deterministic2(inner_move['delta_t'],
log_prob_dtype=events.dtype,
name='delta_t')
def x_star(m, inner_move):
"""Scatter metapop updates to original metapop dimension"""
return Deterministic2(
tf.scatter_nd([m], inner_move['x_star'], [events.shape[0]]),
log_prob_dtype=events.dtype,
name='x_star'
)
return tfd.JointDistributionNamed(dict(m=m,
inner_move=inner_move,
t=t,
delta_t=delta_t,
x_star=x_star))
import os
import tqdm
import pickle as pkl
import yaml
import h5py
import h5py
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tqdm
import yaml
tfd = tfp.distributions
tfb = tfp.bijectors
from covid import config
from covid.model import load_data, CovidUKStochastic
from covid.util import sanitise_parameter, sanitise_settings, seed_areas
from covid.util import sanitise_parameter, sanitise_settings
from covid.impl.util import make_transition_matrix
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
from covid.impl.event_time_mh import EventTimesUpdate
......@@ -20,7 +21,7 @@ from covid.impl.event_time_mh import EventTimesUpdate
DTYPE = config.floatX
# Random moves of events. What invalidates an epidemic, how can we test for it?
with open('ode_config.yaml','r') as f:
with open('ode_config.yaml', 'r') as f:
config = yaml.load(f)
param = sanitise_parameter(config['parameter'])
......@@ -32,7 +33,7 @@ data = load_data(config['data'], settings, DTYPE)
data['pop'] = data['pop'].sum(level=0)
model = CovidUKStochastic(C=data['C'][:10, :10],
N=[1000]*10, #data['pop']['n'].to_numpy(),
N=[1000] * 10, # data['pop']['n'].to_numpy(),
W=data['W'],
date_range=settings['inference_period'],
holidays=settings['holiday'],
......@@ -51,20 +52,23 @@ se_events = event_tensor[:, :, 0, 1]
ei_events = event_tensor[:, :, 1, 2]
ir_events = event_tensor[:, :, 2, 3]
def logp(par, events):
p = param
p['beta1'] = tf.convert_to_tensor(par[0], dtype=DTYPE)
#p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
#p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
# p['beta2'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
# p['beta3'] = tf.convert_to_tensor(par[2], dtype=DTYPE)
p['gamma'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
beta1_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta1'])
#beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
rate=tf.constant(1., dtype=DTYPE)).log_prob(
p['beta1'])
# beta2_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
# rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta2'])
#beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
# beta3_logp = tfd.Gamma(concentration=tf.constant(2., dtype=DTYPE),
# rate=tf.constant(2., dtype=DTYPE)).log_prob(p['beta3'])
gamma_logp = tfd.Gamma(concentration=tf.constant(100., dtype=DTYPE),
rate=tf.constant(400., dtype=DTYPE)).log_prob(p['gamma'])
rate=tf.constant(400., dtype=DTYPE)).log_prob(
p['gamma'])
with tf.name_scope('main_log_p'):
event_tensor = make_transition_matrix(events,
[[0, 1], [1, 2], [2, 3]],
......@@ -85,9 +89,11 @@ def make_parameter_kernel(scale, bounded_convergence):
def kernel_func(logp):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedLogRandomWalk(
target_log_prob_fn=logp,
new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence)
), name='parameter_update')
target_log_prob_fn=logp,
new_state_fn=random_walk_mvnorm_fn(scale,
p_u=bounded_convergence)
), name='parameter_update')
return kernel_func
......@@ -101,6 +107,7 @@ def make_events_step(target_event_id, prev_event_id=None, next_event_id=None):
mmax=1,
nmax=20,
initial_state=state_init)
return kernel_func
......@@ -118,15 +125,18 @@ def trace_results_fn(results):
return tf.concat([[log_prob], [accepted], proposed], axis=0)
#@tf.function #(autograph=False, experimental_compile=True)
# @tf.function #(autograph=False, experimental_compile=True)
def sample(n_samples, init_state, par_scale):
init_state = init_state.copy()
par_func = make_parameter_kernel(par_scale, 0.95)
se_func = make_events_step(0, None, 1)
ei_func = make_events_step(target_event_id=1, prev_event_id=0, next_event_id=2)
ei_func = make_events_step(target_event_id=1, prev_event_id=0,
next_event_id=2)
# Based on Gibbs idea posted by Pavel Sountsov https://github.com/tensorflow/probability/issues/495
results = ei_func(lambda s: logp(init_state[0], s)).bootstrap_results(init_state[1])
# Based on Gibbs idea posted by Pavel Sountsov
# https://github.com/tensorflow/probability/issues/495
results = ei_func(lambda s: logp(init_state[0], s)).bootstrap_results(
init_state[1])
samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(3)]
......@@ -136,73 +146,95 @@ def sample(n_samples, init_state, par_scale):
def par_logp(par_state):
state[0] = par_state # close over state from outer scope
return logp(*state)
state[0], par_results = par_func(par_logp).one_step(state[0], prev_results)
state[0], par_results = par_func(par_logp).one_step(state[0],
prev_results)
# States
def state_logp(event_state):
state[1] = event_state
return logp(*state)
state[1], se_results = se_func(state_logp).one_step(state[1], par_results)
state[1], ei_results = ei_func(state_logp).one_step(state[1], se_results)
state[1], se_results = se_func(state_logp).one_step(state[1],
par_results)
state[1], ei_results = ei_func(state_logp).one_step(state[1],
se_results)
samples = [samples[k].write(i, s) for k, s in enumerate(state)]
results = [results[k].write(i, trace_results_fn(r))
for k, r in enumerate([par_results, se_results, ei_results])]
return i+1, state, ei_results, samples, results
return i + 1, state, ei_results, samples, results
def cond(i, _1, _2, _3, _4):
return i < n_samples
_1, _2, _3, samples, results = tf.while_loop(cond=cond, body=body,
loop_vars=[0, init_state, results, samples_arr, results_arr])
loop_vars=[0, init_state,
results,
samples_arr,
results_arr])
return [s.stack() for s in samples], [r.stack() for r in results]
if __name__=='__main__':
if __name__ == '__main__':
num_loop_iterations = 1000
num_loop_samples = 100
current_state = [np.array([0.15, 0.25], dtype=DTYPE), tf.stack([se_events, ei_events, ir_events], axis=-1)]
posterior = h5py.File(os.path.expandvars(config['output']['posterior']), 'w')
event_size = [num_loop_iterations * num_loop_samples] + list(current_state[1].shape)
par_samples = posterior.create_dataset('samples/parameter', [num_loop_iterations*num_loop_samples,
current_state[0].shape[0]], dtype=np.float64)
se_samples = posterior.create_dataset('samples/events', event_size, dtype=DTYPE)
par_results = posterior.create_dataset('acceptance/parameter', (num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
se_results = posterior.create_dataset('acceptance/S->E', (num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
ei_results = posterior.create_dataset('acceptance/E->I', (num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
current_state = [np.array([0.15, 0.25], dtype=DTYPE),
tf.stack([se_events, ei_events, ir_events], axis=-1)]
posterior = h5py.File(os.path.expandvars(config['output']['posterior']),
'w')
event_size = [num_loop_iterations * num_loop_samples] + list(
current_state[1].shape)
par_samples = posterior.create_dataset('samples/parameter', [
num_loop_iterations * num_loop_samples,
current_state[0].shape[0]], dtype=np.float64)
se_samples = posterior.create_dataset('samples/events', event_size,
dtype=DTYPE)
par_results = posterior.create_dataset('acceptance/parameter', (
num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
se_results = posterior.create_dataset('acceptance/S->E', (
num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
ei_results = posterior.create_dataset('acceptance/E->I', (
num_loop_iterations * num_loop_samples, 22), dtype=DTYPE)
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1)
par_scale = tf.linalg.diag(