Commit 3dbeaf7a authored by Chris Jewell's avatar Chris Jewell
Browse files

Improved PEP8 conformance in interface, with knock-ons to other files.

parent 82dd88f4
Pipeline #269 failed with stage
in 4 minutes and 16 seconds
......@@ -21,4 +21,6 @@
from gem.interface import GEM
import gem.version as version
"""GEM domain-specific language and analysis for epidemic modelling."""
__version__ = f"{version.MAJOR}.{version.MINOR}.{version.MAINTENANCE}.{version.RELEASE}"
......@@ -18,6 +18,8 @@
# #
#
"""The main GEM Python interface"""
import time
from collections import OrderedDict
......@@ -35,24 +37,24 @@ from gem.gemlang.semantics.symbol_declare import SymbolDeclarer
from gem.model.edward2_extn import make_value_setter, TransformedRVBijector
def as_numpy_dict(d):
def as_numpy_dict(dictionary): # Todo move to a utils module
"""Converts values in a dictionary to numpy objects.
:param d a dict where values are array-like objects.
:param dictionary a dict where values are array-like objects.
:return a dict where values are np.array objects.
"""
rv = {}
for k, v in d.items():
if hasattr(v, '__iter__'):
v = np.array(v)
rv[k] = v
return rv
ret_val = {}
for key, val in dictionary.items():
if hasattr(val, '__iter__'):
val = np.array(val)
ret_val[key] = val
return ret_val
class GEM:
"""Represents a GEM model."""
def __init__(self, gemprog, const_data={}, tf_session=None):
def __init__(self, gemprog, const_data=None, tf_session=None):
"""Builds a GEM model given a GEM program and data to attach.
The actual model function is attached to the GEM instance as the method
......@@ -67,26 +69,33 @@ class GEM:
if not tf.executing_eagerly():
self.tf_session = tf.compat.v1.Session() if tf_session is None else tf_session
self.gem_external = const_data
self.__symtab = None
self.ast = None
self.__pyprog = "# Undefined"
self.log_prob_fn = lambda x: 1.0
self.model_impl = lambda *args, **kwargs: {}
self.gem_external = {} if const_data is None else const_data
self.__parse__(gemprog)
self.log_prob_fn = ed.make_log_joint_fn(self.model_impl)
@property
def pyprog(self):
"""Returns the serialized Python representation of the model."""
return self.__pyprog
@property
def symtab(self):
"""Returns the model symbol table."""
return self.__symtab
@property
def variables(self):
"""Returns a dictionary of variables declared in the model."""
return self.model_impl()
@property
def random_variables(self):
"""Returns a dictionary of random variables declared in the model."""
rv_names = get_random_vars(self.__symtab)
return {k: v for k, v in self.model_impl().items() if k in rv_names}
......@@ -103,33 +112,37 @@ class GEM:
SymbolResolver().visit(ast, self.__symtab)
self.ast = ast
self.__pyprog = CodeGenerator().generate(ast, self.__symtab, self.gem_external)
except Exception as e:
print(f"An error occurred during model compilation: {e}")
raise e
except Exception as exception:
print(f"An error occurred during model compilation: {exception}")
raise exception
try:
exec(self.__pyprog,
exec(self.__pyprog, # pylint: disable=exec-used
self.__dict__) # Adds dynamically built self.model_impl()
except Exception as e:
except Exception as exception:
print("An error occurred executing the compiled model")
print(self.__pyprog)
print(self.__symtab)
raise e
raise exception
def log_prob(self, **kwargs):
"""Returns the log posterior of the model evaluated at kwargs.
:param kwargs: a dictionary of parameter values at which to evaluate the log posterior.
:returns: the value of the log posterior"""
if tf.executing_eagerly():
return self.log_prob_fn(**kwargs).numpy()
return self.tf_session.run(self.log_prob_fn(**kwargs))
def sample(self, n_samples=1, condition_vars={}):
def sample(self, n_samples=1, condition_vars=None):
"""Draws samples from the model prior distribution.
:param n_samples: the number of independent samples to draw.
:param condition_vars: a dict of random variable values to condition on {'varname': value}
:return a dict of realisations of random variables in the model
"""
condition_vars = as_numpy_dict(condition_vars)
condition_vars = {} if condition_vars is None else as_numpy_dict(condition_vars)
with ed.interception(make_value_setter(**condition_vars)):
if tf.executing_eagerly():
res_dict = self.model_impl()
......@@ -137,14 +150,16 @@ class GEM:
res_dict = self.tf_session.run(self.model_impl())
return res_dict
def fit(self, observed, n_samples, init=None, tune=None, burnin=0,
clip_burnin=False, transform=False):
def fit(self, observed, n_samples, init=None, burnin=0, # Todo move MCMC implementation to external func
clip_burnin=False):
"""Fits a model to observed data
:param observed: a dict of observed values, with keys denoting variables within the GEM model.
:param n_samples: number of samples to draw from the posterior distribution
:return a dict of posterior samples for each unobserved variable
:param init: dict of initialisation values
:param burnin: the number of burnin samples to take
:param clip_burnin: whether to clip the burnin or not
:returns: a dict of posterior samples for each unobserved variable
"""
# 1. Work out the free rvs
observed = as_numpy_dict(observed)
......@@ -182,10 +197,9 @@ class GEM:
inits = init or [v.value for v in free_rvs.values()]
def trace_fn():
def fn(s, r):
return r.inner_results.inner_results.is_accepted
return fn
def func(sample, result): # pylint: disable=unused-argument
return result.inner_results.inner_results.is_accepted
return func
# 5. Run!
post_, kernel_results_ = tfp.mcmc.sample_chain(
......
......@@ -224,8 +224,7 @@ class TestInference(unittest.TestCase):
model.log_prob(beta=0.00014, gamma=0.14, epi=epidata), -129.847,
places=3)
post, accept = model.fit(observed={'epi': epidata}, n_samples=5000,
tune=[0.00002, 0.03],
init=[0.1, 0.1], burnin=1000, transform=True)
init=[0.1, 0.1], burnin=1000)
traceplot(post)
plt.show()
self.is_within_credibility_interval([0.00014, 0.14], post.values())
......@@ -276,7 +275,6 @@ class TestInference(unittest.TestCase):
end = time.perf_counter()
print("Time log_prob hom_sinr_1000:", end - start, "seconds")
post, accept = model.fit(observed={'epi': epidata}, n_samples=10000,
tune=[0.00002, 0.03],
init=[0.1, 0.5, 0.1, 0.1], burnin=2000,
clip_burnin=True)
traceplot(post)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment