Commit 68b27bba authored by Chris Jewell's avatar Chris Jewell
Browse files

Implemented LTLA-level COVID-19 model

Changes:

1. Replaced 149 UTLAs with 315 LTLAs mixing matrix;
2. Wrote geometric initialisation for censored event times;
3. Modified data ingester to take PHE Anonymised Line Listing data.
parent f0a08b1a
......@@ -6,8 +6,8 @@ import numpy as np
from covid import config
from covid.impl.util import make_transition_matrix
from covid.rdata import load_mobility_matrix, load_population, load_age_mixing
from covid.pydata import load_commute_volume, collapse_commute_data, collapse_pop
from covid.rdata import load_age_mixing
from covid.pydata import load_commute_volume, load_mobility_matrix, load_population
from covid.impl.discrete_markov import (
discrete_markov_simulation,
discrete_markov_log_prob,
......@@ -50,20 +50,20 @@ def load_data(paths, settings, dtype=DTYPE):
M_tt, age_groups = load_age_mixing(paths["age_mixing_matrix_term"])
M_hh, _ = load_age_mixing(paths["age_mixing_matrix_hol"])
C = collapse_commute_data(paths["mobility_matrix"])
C = load_mobility_matrix(paths["mobility_matrix"])
la_names = C.index.to_numpy()
w_period = [settings["inference_period"][0], settings["prediction_period"][1]]
W = load_commute_volume(paths["commute_volume"], w_period)["percent"]
pop = collapse_pop(paths["population_size"])
pop = load_population(paths["population_size"])
M_tt = M_tt.astype(DTYPE)
M_hh = M_hh.astype(DTYPE)
C = C.to_numpy().astype(DTYPE)
np.fill_diagonal(C, 0.0)
W = W.astype(DTYPE)
pop["n"] = pop["n"].astype(DTYPE)
W = W.to_numpy().astype(DTYPE)
pop = pop.to_numpy().astype(DTYPE)
return {
"M_tt": M_tt,
......@@ -155,7 +155,7 @@ class CovidUKStochastic(CovidUK):
* commute_volume
* tf.linalg.matvec(self.C, state[..., 2] / self.N)
)
infec_rate = infec_rate / self.N + 0.00000001 # Vector of length nc
infec_rate = infec_rate / self.N + 0.000000001 # Vector of length nc
ei = tf.broadcast_to(
[param["nu"]], shape=[state.shape[0]]
......
......@@ -11,210 +11,89 @@ import pyreadr as pyr
def load_commute_volume(filename, date_range):
"""Loads commute data and clips or extends date range"""
commute_raw = pd.read_csv(filename, index_col='date')
commute_raw.index = pd.to_datetime(commute_raw.index, format='%d/%m/%Y')
commute_raw = pd.read_csv(filename, index_col="date")
commute_raw.index = pd.to_datetime(commute_raw.index, format="%d/%m/%Y")
commute_raw.sort_index(axis=0, inplace=True)
commute = pd.DataFrame(index=np.arange(date_range[0], date_range[1], np.timedelta64(1,'D')))
commute = commute.merge(commute_raw, left_index=True, right_index=True, how='left')
commute = pd.DataFrame(
index=np.arange(date_range[0], date_range[1], np.timedelta64(1, "D"))
)
commute = commute.merge(commute_raw, left_index=True, right_index=True, how="left")
commute[commute.index < commute_raw.index[0]] = commute_raw.iloc[0, 0]
commute[commute.index > commute_raw.index[-1]] = commute_raw.iloc[-1, 0]
return commute
def group_ages(df):
"""
Sums age groups
:param df: a dataframe with columns 0,1,2,...,90
:return: a dataframe with 5-year age groups
"""
ages = np.arange(90).reshape([90//5, 5]).astype(np.str)
grouped_ages = pd.DataFrame()
for age_group in ages:
grouped_ages[f"[{age_group[0]}-{int(age_group[-1])+1})"] = df[age_group].sum(axis=1)
grouped_ages['[90,)'] = df[['90']]
grouped_ages['[80,inf)'] = grouped_ages[['[80-85)', '[85-90)', '[90,)']].sum(axis=1)
grouped_ages = grouped_ages.drop(columns=['[80-85)', '[85-90)', '[90,)'])
return grouped_ages
def ingest_data(lad_shp, lad_pop):
pop = pd.read_csv(lad_pop, skiprows=4, thousands=',')
age_pop = group_ages(pop)
age_pop.index = pop['Code']
lad = gp.read_file(lad_shp)
lad.index = lad['lad19cd'].rename('Code')
lad = lad.iloc[lad.index.str.match('^E0[6-9]'), :]
lad = lad.merge(age_pop, on='Code')
lad.sort_index(inplace=True)
lad.drop(columns=['objectid', 'lad19cd' ,'long', 'lat'])
N = lad.iloc[:, lad.columns.str.match(pat='^[[0-9]')].stack()
print(f"Found {lad.shape[0]} LADs")
return {'geo': lad, 'N': N}
def phe_death_timeseries(filename, date_range=['2020-02-02', '2020-03-21']):
date_range = [np.datetime64(x) for x in date_range]
csv = pd.read_excel(filename)
cases = pd.DataFrame({'hospital': csv.groupby(['Hospital admission date (non-HCID)', 'Region']).size(),
'deaths': csv.groupby(['PATIENT_DEATH_DATE', 'Region']).size()})
cases.index.rename(['date', 'region'], [0, 1], inplace=True)
cases.reset_index(inplace=True)
cases = cases.pivot(index='date', columns='region')
dates = pd.DataFrame(index=pd.DatetimeIndex(np.arange(*date_range, np.timedelta64(1, 'D'))))
combined = dates.merge(cases, how='left', left_index=True, right_index=True)
combined.columns = pd.MultiIndex.from_tuples(combined.columns, names=['timeseries','region'])
combined[combined.isna()] = 0.0
output = {k: combined.loc[:, [k, None]] for k in combined.columns.levels[0]}
return output
def phe_death_hosp_to_death(filename, date_range=['2020-02-02', '2020-03-21']):
date_range = [np.datetime64(x) for x in date_range]
csv = pd.read_excel(filename)
data = csv.loc[:, ['Sex', 'Age', 'Underlying medical condition?', 'Hospital admission date (non-HCID)',
'PATIENT_DEATH_DATE']]
data.columns = ['sex','age','underlying_condition', 'hosp_adm_date', 'death_date']
data.loc[:, 'underlying_condition'] = data['underlying_condition'] == 'Yes'
data['adm_to_death'] = (data['death_date'] - data['hosp_adm_date']) / np.timedelta64(1, 'D')
return data.dropna(axis=0)
def load_mobility_matrix(flow_file):
"""Loads mobility matrix from rds file"""
mobility = list(pyr.read_r(flow_file).values())[0]
mobility = mobility[
mobility["From"].str.startswith("E") & mobility["To"].str.startswith("E")
]
mobility = mobility.sort_values(["From", "To"])
mobility = mobility.groupby(["From", "To"]).agg({"Flow": sum}).reset_index()
mob_matrix = mobility.pivot(index="To", columns="From", values="Flow")
mob_matrix[mob_matrix.isna()] = 0.0
return mob_matrix
def load_population(pop_file):
pop = pd.read_csv(pop_file, index_col="lad19cd")
pop = pop[pop.index.str.startswith("E")]
pop = pop.sum(axis=1)
pop = pop.sort_index()
pop.name = "n"
return pop
def phe_linelist_timeseries(filename, spec_date='specimen_date', utla='UTLA_code', age='Age',
date_range=None):
def linelist2timeseries(date, region_code, date_range=None):
"""Constructs a daily aggregated timeseries given dates and region code
Optionally accepts a list expressing a required date range."""
linelist = pd.read_csv(filename)
linelist = linelist[[spec_date, utla, age]]
linelist = pd.DataFrame(dict(date=pd.to_datetime(date), region_code=region_code))
# 1. clip dates
one_day = np.timedelta64(1, 'D')
linelist.index = pd.Index(pd.to_datetime(linelist[spec_date], format="%d/%m/%Y"),name='date')
if date_range is not None:
linelist = linelist[date_range[0]:date_range[1]]
linelist = linelist[
(date_range[0] <= linelist["date"]) & (linelist["date"] <= date_range[1])
]
raw_len = linelist.shape[0]
# 2. Remove NA rows
linelist = linelist.dropna(axis=0) # remove na's
warn(f"Removed {raw_len - linelist.shape[0]} rows of {raw_len} due to missing data \
({100. * (raw_len - linelist.shape[0])/raw_len}%)")
# 2a. Aggregate London/Westminster and Cornwall/Scilly
london = ['E09000001', 'E09000033']
corn_scilly = ['E06000052', 'E06000053']
linelist.loc[linelist[utla].isin(london), utla] = ','.join(london)
linelist.loc[linelist[utla].isin(corn_scilly), utla] = ','.join(corn_scilly)
warn(
f"Removed {raw_len - linelist.shape[0]} rows of {raw_len} due to missing data \
({100. * (raw_len - linelist.shape[0])/raw_len}%)"
)
# 3. Create age groups
linelist['age_group'] = np.clip(linelist[age] // 5, a_min=0, a_max=16).astype(np.int64) * 5 # id of 5-year age group
# 4. Group by UTLA/age
case_counts = linelist.groupby(['date', utla, 'age_group']).size()
# 3. Aggregate by date/region and sort on index
case_counts = linelist.groupby(["date", "region_code"]).size()
case_counts.sort_index(axis=0, inplace=True)
case_counts.index.names = ['date','UTLA19CD','age_group']
# 4. Reindex by day
one_day = np.timedelta64(1, "D")
full_dates = pd.date_range(
case_counts.index.levels[0].min(), case_counts.index.levels[0].max() + one_day,
)
index = pd.MultiIndex.from_product(
[full_dates, case_counts.index.levels[1]], names=["date", "region_code"]
)
case_counts = case_counts.reindex(index)
case_counts.loc[case_counts.isna()] = 0.0
case_counts.name = "count"
return case_counts
def zero_cases(case_timeseries, population):
"""Creates a full case matrix, filling in dates, lads, and age groups not represented
in the main dataset. It is explicitly assumed that missing date/lad/age combos in the
case_timeseries are true 0s.
:param case_timeseries: an indexed [date, UTLA_code, age_group] pd.Series containing case counts
:param population: a dataset indexed with all UTLA_codes and age_groups
"""
dates = np.arange(case_timeseries.index.levels[0].min(),
case_timeseries.index.levels[0].max() + np.timedelta64(1, 'D'), # inclusive interval
np.timedelta64(1, 'D'))
fullidx = pd.MultiIndex.from_product([dates, *population.index.levels],
names=['date', *population.index.names])
y = case_timeseries.reindex(fullidx)
y[y.isna()] = 0. # Big assumption that a missing value is a true 0!
return y
def collapse_commute_data(flow_file):
"""Collapses LTLA-based commuting data in England to UTLA areas.
Merges commuting data at LTLA areal basis onto modified UTLA Dec 2019 area.
Modifications:
E06000052, E06000053 combined
E09000001, E09000033 combined
"""
filedir = os.path.dirname(os.path.abspath(__file__))
commuting = list(pyr.read_r(flow_file).values())[0]
lt_map = pd.read_csv(filedir + '/../data/Lower_Tier_Local_Authority_to_Upper_Tier_Local_Authority_April_2019_Lookup_in_England_and_Wales.csv')
lt_map = lt_map[['LTLA19CD', 'UTLA19CD']]
# 1. Extract England
commuting = commuting[commuting['From'].str.startswith('E') & commuting['To'].str.startswith('E')]
# 1. Merge in lt_map on 'From' field
def merge(left, right, left_on, right_on, new_cols):
merged = left.merge(right, how='left', left_on=left_on, right_on=right_on)
colnames = merged.columns.to_numpy()
colnames[-len(new_cols):] = new_cols
merged.columns = pd.Index(colnames)
return merged
commuting = merge(commuting, lt_map, 'From', 'LTLA19CD', ['from_ltla', 'from_utla'])
commuting = merge(commuting, lt_map, 'To', 'LTLA19CD', ['to_ltla', 'to_utla'])
# 2. Fix up collapsed UTLAs
commuting.loc[commuting['From'].str.contains(','), 'from_utla'] = commuting.loc[
commuting['From'].str.contains(','), 'From']
commuting.loc[commuting['To'].str.contains(','), 'to_utla'] = commuting.loc[
commuting['To'].str.contains(','), 'To']
# 3. Collapse data
collapsed = commuting.groupby(['from_utla', 'to_utla']).agg({'Flow': sum})
collapsed.sort_index(inplace=True)
collapsed.reset_index(inplace=True)
# 4. Pivot to return a matrix
commute_matrix = collapsed.pivot(index='to_utla', columns='from_utla', values='Flow')
commute_matrix[commute_matrix.isna()] = 0.0
return commute_matrix
def collapse_pop(pop_file):
"""Aggregates LTLA2019 population data to UTLA2019 and 5-year age groups to 80+"""
filedir = os.path.dirname(os.path.abspath(__file__))
pop = pd.read_csv(pop_file)
pop = pop[pop['lad19cd'].str.startswith('E')]
lt_map = pd.read_csv(filedir + '/../data/Lower_Tier_Local_Authority_to_Upper_Tier_Local_Authority_April_2019_Lookup_in_England_and_Wales.csv')
# 1. Merge LADs
pop = pop.merge(lt_map['UTLA19CD'], how='left', left_on='lad19cd', right_on=lt_map['LTLA19CD'])
# 2. Fill in merged utla codes
pop.loc[pop['lad19cd'].str.contains(','), 'UTLA19CD'] = pop.loc[pop['lad19cd'].str.contains(','), 'lad19cd']
pop.index = pd.MultiIndex.from_frame(pop[['UTLA19CD', 'lad19cd']])
pop.drop(columns=['lad19cd', 'UTLA19CD'], inplace=True)
pop.columns = np.arange(pop.shape[1]) * 5 # 5 year age groups
pop.sort_index(inplace=True)
# 3. Aggregate by UTLA19CD
pop = pop.sum(level=0)
pop.iloc[:, -3] = pop.iloc[:, -3:].sum(axis=1)
pop = pop.iloc[:, :-2]
# 4. Long format
pop = pop.reset_index().melt(id_vars=['UTLA19CD'], value_name='n', var_name='age_group')
pop.index = pd.MultiIndex.from_frame(pop[['UTLA19CD', 'age_group']])
pop.drop(columns=['UTLA19CD', 'age_group'], inplace=True)
pop.sort_index(level=0, inplace=True)
def phe_case_data(linelisting_file, date_range=None):
return pop
ll = pd.read_excel(linelisting_file)
date = ll["specimen_date"]
ltla_region = ll["LTLA_code"]
if __name__=='__main__':
# Merged regions
london = ["E09000001", "E09000033"]
corn_scilly = ["E06000052", "E06000053"]
ltla_region.loc[ltla_region.isin(london)] = ",".join(london)
ltla_region.loc[ltla_region.isin(corn_scilly)] = ",".join(corn_scilly)
ts = phe_linelist_timeseries('/home/jewellcp/Insync/jewellcp@lancaster.ac.uk/OneDrive Biz - Shared/covid19/data/PHE_2020-04-01/Anonymised Line List 20200401.csv')
print(ts)
ts = linelist2timeseries(date, ltla_region, date_range)
return ts.reset_index().pivot(index="region_code", columns="date", values="count")
......@@ -366,3 +366,63 @@ def plot_event_posterior(posterior, simulation, metapopulation=0):
ax[1][0].set_ylabel("E->I")
return fig, ax
def distribute_geom(events, rate, delta_t=1.0):
"""Given a tensor `events`, returns a tensor of shape `events.shape + [t]`
representing the events distributed over a number of days given geometric
waiting times with rate `1-exp(-rate*delta_t)`"""
events = tf.convert_to_tensor(events)
rate = tf.convert_to_tensor(rate, dtype=events.dtype)
accum = tf.TensorArray(events.dtype, size=0, dynamic_size=True)
prob = 1.0 - tf.exp(-rate * delta_t)
def body(i, events_, accum_):
rv = tfd.Binomial(total_count=events_, probs=prob)
failures = rv.sample()
accum_ = accum_.write(i, failures)
i += 1
return i, events_ - failures, accum_
def cond(_1, events_, _2):
return tf.reduce_sum(events_) > tf.constant(0, dtype=events.dtype)
_1, _2, accum = tf.while_loop(cond, body, loop_vars=[1, events, accum])
return tf.transpose(accum.stack(), perm=(1, 0, 2))
def reduce_diagonals(m):
def fn(m_):
idx = (
tf.range(m_.shape[-1])
- tf.range(m_.shape[-2])[:, tf.newaxis]
+ m_.shape[-2]
- 1
)
idx = tf.expand_dims(idx, axis=-1)
return tf.scatter_nd(idx, m_, [m_.shape[-2] + m_.shape[-1] - 1])
return tf.vectorized_map(fn, m)
def impute_previous_cases(events, rate, delta_t=1.0):
"""Imputes previous numbers of cases by using a geometric distribution
:param events: a [M, T] tensor
:param rate: the failure rate per `delta_t`
:param delta_t: the size of the time step
:returns: a tuple containing the matrix of events and the maximum
number of timesteps into the past to allow padding of `events`.
"""
prev_case_distn = distribute_geom(events, rate, delta_t)
prev_cases = reduce_diagonals(prev_case_distn)
# Trim preceding zero days
total_events = tf.reduce_sum(prev_cases, axis=-2)
num_zero_days = total_events.shape[-1] - tf.math.count_nonzero(
tf.cumsum(total_events, axis=-1)
)
return prev_cases[..., num_zero_days:], prev_case_distn.shape[-2] - num_zero_days
......@@ -3,6 +3,7 @@ import optparse
import os
import pickle as pkl
from collections import OrderedDict
from time import perf_counter
import h5py
import numpy as np
......@@ -13,7 +14,8 @@ import yaml
from covid import config
from covid.model import load_data, CovidUKStochastic
from covid.util import sanitise_parameter, sanitise_settings
from covid.pydata import phe_case_data
from covid.util import sanitise_parameter, sanitise_settings, impute_previous_cases
from covid.impl.mcmc import UncalibratedLogRandomWalk, random_walk_mvnorm_fn
from covid.impl.event_time_mh import UncalibratedEventTimesUpdate
from covid.impl.occult_events_mh import UncalibratedOccultUpdate
......@@ -36,60 +38,45 @@ else:
print("Using CPU")
# Read in settings
parser = optparse.OptionParser()
parser.add_option(
"--config",
"-c",
dest="config",
default="ode_config.yaml",
help="configuration file",
)
options, args = parser.parse_args()
print("Loading config file:", options.config)
with open(options.config, "r") as f:
config = yaml.load(f)
# parser = optparse.OptionParser()
# parser.add_option(
# "--config",
# "-c",
# dest="config",
# default="ode_config.yaml",
# help="configuration file",
# )
# options, cmd_args = parser.parse_args()
# print("Loading config file:", options.config)
# with open(options.config, "r") as f:
with open("ode_config.yaml", "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
print("Config:", config)
settings = sanitise_settings(config["settings"])
param = sanitise_parameter(config["parameter"])
param = {k: tf.constant(v, dtype=DTYPE) for k, v in param.items()}
settings = sanitise_settings(config["settings"])
covar_data = load_data(config["data"], settings, DTYPE)
cases = phe_case_data(config["data"]["reported_cases"], settings["inference_period"])
ei_events, lag_ei = impute_previous_cases(cases, 0.25)
se_events, lag_se = impute_previous_cases(ei_events, 0.25)
ir_events = np.pad(cases, ((0, 0), (lag_ei + lag_se - 2, 0)))
ei_events = np.pad(ei_events, ((0, 0), (lag_se - 1, 0)))
data = load_data(config["data"], settings, DTYPE)
data["pop"] = data["pop"].sum(level=0)
model = CovidUKStochastic(
C=data["C"],
N=data["pop"]["n"].to_numpy(),
W=data["W"],
C=covar_data["C"],
N=covar_data["pop"],
W=covar_data["W"],
date_range=settings["inference_period"],
holidays=settings["holiday"],
lockdown=settings["lockdown"],
time_step=1.0,
)
# Load data
with open("stochastic_sim_covid1.pkl", "rb") as f:
example_sim = pkl.load(f)
event_tensor = example_sim["events"] # shape [T, M, S, S]
event_tensor = event_tensor[:60, ...]
num_times = event_tensor.shape[0]
num_meta = event_tensor.shape[1]
state_init = example_sim["state_init"]
se_events = event_tensor[:, :, 0, 1] # [T, M, X]
ei_events = event_tensor[:, :, 1, 2] # [T, M, X]
ir_events = event_tensor[:, :, 2, 3] # [T, M, X]
ir_events = np.pad(ir_events, ((4, 0), (0, 0)), mode="constant", constant_values=0.0)
ei_events = np.roll(ir_events, shift=-2, axis=0)
se_events = np.roll(ir_events, shift=-4, axis=0)
ei_events[-2:, ...] = 0.0
se_events[-4:, ...] = 0.0
##########################
# Log p and MCMC kernels #
##########################
......@@ -159,7 +146,7 @@ def make_occults_step(target_event_id):
target_log_prob_fn=logp,
target_event_id=target_event_id,
nmax=config["mcmc"]["occult_nmax"],
t_range=[se_events.shape[0] - 21, se_events.shape[0]],
t_range=[se_events.shape[1] - 21, se_events.shape[1]],
),
name="occult_update",
)
......@@ -194,7 +181,7 @@ def forward_results(prev_results, next_results):
def sample(n_samples, init_state, par_scale, num_event_updates):
with tf.name_scope("main_mcmc_sample_loop"):
init_state = init_state.copy()
par_func = make_parameter_kernel(par_scale, 0.95)
par_func = make_parameter_kernel(par_scale, 0.0)
se_func = make_events_step(0, None, 1)
ei_func = make_events_step(1, 0, 2)
se_occult = make_occults_step(0)
......@@ -229,8 +216,9 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
state[0] = par_state # close over state from outer scope
return logp(*state)
state[0], results[0] = par_func(par_logp).one_step(
state[0], forward_results(results[2], results[0])
par_kernel = par_func(par_logp)
state[0], results[0] = par_kernel.one_step(
state[0], par_kernel.bootstrap_results(state[0])
)
# States
......@@ -254,7 +242,7 @@ def sample(n_samples, init_state, par_scale, num_event_updates):
state[2], results[3] = se_occult(occult_logp).one_step(
state[2], forward_results(results[2], results[3])
)
# results[3] = forward_results(results[2], results[3])
# results[3] = forward_results(results[2], results[3])
state[2], results[4] = ei_occult(occult_logp).one_step(
state[2], forward_results(results[3], results[4])
)
......@@ -303,15 +291,21 @@ NUM_EVENT_TIME_UPDATES = config["mcmc"]["num_event_time_updates"]
tf.random.set_seed(2)
# Initial state. NB [M, T, X] layout for events.
events = tf.transpose(
tf.stack([se_events, ei_events, ir_events], axis=-1), perm=(1, 0, 2)
)
current_state = [np.array([0.6, 0.25], dtype=DTYPE), events, tf.zeros_like(events)]
events = tf.stack([se_events, ei_events, ir_events], axis=-1)
state_init = tf.concat([model.N[:, tf.newaxis], events[:, 0, :]], axis=-1)
events = events[:, 1:, :]
current_state = [
np.array([0.85, 0.25], dtype=DTYPE),
events,
tf.zeros_like(events),
]
# Output Files
posterior = h5py.File(
os.path.expandvars(config["output"]["posterior"]), "w", rdcc_nbytes=1024 ** 3 * 2,
os.path.expandvars(config["output"]["posterior"]),
"w",
rdcc_nbytes=1024 ** 2 * 400,
rdcc_nslots=100000,
)
event_size = [NUM_BURSTS * NUM_BURST_SAMPLES] + list(current_state[1].shape)
# event_chunk = (10, 1, 1, 1)
......@@ -325,17 +319,15 @@ event_samples = posterior.create_dataset(
"samples/events",
event_size,
dtype=DTYPE,
chunks=(10,) + tuple(current_state[1].shape),
compression="gzip",
compression_opts=1,
chunks=(1024, 64, 64, current_state[1].shape[-1]),
compression="lzf",
)
occult_samples = posterior.create_dataset(
"samples/occults",
event_size,
dtype=DTYPE,
chunks=(10,) + tuple(current_state[1].shape),
compression="gzip",
compression_opts=1,
chunks=(1024, 64, 64, current_state[1].shape[-1]),
compression="lzf",
)
output_results = [
......@@ -362,7 +354,7 @@ output_results = [
print("Initial logpi:", logp(*current_state))
par_scale = tf.linalg.diag(
tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 1.0
tf.ones(current_state[0].shape, dtype=current_state[0].dtype) * 0.1
)
# We loop over successive calls to sample because we have to dump results
......@@ -385,14 +377,18 @@ for i in tqdm.tqdm(range(NUM_BURSTS), unit_scale=NUM_BURST_SAMPLES):
print(current_state[0].numpy())
print(cov)
if np.all(np.isfinite(cov)):
par_scale = 2.0 ** 2 * cov / 2.0
if (i * NUM_BURST_SAMPLES) > 1000 and np.all(np.isfinite(cov)):
par_scale = 2.38 ** 2 * cov / 2.0
start = perf_counter()
event_samples[s, ...] = samples[1].numpy()
occult_samples[s, ...] = samples[2].numpy()
end = perf_counter()
for i, ro in enumerate(output_results):
ro[s, ...] = results[i]
print("Storage time:", end - start, "seconds")
print("Acceptance par:", tf.reduce_mean(tf.cast(results[0][:, 1], tf.float32)))
print(