Commit c4134880 authored by Chris Jewell's avatar Chris Jewell
Browse files

Re-wrote Gibbs sampler logic

parent ce398f4b
......@@ -170,11 +170,21 @@ def trace_results_fn(results):
return tf.concat([[log_prob], [accepted], [q_ratio]], axis=0)
def forward_results(prev_results, next_results):
accepted_results = next_results.accepted_results._replace(
def get_tlp(results):
return results.accepted_results.target_log_prob
def put_tlp(results, target_log_prob):
accepted_results = results.accepted_results._replace(
return next_results._replace(accepted_results=accepted_results)
return results._replace(accepted_results=accepted_results)
def invoke_one_step(kernel, state, previous_results, target_log_prob):
current_results = put_tlp(previous_results, target_log_prob)
new_state, new_results = kernel.one_step(state, current_results)
return new_state, new_results, get_tlp(new_results)
@tf.function(autograph=False, experimental_compile=True)
......@@ -210,21 +220,18 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
samples_arr = [tf.TensorArray(s.dtype, size=n_samples) for s in init_state]
results_arr = [tf.TensorArray(DTYPE, size=n_samples) for r in range(5)]
def body(i, state, results, sample_accum, results_accum):
def body(i, state, results, target_log_prob, sample_accum, results_accum):
# Parameters
def par_logp(par_state):
state[0] = par_state # close over state from outer scope
return logp(*state)
par_kernel = par_func(par_logp)
state[0], results[0] = par_kernel.one_step(
state[0], par_kernel.bootstrap_results(state[0])
state[0], results[0], target_log_prob = invoke_one_step(
par_func(par_logp), state[0], results[0], target_log_prob,
# States
results[4] = forward_results(results[0], results[4])
def infec_body(j, state, results):
def infec_body(j, state, results, target_log_prob):
def state_logp(event_state):
state[1] = event_state
return logp(*state)
......@@ -233,30 +240,32 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
state[2] = occult_state
return logp(*state)
state[1], results[1] = se_func(state_logp).one_step(
state[1], forward_results(results[4], results[1])
state[1], results[1], target_log_prob = invoke_one_step(
se_func(state_logp), state[1], results[1], target_log_prob
state[1], results[2] = ei_func(state_logp).one_step(
state[1], forward_results(results[1], results[2])
state[1], results[2], target_log_prob = invoke_one_step(
ei_func(state_logp), state[1], results[2], target_log_prob
state[2], results[3] = se_occult(occult_logp).one_step(
state[2], forward_results(results[2], results[3])
state[2], results[3], target_log_prob = invoke_one_step(
se_occult(occult_logp), state[2], results[3], target_log_prob
# results[3] = forward_results(results[2], results[3])
state[2], results[4] = ei_occult(occult_logp).one_step(
state[2], forward_results(results[3], results[4])
state[2], results[4], target_log_prob = invoke_one_step(
ei_occult(occult_logp), state[2], results[4], target_log_prob
# results[4] = forward_results(results[3], results[4])
j += 1
return j, state, results
return j, state, results, target_log_prob
def infec_cond(j, state, results):
def infec_cond(j, state, results, target_log_prob):
return j < num_event_updates
_, state, results = tf.while_loop(
_, state, results, target_log_prob = tf.while_loop(
loop_vars=[tf.constant(0, tf.int32), state, results],
loop_vars=[tf.constant(0, tf.int32), state, results, target_log_prob],
sample_accum = [sample_accum[k].write(i, s) for k, s in enumerate(state)]
......@@ -264,15 +273,22 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
results_accum[k].write(i, trace_results_fn(r))
for k, r in enumerate(results)
return i + 1, state, results, sample_accum, results_accum
return i + 1, state, results, target_log_prob, sample_accum, results_accum
def cond(i, _1, _2, _3, _4):
def cond(i, *_):
return i < n_samples
_1, _2, _3, samples, results = tf.while_loop(
_1, _2, _3, target_log_prob, samples, results = tf.while_loop(
loop_vars=[0, init_state, results, samples_arr, results_arr],
return [s.stack() for s in samples], [r.stack() for r in results]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment