rdata.py 1.32 KB
Newer Older
1
2
3
4
"""Loads R data.frame data structures"""

import numpy
import pyreadr as pyr
Chris Jewell's avatar
Chris Jewell committed
5
import numpy as np
6
7
8
9
10
11
12
13


def load_age_mixing(rds_file: str):
    """Loads age mixing matrix from R.

    :param rds_file: a .rds file containing an R data.frame with mixing matrix
    """
    raw = pyr.read_r(rds_file)
Chris Jewell's avatar
Chris Jewell committed
14
15
16
    K = list(raw.values())[0]
    age_groups = K.columns
    return K.to_numpy(dtype=np.float32), age_groups
17
18
19
20
21
22
23
24
25
26
27
28
29


def load_mobility_matrix(rds_file: str):
    """Loads mobility COO from RDS file.

     :param rds_file: a .rds file containing an R data.frame with columns 'Residence',
     and 'Workplace' indicating matrix coordinates, and first column containing the value
     """
    raw = pyr.read_r(rds_file)
    df = list(raw.values())[0]
    colnames = df.columns
    mobility_matrix = df.pivot(index='Workplace', columns='Residence', values=colnames[0])
    mobility_matrix[mobility_matrix.isna()] = 0.
Chris Jewell's avatar
Chris Jewell committed
30
    return mobility_matrix.to_numpy(dtype=np.float32), mobility_matrix.index.to_numpy()
31
32
33
34
35
36
37
38
39
40


def load_population(rds_file: str):
    """Loads population data from RDS file.

    :param rds_file: and RDS file containing a data.frame with columns 'age', 'LA.code', 'n'
    """
    raw = pyr.read_r(rds_file)
    df = list(raw.values())[0]
    df = df.sort_values(by=['LA.code', 'age'])
Chris Jewell's avatar
Chris Jewell committed
41
    return df['n'].to_numpy(dtype=np.float32), df[['name', 'Area.name.2']]