diff --git a/pycode/memilio-epidata/memilio/epidata/getContactData.py b/pycode/memilio-epidata/memilio/epidata/getContactData.py index 186f02e989..bae6cf2a43 100644 --- a/pycode/memilio-epidata/memilio/epidata/getContactData.py +++ b/pycode/memilio-epidata/memilio/epidata/getContactData.py @@ -24,9 +24,10 @@ Prem et al., 2017 (DOI: https://doi.org/10.1371/journal.pcbi.1005697). The module can download the supporting ZIP from https://doi.org/10.1371/journal.pcbi.1005697.s002 (contains the -``MUestimates_all_locations_1.xlsx`` workbook) or read a defined local -workbook path. By default, downloads are done in memory and no -files are written. +``MUestimates_all_locations_1.xlsx`` and +``MUestimates_all_locations_2.xlsx`` workbooks) or read a defined local +workbook path. By default, downloads are done in memory and no files are +written. """ import io @@ -43,7 +44,10 @@ "https://journals.plos.org/ploscompbiol/article/file" "?id=10.1371/journal.pcbi.1005697.s002&type=supplementary" ) -CONTACT_WORKBOOK_NAME = "MUestimates_all_locations_1.xlsx" +CONTACT_WORKBOOK_NAMES = ( + "MUestimates_all_locations_1.xlsx", + "MUestimates_all_locations_2.xlsx", +) AGE_GROUP_LABELS = [ "0-4", @@ -84,59 +88,64 @@ def _normalize_country_name(country: str): return "".join(ch for ch in country.casefold() if ch.isalnum()) -def _download_contact_workbook( - url: str = CONTACT_ZIP_URL, target_filename: str = CONTACT_WORKBOOK_NAME): +def _download_contact_workbooks(): """ - Download the ZIP from the url and return the workbook. + Download the ZIP and return all contact workbooks. - :param url: URL to download the ZIP from. - :param target_filename: Name of the workbook file within the ZIP. - :returns: Content of the workbook. + :returns: List of workbook contents. """ - response = requests.get(url, timeout=30) + response = requests.get(CONTACT_ZIP_URL, timeout=30) response.raise_for_status() + + workbooks = [] with zipfile.ZipFile(io.BytesIO(response.content)) as zf: - candidates = [name for name in zf.namelist() - if name.endswith(target_filename)] - if not candidates: - raise FileNotFoundError( - f"'{target_filename}' not found in downloaded workbook.") - with zf.open(candidates[0]) as f: - return f.read() + for target_filename in CONTACT_WORKBOOK_NAMES: + candidates = [name for name in zf.namelist() + if name.endswith(target_filename)] + if not candidates: + raise FileNotFoundError( + f"'{target_filename}' not found in downloaded workbook.") + with zf.open(candidates[0]) as f: + workbooks.append(f.read()) + return workbooks -def _load_workbook_bytes( - contact_path: Optional[str], - url: str = CONTACT_ZIP_URL, - target_filename: str = CONTACT_WORKBOOK_NAME): + +def _load_workbooks_bytes( + contact_path: Optional[str]): """ - Return workbook either from a user path or by downloading the ZIP. + Return one explicit workbook or, by default, all downloaded workbooks. - :param contact_path: Optional local path to the workbook. - :param url: Url to download the ZIP from if no path is provided. - :param target_filename: Name of the workbook file within the ZIP. - :returns: Content of the workbook. + :param contact_path: Optional local path to a single workbook. + :returns: List of workbook contents. """ if contact_path: if not os.path.exists(contact_path): raise FileNotFoundError( f"Contact matrix file not found at {contact_path}") with open(contact_path, "rb") as f: - return f.read() - return _download_contact_workbook(url=url, target_filename=target_filename) + return [f.read()] + return _download_contact_workbooks() def list_available_contact_countries( contact_path: Optional[str] = None): """ - List all country names available in the contact matrix workbook. + List all country names available in the contact matrix workbooks. :param contact_path: Optional local path to the workbook. :returns: List of all country names. """ - xls_bytes = _load_workbook_bytes(contact_path) - xls = pd.ExcelFile(io.BytesIO(xls_bytes)) - return xls.sheet_names + countries = [] + seen = set() + for xls_bytes in _load_workbooks_bytes(contact_path): + xls = pd.ExcelFile(io.BytesIO(xls_bytes)) + for sheet_name in xls.sheet_names: + key = _normalize_country_name(sheet_name) + if key not in seen: + countries.append(sheet_name) + seen.add(key) + return countries def _select_sheet_name(country: str, sheet_names: Iterable[str]): @@ -157,6 +166,58 @@ def _select_sheet_name(country: str, sheet_names: Iterable[str]): return lookup[key] +def _read_contact_sheet(xls: pd.ExcelFile, sheet_name: str, country: str): + """ + Read a contact sheet and extract its numeric 16x16 matrix. + + Some source workbooks contain an explicit header row, others contain only + the matrix. Reading without headers and selecting the numeric block handles + both formats. + + :param xls: Opened Excel workbook containing contact matrix sheets. + :param sheet_name: Name of the sheet to read. + :param country: Country name used for error messages. + :returns: DataFrame containing the extracted 16x16 contact matrix. + """ + df = pd.read_excel( + xls, + sheet_name=sheet_name, + engine="openpyxl", + header=None) + + return _extract_contact_matrix(df, country) + + +def _extract_contact_matrix(df: pd.DataFrame, country: str): + """ + Extract the 16x16 numeric contact matrix from a raw Excel sheet. + + :param df: Raw sheet data read from the workbook without headers. + :param country: Country name used for error messages. + :returns: DataFrame with age-group labels as index and columns. + """ + matrix_size = len(AGE_GROUP_LABELS) + numeric = df.apply(pd.to_numeric, errors="coerce") + max_row_start = numeric.shape[0] - matrix_size + max_col_start = numeric.shape[1] - matrix_size + + for row_start in range(max_row_start, -1, -1): + for col_start in range(max_col_start, -1, -1): + matrix = numeric.iloc[ + row_start:row_start + matrix_size, + col_start:col_start + matrix_size] + if matrix.shape == (matrix_size, matrix_size): + if not matrix.isnull().any().any(): + matrix = matrix.copy() + matrix.columns = AGE_GROUP_LABELS + matrix.index = AGE_GROUP_LABELS + return matrix + + raise ValueError( + f"Contact matrix for '{country}' does not contain a numeric " + f"{matrix_size}x{matrix_size} block. Raw shape: {df.shape}") + + def load_contact_matrix( country: str, contact_path: Optional[str] = None, @@ -165,31 +226,31 @@ def load_contact_matrix( """ Load the all-locations contact matrix for the given country. If ``contact_path`` is not provided, the function downloads the - ``MUestimates_all_locations_1.xlsx`` workbook from Prem et al., 2017. + ``MUestimates_all_locations_*.xlsx`` workbooks from Prem et al., 2017. :param country: Country name as listed in the workbook (case-insensitive). - :param contact_path: Optional path to ``MUestimates_all_locations_1.xlsx``. + :param contact_path: Optional path to one ``MUestimates_all_locations`` + workbook. :param reduce_to_rki_groups: If True, aggregate to the six RKI age groups (0-4, 5-14, 15-34, 35-59, 60-79, 80+ years). Default True. - :param population: An iterable of 16 float values representing the population - size for each original age group. Required if reduce_to_rki_groups is True. + :param population: An iterable of 16 float values representing the + population size for each original age group. Required if + reduce_to_rki_groups is True. :returns: DataFrame indexed by age group with floats. """ - xls_bytes = _load_workbook_bytes(contact_path) - xls = pd.ExcelFile(io.BytesIO(xls_bytes)) - sheet_names = xls.sheet_names - sheet = _select_sheet_name(country, sheet_names) - df = pd.read_excel(xls, sheet_name=sheet, engine="openpyxl") - - # Ensure numeric values and trim potential trailing rows/cols. - matrix = df.apply(pd.to_numeric, errors="coerce") - matrix = matrix.iloc[:len(AGE_GROUP_LABELS), :len(AGE_GROUP_LABELS)] - matrix.columns = AGE_GROUP_LABELS[:matrix.shape[1]] - matrix.index = AGE_GROUP_LABELS[:matrix.shape[0]] - - if matrix.isnull().any().any(): - raise ValueError( - f"Contact matrix for '{country}' contains non-numeric entries.") + all_sheet_names = [] + for xls_bytes in _load_workbooks_bytes(contact_path): + xls = pd.ExcelFile(io.BytesIO(xls_bytes)) + sheet_names = xls.sheet_names + all_sheet_names.extend(sheet_names) + try: + sheet = _select_sheet_name(country, sheet_names) + except ValueError: + continue + matrix = _read_contact_sheet(xls, sheet, country) + break + else: + _select_sheet_name(country, all_sheet_names) if matrix.shape[0] != matrix.shape[1]: raise ValueError( @@ -210,8 +271,8 @@ def load_contact_matrix( def _aggregate_to_rki_age_groups( matrix: pd.DataFrame, population: Iterable[float]): """ - Aggregate an age-structured 16x16 contact matrix to the 6-group RKI scheme using - population-weighted averages. + Aggregate an age-structured 16x16 contact matrix to the 6-group RKI + scheme using population-weighted averages. Assumes the original columns/rows follow AGE_GROUP_LABELS order. Note: The source only provides data up to 70-74 and a 75+ group. We map 60-74 to the 60-79 RKI group and 75+ to the 80-99 RKI group. diff --git a/pycode/memilio-epidata/tests/test_epidata_get_contact_data.py b/pycode/memilio-epidata/tests/test_epidata_get_contact_data.py index 715b7e6c5b..5c769478df 100644 --- a/pycode/memilio-epidata/tests/test_epidata_get_contact_data.py +++ b/pycode/memilio-epidata/tests/test_epidata_get_contact_data.py @@ -29,7 +29,7 @@ from memilio.epidata.getContactData import (AGE_GROUP_LABELS, AGE_GROUP_LABELS_RKI, - CONTACT_WORKBOOK_NAME, + CONTACT_WORKBOOK_NAMES, list_available_contact_countries, load_contact_matrix) @@ -62,15 +62,31 @@ def _create_workbook(self, sheets: dict): return path def _create_zip_with_workbook(self, sheets: dict): - """Create a ZIP that contains MUestimates_all_locations_1.xlsx.""" + """Create a ZIP that contains all default contact workbooks.""" + return self._create_zip_with_workbooks({ + workbook_name: sheets + for workbook_name in CONTACT_WORKBOOK_NAMES + }) + + def _create_zip_with_workbooks( + self, workbook_sheets: dict, workbook_headers: dict = None): + """Create a ZIP that contains the given contact workbooks.""" + if workbook_headers is None: + workbook_headers = {} with io.BytesIO() as buf_zip: with zipfile.ZipFile(buf_zip, mode="w") as zf: - with io.BytesIO() as buf_xlsx: - with pd.ExcelWriter(buf_xlsx, engine="openpyxl") as writer: - for sheet_name, df in sheets.items(): - df.to_excel( - writer, sheet_name=sheet_name, index=False) - zf.writestr(CONTACT_WORKBOOK_NAME, buf_xlsx.getvalue()) + for workbook_name, sheets in workbook_sheets.items(): + header = workbook_headers.get(workbook_name, True) + with io.BytesIO() as buf_xlsx: + with pd.ExcelWriter( + buf_xlsx, engine="openpyxl") as writer: + for sheet_name, df in sheets.items(): + df.to_excel( + writer, + sheet_name=sheet_name, + index=False, + header=header) + zf.writestr(workbook_name, buf_xlsx.getvalue()) return buf_zip.getvalue() def test_list_available_explicit_path(self): @@ -83,11 +99,13 @@ def test_list_available_explicit_path(self): @patch('memilio.epidata.getContactData.requests.get') def test_list_available_downloads_zip(self, mock_get): - """When no path is given, the workbook is downloaded from the ZIP and the available countries are listed. + """The country list is read from all downloaded workbooks. """ data = pd.DataFrame(np.ones((16, 16))) - zip_bytes = self._create_zip_with_workbook( - {"Germany": data, "Spain": data}) + zip_bytes = self._create_zip_with_workbooks({ + CONTACT_WORKBOOK_NAMES[0]: {"Germany": data}, + CONTACT_WORKBOOK_NAMES[1]: {"Spain": data}, + }) mock_resp = Mock() mock_resp.content = zip_bytes mock_resp.raise_for_status = Mock() @@ -99,7 +117,7 @@ def test_list_available_downloads_zip(self, mock_get): mock_get.assert_called_once() def test_load_contact_matrix(self): - """Contact matrix loads full 16x16 matrix (no reduction, no pop needed).""" + """Contact matrix loads full 16x16 matrix.""" matrix_values = np.arange(16 * 16).reshape(16, 16) df = pd.DataFrame(matrix_values) contact_path = self._create_workbook({"Germany": df}) @@ -115,7 +133,7 @@ def test_load_contact_matrix(self): self.assertEqual(matrix.iloc[-1, -1], 255) def test_load_contact_matrix_rki_groups_missing_population(self): - """Raises ValueError if reduce_to_rki_groups=True but population is missing.""" + """Raises ValueError if RKI aggregation misses population.""" matrix_values = np.arange(16 * 16).reshape(16, 16) df = pd.DataFrame(matrix_values) contact_path = self._create_workbook({"Germany": df}) @@ -126,7 +144,7 @@ def test_load_contact_matrix_rki_groups_missing_population(self): reduce_to_rki_groups=True) def test_load_contact_matrix_rki_groups(self): - """Aggregate to the 6-group RKI age groups using a uniform dummy population.""" + """Aggregate to the 6-group RKI age groups.""" matrix_values = np.arange(16 * 16).reshape(16, 16) df = pd.DataFrame(matrix_values) contact_path = self._create_workbook({"Germany": df}) @@ -141,28 +159,31 @@ def test_load_contact_matrix_rki_groups(self): # entries from group [0-4] to [0-4] remain the same self.assertEqual(matrix.iloc[0, 0], 0) - # groups [5-9] and [10-14] (row 1,2) are aggregated into the [5-14] RKI group (row 1). + # Groups [5-9] and [10-14] are aggregated into [5-14]. # weighted avg with uniform pop is (16*100 + 32*100)/200 = 24. self.assertEqual(matrix.iloc[1, 0], 24) # rows 3,4,5,6 and cols 3,4,5,6 ([15-34] with [15-34]) - # sum across the columns, then average across the rows (due to uniform population) + # Sum across columns, then average rows due to uniform population. sub_expected = matrix_values[3:7, 3:7].sum(axis=1).mean() self.assertEqual(matrix.iloc[2, 2], sub_expected) @patch('memilio.epidata.getContactData.requests.get') def test_load_contact_matrix_downloads_no_path(self, mock_get): - """When no path is given, the workbook is downloaded from the DOI ZIP.""" + """When no path is given, all contact workbooks are searched.""" matrix_values = np.arange(16 * 16).reshape(16, 16) df = pd.DataFrame(matrix_values) - zip_bytes = self._create_zip_with_workbook({"Germany": df}) + zip_bytes = self._create_zip_with_workbooks({ + CONTACT_WORKBOOK_NAMES[0]: {"Germany": df}, + CONTACT_WORKBOOK_NAMES[1]: {"Spain": df}, + }) mock_resp = Mock() mock_resp.content = zip_bytes mock_resp.raise_for_status = Mock() mock_get.return_value = mock_resp matrix = load_contact_matrix( - "Germany", contact_path=None, population=self.dummy_pop) + "Spain", contact_path=None, population=self.dummy_pop) self.assertEqual(matrix.shape, (6, 6)) self.assertEqual(list(matrix.columns), AGE_GROUP_LABELS_RKI) self.assertEqual(list(matrix.index), AGE_GROUP_LABELS_RKI) @@ -172,6 +193,34 @@ def test_load_contact_matrix_downloads_no_path(self, mock_get): mock_resp.raise_for_status.assert_called_once() mock_get.assert_called_once() + @patch('memilio.epidata.getContactData.requests.get') + def test_load_contact_matrix_downloads_headerless_second_workbook( + self, mock_get): + """Contact matrices can be read without an Excel header row.""" + matrix_values = np.arange(16 * 16).reshape(16, 16) + df = pd.DataFrame(matrix_values) + zip_bytes = self._create_zip_with_workbooks( + { + CONTACT_WORKBOOK_NAMES[0]: {"Germany": df}, + CONTACT_WORKBOOK_NAMES[1]: {"United States of America": df}, + }, + workbook_headers={ + CONTACT_WORKBOOK_NAMES[1]: False, + }) + mock_resp = Mock() + mock_resp.content = zip_bytes + mock_resp.raise_for_status = Mock() + mock_get.return_value = mock_resp + + matrix = load_contact_matrix( + "United States of America", + contact_path=None, + reduce_to_rki_groups=False) + + self.assertEqual(matrix.shape, (16, 16)) + self.assertEqual(matrix.iloc[0, 0], 0) + self.assertEqual(matrix.iloc[-1, -1], 255) + if __name__ == '__main__': unittest.main()