Commit f373d432 authored by Chris Jewell's avatar Chris Jewell
Browse files

First pass implementation of data augmentation MCMC. This is slow, buggy, and...

First pass implementation of data augmentation MCMC.  This is slow, buggy, and possibly not correct yet!  Use with caution!
parent 33e24006
......@@ -4,7 +4,7 @@ import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from covid.impl.util import make_transition_rate_matrix
from covid.impl.util import make_transition_matrix
def approx_expm(rates):
......
"""MCMC Update classes for stochastic epidemic models"""
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.util import SeedStream
......@@ -19,9 +20,10 @@ def matrix_where(condition):
msk = tf.reshape(condition, [-1])
msk_idx = tf.boolean_mask(tf.range(tf.size(msk), dtype=tf.int64), msk)
true_coords = tf.stack([msk_idx // ncol, msk_idx % ncol],
axis=-1)
axis=-1)
return true_coords
def make_event_time_move(counts_matrix, q, p, alpha):
"""Returns a proposal to move infection times.
......@@ -40,9 +42,10 @@ def make_event_time_move(counts_matrix, q, p, alpha):
alpha = tf.convert_to_tensor(alpha, dtype=DTYPE)
q = tf.convert_to_tensor(q, dtype=DTYPE)
nz_idx = matrix_where(counts_matrix > 0)
nz_idx = tf.where(counts_matrix > 0)
# Choose which elements to move
# Todo there's a bug here if no elements are chosen.
ix = tfd.Sample(tfd.Bernoulli(probs=q, dtype=tf.bool), [tf.shape(nz_idx)[0]], name='ix')
def tm(ix):
......@@ -64,7 +67,8 @@ def make_event_time_move(counts_matrix, q, p, alpha):
def distance(dir, d_mag):
# Compute the distance to move as product of direction and distance
return DeterministicFloatX(tf.gather(tf.constant([-1., 1.], dtype=DTYPE), dir)*d_mag, name='distance')
return DeterministicFloatX(tf.gather(tf.constant([-1., 1.], dtype=DTYPE), dir) * (d_mag + 1),
name='distance')
return tfd.JointDistributionNamed({
'ix': ix,
......@@ -78,7 +82,6 @@ def make_event_time_move(counts_matrix, q, p, alpha):
class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def __init__(self,
target_log_prob_fn,
transition_coord,
q,
p,
alpha,
......@@ -86,7 +89,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
name=None):
"""An uncalibrated random walk for event times.
:param target_log_prob_fn: the log density of the target distribution
:param transition_coord: the coordinate of the transition in the transition matrix
:param p: the proportion of events to move
:param alpha: the magnitude of the distance over which to move
:param seed: a random seed stream
......@@ -97,7 +99,6 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
self._name = name
self._parameters = dict(
target_log_prob_fn=target_log_prob_fn,
transition_coord=tf.convert_to_tensor(transition_coord, dtype=tf.int64),
q=q,
p=p,
alpha=alpha,
......@@ -131,19 +132,15 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def one_step(self, current_state, previous_kernel_results):
with tf.name_scope('uncalibrated_event_times_rw/onestep'):
proposal = make_event_time_move(current_state[..., self.transition_coord[0],
self.transition_coord[1]],
proposal = make_event_time_move(current_state,
self._parameters['q'],
self._parameters['p'],
self._parameters['alpha'])
x_star = proposal.sample(seed=self.seed) # This is the move to make
n_move = tf.shape(x_star['tm'],out_type=x_star['tm'].dtype)[0] # Number of time/metapop moves
coord_dtype = x_star['tm'].dtype
state_coord = tf.broadcast_to(self.transition_coord,
[n_move, self.transition_coord.shape[0]])
n_move = tf.shape(x_star['tm'], out_type=x_star['tm'].dtype)[0] # Number of time/metapop moves
# Calculate the coordinate that we'll move events to
coord_dtype = x_star['tm'].dtype
indices = tf.stack([tf.range(n_move, dtype=coord_dtype),
tf.zeros(n_move, dtype=coord_dtype)], axis=-1)
coord_to_move_to = tf.tensor_scatter_nd_add(tensor=x_star['tm'],
......@@ -151,16 +148,17 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
updates=tf.cast(x_star['distance'], tf.int64))
# Update the state
indices = [tf.concat([x_star['tm'], state_coord], axis=-1),
tf.concat([coord_to_move_to, state_coord], axis=-1)]
indices = tf.concat([x_star['tm'], coord_to_move_to], axis=0)
updates = tf.concat([-x_star['n_events'], x_star['n_events']], axis=0)
next_state = tf.tensor_scatter_nd_add(tensor=current_state,
indices=indices,
updates=[-x_star['n_events'], x_star['n_events']]) # Update state based on move
updates=updates) # Update state based on move
next_target_log_prob = self.target_log_prob_fn(next_state)
reverse_n = tf.gather_nd(next_state, indices[1])
log_acceptance_correction = tfd.Binomial(reverse_n, probs=self._parameters['p']).log_prob(x_star['n_events'])
log_acceptance_correction = tfd.Binomial(reverse_n, probs=self._parameters['p']).log_prob(
x_star['n_events'])
log_acceptance_correction -= proposal.log_prob(x_star) # move old->new
log_acceptance_correction = tf.reduce_sum(log_acceptance_correction)
......@@ -172,7 +170,7 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
def bootstrap_results(self, init_state):
with tf.name_scope('uncalibrated_event_times_rw/bootstrap_results'):
init_state=tf.convert_to_tensor(init_state, dtype=DTYPE)
init_state = tf.convert_to_tensor(init_state, dtype=DTYPE)
init_target_log_prob = self.target_log_prob_fn(init_state)
return tfp.mcmc.random_walk_metropolis.UncalibratedRandomWalkResults(
log_acceptance_correction=tf.constant(0., dtype=DTYPE),
......@@ -180,23 +178,68 @@ class UncalibratedEventTimesUpdate(tfp.mcmc.TransitionKernel):
)
class Gibbs(tfp.mcmc.TransitionKernel):
def get_accepted_results(results):
if hasattr(results, 'accepted_results'):
return results.accepted_results
else:
return get_accepted_results(results.inner_results)
def set_accepted_results(results, accepted_results):
if hasattr(results, 'accepted_results'):
results = results._replace(accepted_results=accepted_results)
return results
else:
next_inner_results = set_accepted_results(results.inner_results, accepted_results)
return results._replace(inner_results=next_inner_results)
def advance_target_log_prob(next_results, previous_results):
prev_accepted_results = get_accepted_results(previous_results)
next_accepted_results = get_accepted_results(next_results)
next_accepted_results = next_accepted_results._replace(target_log_prob=prev_accepted_results.target_log_prob)
return set_accepted_results(next_results, next_accepted_results)
class MH_within_Gibbs(tfp.mcmc.TransitionKernel):
def __init__(self, target_log_prob_fn, make_kernel_fns):
"""Metropolis within Gibbs sampling.
Based on Gibbs idea posted by Pavel Sountsov https://github.com/tensorflow/probability/issues/495
:param target_log_prob_fn: a function which given a list of state parts calculated the joint logp
:param make_kernel_fns: a list of functions that return an MH-compatible kernel. Functions accept a
log_prob function which in turn takes a state part.
"""
self._target_log_prob_fn = target_log_prob_fn
self._make_kernel_fns = make_kernel_fns
def is_calibrated(self):
return True
def one_step(self, state, _):
def one_step(self, state, step_results):
prev_step = np.roll(np.arange(len(state)), 1)
for i, make_kernel_fn in enumerate(self._make_kernel_fns):
def _target_log_prob_fn_part(state_part):
state[i] = state_part
return self._target_log_prob_fn(*state)
kernel = make_kernel_fn(_target_log_prob_fn_part)
state[i], _ = kernel.one_step(state[i], kernel.bootstrap_results(state[i]))
return state, ()
# results = advance_target_log_prob(step_results[i],
# step_results[prev_step[i]]) or kernel.bootstrap_results(
# state[i])
results = kernel.bootstrap_results(state[i])
state[i], step_results[i] = kernel.one_step(state[i], results)
return state, step_results
def bootstrap_results(self, state):
return ()
results = []
for i, make_kernel_fn in enumerate(self._make_kernel_fns):
def _target_log_prob_fn_part(state_part):
state[i] = state_part
return self._target_log_prob_fn(*state)
kernel = make_kernel_fn(_target_log_prob_fn_part)
results.append(kernel.bootstrap_results(state[i]))
return results
......@@ -21,7 +21,7 @@ def _gen_index(state, trm_coords):
return tf.reshape(idx, i_shp)
def make_transition_rate_matrix(rates, rate_coords, state):
def make_transition_matrix(rates, rate_coords, state):
"""Create a transition rate matrix
:param rates: batched transition rate tensors
:param rate_coords: coordinates of rates in resulting transition matrix
......
......@@ -5,7 +5,7 @@ from tensorflow_probability.python.internal import dtype_util
import numpy as np
from covid import config
from covid.impl.util import make_transition_rate_matrix
from covid.impl.util import make_transition_matrix
from covid.rdata import load_mobility_matrix, load_population, load_age_mixing
from covid.pydata import load_commute_volume, collapse_commute_data, collapse_pop
from covid.impl.discrete_markov import discrete_markov_simulation, discrete_markov_log_prob
......@@ -247,7 +247,7 @@ class CovidUKStochastic(CovidUK):
ei = tf.broadcast_to([param['nu']], shape=[state.shape[0]]) # Vector of length nc
ir = tf.broadcast_to([param['gamma']], shape=[state.shape[0]]) # Vector of length nc
rate_matrix = make_transition_rate_matrix([infec_rate, ei, ir], [[0, 1], [1, 2], [2, 3]], state)
rate_matrix = make_transition_matrix([infec_rate, ei, ir], [[0, 1], [1, 2], [2, 3]], state)
return rate_matrix
return h
......
"""Inference on stochastic models"""
import optparse
import time
import pickle as pkl
import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import h5py
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
import numpy as np
import matplotlib.pyplot as plt
import yaml
from covid.model import CovidUKStochastic, load_data
from covid import config
from covid.model import load_data, CovidUKStochastic
from covid.util import sanitise_parameter, sanitise_settings, seed_areas
DTYPE = np.float64
def random_walk_mvnorm_fn(covariance, name=None):
"""Returns callable that adds Multivariate Normal noise to the input"""
covariance = covariance + tf.eye(covariance.shape[0], dtype=tf.float64) * 1.e-9
from covid.impl.util import make_transition_matrix
from covid.impl.mcmc import UncalibratedEventTimesUpdate, MH_within_Gibbs, get_accepted_results
from covid.pydata import phe_linelist_timeseries, zero_cases, collapse_pop
DTYPE = config.floatX
phe = phe_linelist_timeseries('/home/jewellcp/Insync/jewellcp@lancaster.ac.uk/OneDrive Biz - Shared/covid19/data/PHE_2020-04-08/Anonymised Line List 20200409.csv')
pop = collapse_pop('data/c2019modagepop.csv')
cases = zero_cases(phe, pop)
ei = cases.copy()
se = cases.copy()
idx_ei = cases.index.to_frame().copy()
idx_ei['date'] = idx_ei['date'] - pd.to_timedelta(2, 'D')
ei.index = pd.MultiIndex.from_frame(idx_ei)
idx_se = cases.index.to_frame().copy()
idx_se['date'] = idx_se['date'] - pd.to_timedelta(6, 'D')
se.index = pd.MultiIndex.from_frame(idx_se)
events = pd.concat([se, ei, cases], axis=1)
events.columns = ['se','ei','ir']
events.index.set_names(['date','UTLA_code','Age'], range(3), inplace=True)
# # events.to_csv('tmp_events.csv')
#
# events = pd.read_csv('tmp_events.csv')
# events.index = pd.MultiIndex.from_frame(events[['date','UTLA_code','Age']])
# events = events[['se','ei','ir']]
events[events.isna()] = 0.0
num_times = events.index.get_level_values(0).unique().shape[0]
se_events = events['se'].to_numpy().reshape([num_times,2533])
ei_events = events['ei'].to_numpy().reshape([num_times,2533])
ir_events = events['ir'].to_numpy().reshape([num_times,2533])
event_tensor = make_transition_matrix([se_events, ei_events, ir_events], [[0, 1], [1, 2], [2, 3]],
tf.zeros([num_times, 2533,4]))
event_tensor = tf.cast(event_tensor, dtype=DTYPE)
# 53, 1342
# Fill in susceptibles
print(f"Sum cases: {events.to_numpy().sum()}")
print(f"Sum events: {tf.reduce_sum(event_tensor)}")
init_state = tf.zeros([4])
timeline = tf.cumsum(tf.reduce_sum(event_tensor, axis=[1, -2]), axis=0)
# Random moves of events. What invalidates an epidemic, how can we test for it?
with open('ode_config.yaml','r') as f:
config = yaml.load(f)
param = sanitise_parameter(config['parameter'])
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}
settings = sanitise_settings(config['settings'])
data = load_data(config['data'], settings, DTYPE)
model = CovidUKStochastic(M_tt=data['M_tt'],
M_hh=data['M_hh'],
C=data['C'],
N=data['pop']['n'].to_numpy(),
W=data['W'],
date_range=settings['inference_period'],
holidays=settings['holiday'],
lockdown=settings['lockdown'],
time_step=1.)
seeding = seed_areas(data['pop']['n']) # Seed 40-44 age group, 30 seeds by popn size
state_init = model.create_initial_state(init_matrix=seeding)
def logp(par, se, ei):
p = param
p['beta1'] = tf.convert_to_tensor(par[0], dtype=DTYPE)
p['gamma'] = tf.convert_to_tensor(par[1], dtype=DTYPE)
beta_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta1'])
gamma_logp = tfd.Gamma(concentration=tf.constant(100., dtype=DTYPE),
rate=tf.constant(400., dtype=DTYPE)).log_prob(p['gamma'])
event_tensor = make_transition_matrix([se, ei, ir_events], # ir_events global scope
[[0, 1], [1, 2], [2, 3]],
tf.zeros([num_times, 2533, 4])) # Todo: remove constant
y_logp = tf.reduce_sum(model.log_prob(event_tensor, p, state_init))
logp = beta_logp + gamma_logp + y_logp
return logp
def random_walk_mvnorm_fn(covariance, p_u=0.95, name=None):
"""Returns callable that adds Multivariate Normal noise to the input
:param covariance: the covariance matrix of the mvnorm proposal
:param p_u: the bounded convergence parameter. If equal to 1, then all proposals are drawn from the
MVN(0, covariance) distribution, if less than 1, proposals are drawn from MVN(0, covariance)
with probabilit p_u, and MVN(0, 0.1^2I_d/d) otherwise.
:returns: a function implementing the proposal.
"""
covariance = covariance + tf.eye(covariance.shape[0], dtype=DTYPE) * 1.e-9
scale_tril = tf.linalg.cholesky(covariance)
rv = tfp.distributions.MultivariateNormalTriL(loc=tf.zeros(covariance.shape[0], dtype=tf.float64),
scale_tril=scale_tril)
rv_adapt = tfp.distributions.MultivariateNormalTriL(loc=tf.zeros(covariance.shape[0], dtype=DTYPE),
scale_tril=scale_tril)
rv_fix = tfp.distributions.Normal(loc=tf.zeros(covariance.shape[0], dtype=DTYPE),
scale=0.01/covariance.shape[0],)
u = tfp.distributions.Bernoulli(probs=p_u)
def _fn(state_parts, seed):
with tf.name_scope(name or 'random_walk_mvnorm_fn'):
new_state_parts = [rv.sample() + state_part for state_part in state_parts]
def proposal():
rv = tf.stack([rv_fix.sample(), rv_adapt.sample()])
uv = u.sample()
return tf.gather(rv, uv)
new_state_parts = [proposal() + state_part for state_part in state_parts]
return new_state_parts
return _fn
unconstraining_bijector = [tfb.Exp()]
#initial_mcmc_state = event_tensor # tf.constant([0.09, 0.5], dtype=DTYPE) # beta1, gamma, I0
print("Initial log likelihood:", logp([0.05, 0.24], se_events, ei_events))
def trace_fn(state, prev_results):
return (prev_results.is_accepted,
prev_results.accepted_results.target_log_prob)
# Pavel's suggestion for a Gibbs kernel requires
# kernel factory functions.
def make_parameter_kernel(scale, bounded_convergence):
def kernel_func(logp):
return tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.RandomWalkMetropolis(
target_log_prob_fn=logp,
new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence)
),
bijector=unconstraining_bijector)
return kernel_func
def make_events_step(q, p, alpha):
def kernel_func(logp):
return tfp.mcmc.MetropolisHastings(
inner_kernel=UncalibratedEventTimesUpdate(target_log_prob_fn=logp,
q=q,
p=p,
alpha=alpha)
)
return kernel_func
def is_accepted(result):
if hasattr(result, 'is_accepted'):
return result.is_accepted
else:
return is_accepted(result.inner_results)
@tf.function # (autograph=False, experimental_compile=True)
def sample(n_samples, init_state, par_scale):
init_state = init_state.copy()
par_func = make_parameter_kernel(par_scale, 0.95)
kernel_func1 = make_events_step(q=100./192508., p=0.2, alpha=0.3)
kernel_func2 = make_events_step(q=100./192508., p=0.2, alpha=0.5)
# Based on Gibbs idea posted by Pavel Sountsov https://github.com/tensorflow/probability/issues/495
gibbs = MH_within_Gibbs(logp, [par_func, kernel_func1, kernel_func2])
results = gibbs.bootstrap_results(init_state)
samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
results_arr = [tf.TensorArray(tf.bool, size=n_samples) for r in results]
def body(i, state, prev_results, samples, results):
new_state, new_results = gibbs.one_step(state, prev_results)
samples = [samples[k].write(i, s) for k, s in enumerate(new_state)]
results = [results[k].write(i, is_accepted(r)) for k, r in enumerate(new_results)]
return i+1, new_state, new_results, samples, results
def cond(i, _1, _2, _3, _4):
return i < n_samples
_1, _2, _3, samples, results = tf.while_loop(cond=cond, body=body,
loop_vars=[0, init_state, results, samples_arr, results_arr])
return [s.stack() for s in samples], [r.stack() for r in results]
if __name__=='__main__':
num_loop_iterations = 20
num_loop_samples = 50
current_state = [np.array([0.05, 0.24], dtype=DTYPE),
se_events, ei_events]
posterior = h5py.File('posterior.h5','w')
event_size = [num_loop_iterations * num_loop_samples] + list(current_state[1].shape)
par_samples = posterior.create_dataset('samples/parameter', [num_loop_iterations*num_loop_samples, 2], dtype=np.float64)
se_samples = posterior.create_dataset('samples/S->E', event_size, dtype=DTYPE)
ei_samples = posterior.create_dataset('samples/E->I', event_size, dtype=DTYPE)
par_results = posterior.create_dataset('acceptance/parameter', (num_loop_iterations * num_loop_samples,), dtype=np.bool)
se_results = posterior.create_dataset('acceptance/S->E', (num_loop_iterations * num_loop_samples,), dtype=np.bool)
ei_results = posterior.create_dataset('acceptance/E->I', (num_loop_iterations * num_loop_samples,), dtype=np.bool)
par_scale = tf.convert_to_tensor([0.001, 0.001], dtype=DTYPE)
# We loop over successive calls to sample because we have to dump results
# to disc, or else end OOM (even on a 32GB system).
for i in tqdm.tqdm(range(num_loop_iterations), unit_scale=num_loop_samples):
samples, results = sample(num_loop_samples, init_state=current_state, par_scale=par_scale)
current_state = [s[-1] for s in samples]
s = slice(i*num_loop_samples, i*num_loop_samples+num_loop_samples)
par_samples[s, ...] = samples[0].numpy()
se_samples[s, ...] = samples[1].numpy()
ei_samples[s, ...] = samples[2].numpy()
par_results[s, ...] = results[0].numpy()
se_results[s, ...] = results[1].numpy()
ei_results[s, ...] = results[2].numpy()
print("Acceptance0:", tf.reduce_mean(tf.cast(results[0], tf.float32)))
print("Acceptance1:", tf.reduce_mean(tf.cast(results[1], tf.float32)))
print("Acceptance2:", tf.reduce_mean(tf.cast(results[2], tf.float32)))
print(f'Acceptance param: {par_results[:].mean()}')
print(f'Acceptance S->E: {se_results[:].mean()}')
print(f'Acceptance E->I: {ei_results[:].mean()}')
posterior.close()
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option("--config", "-c", dest="config", default="ode_config.yaml",
help="configuration file")
options, args = parser.parse_args()
with open(options.config, 'r') as ymlfile:
config = yaml.load(ymlfile)
param = sanitise_parameter(config['parameter'])
settings = sanitise_settings(config['settings'])
data = load_data(config['data'], settings, DTYPE)
model = CovidUKStochastic(M_tt=data['M_tt'],
M_hh=data['M_hh'],
C=data['C'],
N=data['pop']['n'].to_numpy(),
W=data['W'],
date_range=settings['inference_period'],
holidays=settings['holiday'],
lockdown=settings['lockdown'],
time_step=1.)
with open('stochastic_sim.pkl', 'rb') as f:
sim = pkl.load(f)
events = tf.convert_to_tensor(sim['events'], dtype=DTYPE)
state_init = tf.convert_to_tensor(sim['state_init'], dtype=DTYPE)
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}
def logp(par):
print("Tracing logp")
p = param
p['beta1'] = par[0]
p['beta3'] = par[1]
p['gamma'] = par[2]
beta_logp = tfd.Gamma(concentration=tf.constant(1., dtype=DTYPE),
rate=tf.constant(1., dtype=DTYPE)).log_prob(p['beta1'])
beta3_logp = tfd.Gamma(concentration=tf.constant(20., dtype=DTYPE),
rate=tf.constant(20., dtype=DTYPE)).log_prob(p['beta3'])
gamma_logp = tfd.Gamma(concentration=tf.constant(100., dtype=DTYPE),
rate=tf.constant(400., dtype=DTYPE)).log_prob(p['gamma'])
y_logp = model.log_prob(events, p, state_init)
logp = beta_logp + beta3_logp + gamma_logp + y_logp
return logp
unconstraining_bijector = [tfb.Exp()]
initial_mcmc_state = tf.constant([0.05, 0.5, 0.25], dtype=tf.float64) # beta1, gamma, I0
print("Initial log likelihood:", logp(initial_mcmc_state))
@tf.function(experimental_compile=True)
def sample(n_samples, init_state, scale, num_burnin_steps=0):
return tfp.mcmc.sample_chain(
num_results=n_samples,
num_burnin_steps=num_burnin_steps,
current_state=init_state,
kernel=tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.RandomWalkMetropolis(
target_log_prob_fn=logp,
new_state_fn=random_walk_mvnorm_fn(scale)
),
bijector=unconstraining_bijector),
trace_fn=lambda _, pkr: pkr.inner_results.is_accepted)
joint_posterior = tf.zeros([0] + list(initial_mcmc_state.shape), dtype=DTYPE)
scale = np.diag([0.1, 0.1, 0.1])
overall_start = time.perf_counter()
num_covariance_estimation_iterations = 20
num_covariance_estimation_samples = 50
num_final_samples = 10000
start = time.perf_counter()
for i in range(num_covariance_estimation_iterations):
step_start = time.perf_counter()
samples, results = sample(num_covariance_estimation_samples,
initial_mcmc_state,
scale)
step_end = time.perf_counter()
print(f'{i} time {step_end - step_start}')
print("Acceptance: ", results.numpy().mean())
joint_posterior = tf.concat([joint_posterior, samples], axis=0)
cov = tfp.stats.covariance(tf.math.log(joint_posterior))
print(cov.numpy())
scale = cov * 2.38**2 / joint_posterior.shape[1]
initial_mcmc_state = joint_posterior[-1, :]
step_start = time.perf_counter()
#tf.profiler.experimental.start('mcmc_logdir')
samples, results = sample(num_final_samples,
init_state=joint_posterior[-1, :], scale=scale,)
#tf.profiler.experimental.stop()
joint_posterior = tf.concat([joint_posterior, samples], axis=0)
step_end = time.perf_counter()
print(f'Sampling step time {step_end - step_start}')
end = time.perf_counter()
print(f"Simulation complete in {end-start} seconds")
print("Acceptance: ", np.mean(results.numpy()))
print(tfp.stats.covariance(tf.math.log(joint_posterior)))
fig, ax = plt.subplots(1, joint_posterior.shape[1])
for i in range(joint_posterior.shape[1]):
ax[i].plot(joint_posterior[:, i])
plt.show()
print(f"Posterior mean: {np.mean(joint_posterior, axis=0)}")
with open('stochastic_posterior.pkl', 'wb') as f:
pkl.dump(joint_posterior, f)
\ No newline at end of file