Commit 2a5043c0 authored by Chris Jewell's avatar Chris Jewell
Browse files

Chunked calculation of Rt metrics to avoid OOM on the GPU.

parent 4afc60f0
......@@ -47,24 +47,36 @@ def calc_posterior_rit(samples, initial_state, times, covar_data):
)
CHUNKSIZE = 50
def reproduction_number(input_files, output_file):
covar_data = xarray.open_dataset(input_files[0], group="constant_data")
with open(input_files[1], "rb") as f:
samples = pkl.load(f)
num_samples = samples["seir"].shape[0]
initial_state = samples["initial_state"]
del samples["initial_state"]
times = np.arange(covar_data.coords["time"].shape[0])
# Compute ngm posterior
r_it = calc_posterior_rit(samples, initial_state, times, covar_data)
# Compute ngm posterior in chunks to prevent over-memory
r_its = []
for i in range(0, num_samples, CHUNKSIZE):
start = i
end = np.minimum(i + CHUNKSIZE, num_samples)
print(f"Chunk {start}:{end}", flush=True)
subsamples = {k: v[start:end] for k, v in samples.items()}
r_it = calc_posterior_rit(subsamples, initial_state, times, covar_data)
r_its.append(r_it)
r_it = xarray.DataArray(
r_it,
tf.concat(r_its, axis=0),
coords=[
np.arange(r_it.shape[0]),
np.arange(num_samples),
covar_data.coords["time"][times],
covar_data.coords["location"],
],
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment