Source code for alphadia.fragcomp.fragcomp

"""The fragment competition module contains functionality to maintain the exclusive assignment of signal to identifications."""

import logging
import warnings

import numba as nb
import numpy as np
import pandas as pd
from alpharaw.utils.pjit import pjit
from alpharaw.utils.pjit import set_threads as set_pjit_threads
from pandas.errors import SettingWithCopyWarning

from alphadia.constants.keys import CalibCols
from alphadia.fragcomp.utils import add_frag_start_stop_idx, candidate_hash
from alphadia.utils import USE_NUMBA_CACHING

logger = logging.getLogger(__name__)


@nb.njit(cache=USE_NUMBA_CACHING)
def _get_fragment_overlap(
    frag_mz_1: np.ndarray,
    frag_mz_2: np.ndarray,
    mass_tol_ppm: float = 10,
) -> int:
    """Get the number of overlapping fragments between two spectra.

    Parameters
    ----------
    frag_mz_1: np.ndarray
        The m/z values of the first spectrum.

    frag_mz_2: np.ndarray
        The m/z values of the second spectrum.

    mass_tol_ppm: float
        The mass tolerance in ppm.

    Returns
    -------
    int
        The number of overlapping fragments.

    """
    frag_mz_1 = frag_mz_1.reshape(-1, 1)
    frag_mz_2 = frag_mz_2.reshape(1, -1)
    delta_mz = np.abs(frag_mz_1 - frag_mz_2)
    ppm_delta_mz = delta_mz / frag_mz_1 * 1e6
    return np.sum(ppm_delta_mz < mass_tol_ppm)


@pjit(cache=USE_NUMBA_CACHING)
def _compete_for_fragments(  # noqa: PLR0913 # Too many arguments
    thread_idx: int,  # pjit decorator changes the passed argument from an iterable to single index
    precursor_start_idxs: np.ndarray,
    precursor_stop_idxs: np.ndarray,
    rt: np.ndarray,
    frag_start_idx: np.ndarray,
    frag_stop_idx: np.ndarray,
    fragment_mz: np.ndarray,
    rt_tol_seconds: float,
    mass_tol_ppm: float,
    valid: np.ndarray,
) -> None:
    """Remove PSMs that share fragments with other PSMs.

    The function is applied on a dia window basis.

    The pjit decorator thread-parallelizes over the first argument index and additionally wraps with numba.njit(nogil=True).
    Make sure to read and understand the pjit decorator, especially how it changes the type of the first argument.

    Parameters
    ----------
    thread_idx: int
        The thread index. Each thread will handle one dia window.
        The pjit decorator effectively changes the type of this argument to `np.ndarray` and thread-parallelizes
        over it.

    precursor_start_idxs: np.ndarray
        Array of length n_windows. The start indices of the precursors in the PSM dataframe.

    precursor_stop_idxs: np.ndarray
        Array of length n_windows. The stop indices of the precursors in the PSM dataframe.

    rt: np.ndarray
        The retention times of the precursors.

    frag_start_idx: np.ndarray
        Array of length n_psms. The start indices of the fragments in the fragment dataframe.

    frag_stop_idx: np.ndarray
        Array of length n_psms. The stop indices of the fragments in the fragment dataframe.

    fragment_mz: np.ndarray
        The m/z values of the fragments.

    rt_tol_seconds: float
        The retention time tolerance in seconds.

    mass_tol_ppm: float
        The mass tolerance in ppm.

    valid: np.ndarray
        Array of length n_psms. The validity of each PSM. This is where the method output will be stored.

    Returns
    -------
        None, but modifies the `valid` array in place.

    """
    precursor_start_idx = precursor_start_idxs[thread_idx]
    precursor_stop_idx = precursor_stop_idxs[thread_idx]

    rt_window = rt[precursor_start_idx:precursor_stop_idx]
    valid_window = valid[precursor_start_idx:precursor_stop_idx]

    for i, i_rt in enumerate(rt_window):
        if not valid_window[i]:
            continue
        for j, j_rt in enumerate(rt_window):
            if i == j:
                continue
            if not valid_window[j]:
                continue

            delta_rt = abs(i_rt - j_rt)
            if delta_rt < rt_tol_seconds:
                fragment_overlap = _get_fragment_overlap(
                    fragment_mz[
                        frag_start_idx[precursor_start_idx + i] : frag_stop_idx[
                            precursor_start_idx + i
                        ]
                    ],
                    fragment_mz[
                        frag_start_idx[precursor_start_idx + j] : frag_stop_idx[
                            precursor_start_idx + j
                        ]
                    ],
                    mass_tol_ppm=mass_tol_ppm,
                )
                if fragment_overlap >= 3:  # noqa: PLR2004
                    valid_window[j] = False

    valid[precursor_start_idx:precursor_stop_idx] = valid_window


[docs] class FragmentCompetition: """Fragment competition class to remove PSMs that share fragments with other PSMs."""
[docs] def __init__( self, rt_tol_seconds: int = 3, mass_tol_ppm: int = 15, thread_count: int = 8 ): """Remove PSMs that share fragments with other PSMs. Parameters ---------- rt_tol_seconds: int The retention time tolerance in seconds. mass_tol_ppm: int The mass tolerance in ppm. thread_count: int The number of threads to use. """ self.rt_tol_seconds = rt_tol_seconds self.mass_tol_ppm = mass_tol_ppm self.thread_count = thread_count
@staticmethod def _add_window_idx(psm_df: pd.DataFrame, cycle: np.ndarray) -> pd.DataFrame: """Add the window index to the PSM dataframe. Parameters ---------- psm_df: pd.DataFrame The PSM dataframe. cycle: np.ndarray The cycle array. Returns ------- pd.DataFrame The PSM dataframe with the window index. """ if "window_idx" in psm_df.columns: logger.warning("Window index already present in PSM dataframe. Skipping.") return psm_df lower_limit = np.min(cycle[0, :, :, 0], axis=1, keepdims=True).T upper_limit = np.max(cycle[0, :, :, 1], axis=1, keepdims=True).T idx = ( np.expand_dims(psm_df[CalibCols.MZ_OBSERVED].values, axis=-1) >= lower_limit ) & ( np.expand_dims(psm_df[CalibCols.MZ_OBSERVED].values, axis=-1) < upper_limit ) psm_df["window_idx"] = np.argmax(idx, axis=1) return psm_df @staticmethod def _get_thread_plan_df(psm_df: pd.DataFrame) -> pd.DataFrame: """Expects a dataframe sorted by window idxs and qvals. Returns a dataframe with start and stop indices of the threads. Parameters ---------- psm_df: pd.DataFrame The PSM dataframe. Returns ------- pd.DataFrame The thread plan dataframe. """ psm_df["_thread_idx"] = np.arange(len(psm_df)) index_df = psm_df.groupby("window_idx", as_index=False).agg( start_idx=pd.NamedAgg("_thread_idx", "min"), stop_idx=pd.NamedAgg("_thread_idx", "max"), ) index_df["stop_idx"] += 1 psm_df.drop(columns=["_thread_idx"], inplace=True) return index_df def __call__( self, psm_df: pd.DataFrame, frag_df: pd.DataFrame, cycle: np.ndarray ) -> pd.DataFrame: """Remove PSMs that share fragments with other PSMs. Parameters ---------- psm_df: pd.DataFrame The PSM dataframe. frag_df: pd.DataFrame The fragment dataframe. cycle: np.ndarray DIA cycle. Returns ------- pd.DataFrame The PSM dataframe with the valid column. """ # TODO: this method raises SettingWithCopyWarning. Resolve without increasing memory usage. warnings.simplefilter(action="ignore", category=(SettingWithCopyWarning)) psm_df["_candidate_idx"] = candidate_hash( psm_df["precursor_idx"].values, psm_df["rank"].values ) frag_df["_candidate_idx"] = candidate_hash( frag_df["precursor_idx"].values, frag_df["rank"].values ) psm_df = add_frag_start_stop_idx(psm_df, frag_df) psm_df = self._add_window_idx(psm_df, cycle) # important to sort by window_idx and proba psm_df.sort_values( by=["window_idx", "proba", "precursor_idx"], inplace=True ) # last sort to break ties valid = np.ones(len(psm_df)).astype(bool) # psm_df["valid"] = True set_pjit_threads(self.thread_count) thread_plan_df = self._get_thread_plan_df(psm_df) _compete_for_fragments( np.arange(len(thread_plan_df)), # type: ignore # noqa: PGH003 # function is wrapped by pjit -> will be turned into single index and passed to the method thread_plan_df["start_idx"].values, thread_plan_df["stop_idx"].values, psm_df[CalibCols.RT_OBSERVED].values, psm_df["_frag_start_idx"].values, psm_df["_frag_stop_idx"].values, frag_df[CalibCols.MZ_OBSERVED].values, self.rt_tol_seconds, self.mass_tol_ppm, valid, ) psm_df["valid"] = valid # clean up psm_df.drop( columns=["_frag_start_idx", "_frag_stop_idx", "window_idx"], inplace=True ) warnings.simplefilter(action="default", category=(SettingWithCopyWarning)) return psm_df[psm_df["valid"]]