Source code for mrtool.core.data

# -*- coding: utf-8 -*-
"""
    data
    ~~~~

    `data` module for `mrtool` package.
"""
from typing import Dict, List, Union, Any
import warnings
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
from .utils import empty_array, to_list, is_numeric_array, expand_array


[docs]@dataclass class MRData: """Data for simple linear mixed effects model. """ obs: np.ndarray = field(default_factory=empty_array) obs_se: np.ndarray = field(default_factory=empty_array) covs: Dict[str, np.ndarray] = field(default_factory=dict) study_id: np.ndarray = field(default_factory=empty_array) data_id: np.ndarray = field(default_factory=empty_array) cov_scales: Dict[str, float] = field(init=False, default_factory=dict) def __post_init__(self): self._check_attr_type() self.obs = expand_array(self.obs, (self.num_points,), np.nan, 'obs') self.obs_se = expand_array(self.obs_se, (self.num_points,), 1.0, 'obs_se') self.study_id = expand_array(self.study_id, (self.num_points,), 'Unknown', 'study_id') self.data_id = expand_array(self.data_id, (self.num_points,), np.arange(self.num_points), 'data_id') assert len(np.unique(self.data_id)) == self.num_points, "data_id must be unique for each data point." self.covs.update({'intercept': np.ones(self.num_points)}) for cov_name, cov in self.covs.items(): assert len(cov) == self.num_points, f"covs[{cov_name}], inconsistent shape." self._remove_nan_in_covs() self._get_study_structure() self._get_cov_scales() @property def num_points(self): """Number of data points. """ return max([len(self.obs), len(self.obs_se), len(self.study_id)] + [len(cov) for cov in self.covs.values()]) @property def num_obs(self): """Number of observations. """ return len(self.obs) @property def num_covs(self): """Number of covariates. """ return len(self.covs) @property def num_studies(self): """Number of studies. """ return len(self.studies) def _check_attr_type(self): """Check the type of the attributes. """ assert isinstance(self.obs, np.ndarray) assert is_numeric_array(self.obs) assert isinstance(self.obs_se, np.ndarray) assert is_numeric_array(self.obs_se) assert isinstance(self.study_id, np.ndarray) assert isinstance(self.data_id, np.ndarray) assert isinstance(self.covs, dict) for cov in self.covs.values(): assert isinstance(cov, np.ndarray) assert is_numeric_array(cov) def _get_cov_scales(self): """Compute the covariate scale. """ if self.is_empty(): self.cov_scales = {cov_name: np.nan for cov_name in self.covs.keys()} else: self.cov_scales = {cov_name: np.max(np.abs(cov)) for cov_name, cov in self.covs.items()} zero_covs = [cov_name for cov_name, cov_scale in self.cov_scales.items() if cov_scale == 0.0] for cov_name in zero_covs: warnings.warn(f"numbers in covariate[{cov_name}] are all zero. " f"Please use this in spline range exposure or when preidct.") def _get_study_structure(self): """Get the study structure. """ self.studies, self.study_sizes = np.unique(self.study_id, return_counts=True) self._sort_by_study_id() def _sort_data(self, index: np.ndarray): """Sort the object. Args: index (np.ndarray): Sorting index. """ index = np.array(index) assert (np.sort(index) == np.arange(self.num_obs)).all(), "Sorting index must go from 0 to num_obs - 1." self.obs = self.obs[index] self.obs_se = self.obs_se[index] for cov_name, cov in self.covs.items(): self.covs[cov_name] = cov[index] self.study_id = self.study_id[index] self.data_id = self.data_id[index] def _sort_by_study_id(self): """Sort data by study_id. """ if not self.is_empty() and self.num_studies != 1: sort_index = np.argsort(self.study_id) self._sort_data(sort_index) def _sort_by_data_id(self): """Sort data by data_id. """ if not self.is_empty(): sort_index = np.argsort(self.data_id) self._sort_data(sort_index) def _remove_nan_in_covs(self): """Remove potential nans in covaraites. """ if not self.is_empty(): index = np.full(self.num_obs, False) for cov_name, cov in self.covs.items(): cov_index = np.isnan(cov) if cov_index.any(): warnings.warn(f"There are {cov_index.sum()} nans in covaraite {cov_name}.") index = index | cov_index self._remove_data(index) def _remove_data(self, index: np.ndarray): """Remove the data point by index. Args: index (np.ndarray): Bool array, when ``True`` delete corresponding data. """ assert len(index) == self.num_obs assert all([isinstance(i, (bool, np.bool_)) for i in index]) keep_index = ~index self.obs = self.obs[keep_index] self.obs_se = self.obs_se[keep_index] for cov_name, cov in self.covs.items(): self.covs[cov_name] = cov[keep_index] self.study_id = self.study_id[keep_index] self.data_id = self.data_id[keep_index] def _get_data(self, index: np.ndarray) -> 'MRData': """Get the data point by index. Args: index (np.ndarray): Indices of the data we want to get. Returns: MRData: data object contains the data from indices. """ obs = self.obs[index].copy() obs_se = self.obs_se[index].copy() covs = {} for cov_name, cov in self.covs.items(): covs[cov_name] = cov[index].copy() study_id = self.study_id[index].copy() data_id = self.data_id[index].copy() return MRData(obs, obs_se, covs, study_id, data_id)
[docs] def is_empty(self) -> bool: """Return true when object contain data. """ return self.num_points == 0
def _assert_not_empty(self): """Raise ValueError when object is empty. """ if self.is_empty(): raise ValueError("MRData object is empty.")
[docs] def is_cov_normalized(self, covs: Union[List[str], str, None] = None) -> bool: """Return true when covariates are normalized. """ if covs is None: covs = list(self.covs.keys()) else: covs = to_list(covs) assert self.has_covs(covs) ok = not self.is_empty() for cov_name in covs: ok = ok and ((not is_numeric_array(self.covs[cov_name])) or (np.max(np.abs(self.covs[cov_name])) == 1.0)) return ok
[docs] def reset(self): """Reset all the attributes to default values. """ self.obs = empty_array() self.obs_se = empty_array() self.covs = dict() self.covs['intercept'] = np.ones(0) self.study_id = empty_array() self.data_id = empty_array()
[docs] def load_df(self, data: pd.DataFrame, col_obs: Union[str, None] = None, col_obs_se: Union[str, None] = None, col_covs: Union[List[str], None] = None, col_study_id: Union[str, None] = None, col_data_id: Union[str, None] = None): """Load data from data frame. """ self.reset() self.obs = empty_array() if col_obs is None else data[col_obs].to_numpy() self.obs_se = empty_array() if col_obs_se is None else data[col_obs_se].to_numpy() self.study_id = empty_array() if col_study_id is None else data[col_study_id].to_numpy() self.data_id = empty_array() if col_data_id is None else data[col_data_id].to_numpy() self.covs = dict() if col_covs is None else {cov_name: data[cov_name].to_numpy() for cov_name in col_covs} self.__post_init__()
[docs] def load_xr(self, data, var_obs: Union[str, None] = None, var_obs_se: Union[str, None] = None, var_covs: Union[List[str], None] = None, coord_study_id: Union[str, None] = None): """Load data from xarray. """ self.reset() self.obs = empty_array() if var_obs is None else data[var_obs].data.flatten() self.obs_se = empty_array() if var_obs_se is None else data[var_obs_se].data.flatten() if coord_study_id is None: self.study_id = empty_array() else: index = data.coords.to_index().to_frame(index=False) self.study_id = index[coord_study_id].to_numpy() self.covs = dict() if var_covs is None else {cov_name: data[cov_name].data.flatten() for cov_name in var_covs} self.__post_init__()
[docs] def to_df(self) -> pd.DataFrame: """Convert data object to data frame. """ df = pd.DataFrame({ 'obs': self.obs, 'obs_se': self.obs_se, 'study_id': self.study_id }) for cov_name in self.covs: df[cov_name] = self.covs[cov_name] return df
[docs] def has_covs(self, covs: Union[List[str], str]) -> bool: """If the data has the provided covariates. Args: covs (Union[List[str], str]): List of covariate names or one covariate name. Returns: bool: If has covariates return `True`. """ covs = to_list(covs) if len(covs) == 0: return True else: return all([cov in self.covs for cov in covs])
[docs] def has_studies(self, studies: Union[List[Any], Any]) -> bool: """If the data has provided study_id Args: studies Union[List[Any], Any]: List of studies or one study. Returns: bool: If has studies return `True`. """ studies = to_list(studies) if len(studies) == 0: return True else: return all([study in self.studies for study in studies])
def _assert_has_covs(self, covs: Union[List[str], str]): """Assert has covariates otherwise raise ValueError. """ if not self.has_covs(covs): covs = to_list(covs) missing_covs = [cov for cov in covs if cov not in self.covs] raise ValueError(f"MRData object do not contain covariates: {missing_covs}.") def _assert_has_studies(self, studies: Union[List[Any], Any]): """Assert has studies otherwise raise ValueError. """ if not self.has_studies(studies): studies = to_list(studies) missing_studies = [study for study in studies if study not in self.studies] raise ValueError(f"MRData object do not contain studies: {missing_studies}.")
[docs] def get_covs(self, covs: Union[List[str], str]) -> np.ndarray: """Get covariate matrix. Args: covs (Union[List[str], str]): List of covariate names or one covariate name. Returns: np.ndarray: Covariates matrix, in the column fashion. """ covs = to_list(covs) self._assert_has_covs(covs) if len(covs) == 0: return np.array([]).reshape(self.num_obs, 0) else: return np.hstack([self.covs[cov_names][:, None] for cov_names in covs])
[docs] def get_study_data(self, studies: Union[List[Any], Any]) -> 'MRData': """Get study specific data. Args: studies (Union[List[Any], Any]): List of studies or one study. Returns MRData: Data object contains the study specific data. """ self._assert_has_studies(studies) studies = to_list(studies) index = np.array([study in studies for study in self.study_id]) return self._get_data(index)
[docs] def normalize_covs(self, covs: Union[List[str], str, None] = None): """Normalize covariates by the largest absolute value for each covariate. """ if covs is None: covs = list(self.covs.keys()) else: covs = to_list(covs) self._assert_has_covs(covs) if not self.is_empty(): for cov_name in covs: if is_numeric_array(self.covs[cov_name]): self.covs[cov_name] = self.covs[cov_name]/self.cov_scales[cov_name]
def __repr__(self): """Summary of the object. """ dimension_summary = [ "number of observations: %i" % self.num_obs, "number of covariates : %i" % self.num_covs, "number of studies : %i" % self.num_studies, ] return "\n".join(dimension_summary)