Source code for alphadia.search.selection.selection

"""Main candidate selection implementation for DIA data analysis."""

import logging

import numba as nb
import numpy as np
import pandas as pd
from alpharaw.utils.pjit import pjit
from alpharaw.utils.pjit import set_threads as set_pjit_threads

from alphadia import utils
from alphadia.constants.keys import CalibCols
from alphadia.raw_data import DiaData, DiaDataJIT
from alphadia.search.jitclasses.fragment_container import FragmentContainer
from alphadia.search.selection import fft
from alphadia.search.selection.config_df import (
    CandidateContainer,
    CandidateSelectionConfig,
    CandidateSelectionConfigJIT,
    PrecursorFlatContainer,
    candidate_container_to_df,
)
from alphadia.search.selection.kernel import GaussianKernel
from alphadia.search.selection.utils import (
    amean1,
    assemble_isotope_mz,
    astd1,
    find_peaks_1d,
    find_peaks_2d,
    slice_manual,
    symetric_limits_2d,
    wrap0,
    wrap1,
)
from alphadia.utils import USE_NUMBA_CACHING
from alphadia.validation.schemas import fragments_flat_schema, precursors_flat_schema

logger = logging.getLogger()


@nb.njit(cache=USE_NUMBA_CACHING)
def _is_valid(
    dense_fragments: np.ndarray, dense_precursors: np.ndarray, kernel: np.ndarray
) -> bool:
    """Perform sanity checks and return False if any of them fails."""
    if dense_fragments.shape[0] == 0:
        # "Empty dense fragment matrix"
        return False

    if dense_precursors.shape[0] == 0:
        # "Empty dense precursor matrix"
        return False

    if dense_fragments.shape[2] % 2 != 0:
        # "Dense fragment matrix not divisible by 2"
        return False

    if dense_precursors.shape[2] % 2 != 0:
        # "Dense precursor matrix not divisible by 2"
        return False

    if (
        dense_precursors.shape[2] < kernel.shape[0]
        or dense_precursors.shape[3] < kernel.shape[1]
    ):
        # "Precursor matrix smaller than convolution kernel"
        return False

    if (
        dense_fragments.shape[2] < kernel.shape[0]
        or dense_fragments.shape[3] < kernel.shape[1]
    ):
        # "Fragment matrix smaller than convolution kernel"
        return False

    return True


@pjit(cache=USE_NUMBA_CACHING)
def _select_candidates_pjit(
    i: int,  # pjit decorator changes the passed argument from an iterable to single index
    jit_data: DiaDataJIT,
    precursor_container: PrecursorFlatContainer,
    fragment_container: FragmentContainer,
    config: CandidateSelectionConfigJIT,
    kernel: np.ndarray,
    candidate_container: CandidateContainer,
) -> None:
    """Select candidates for MS2 extraction based on MS1 features.

    Parameters
    ----------
    i : int
        Index of the precursor to process.
    jit_data : DiaDataJIT
        JIT-compiled data object containing the raw data.
    precursor_container : PrecursorFlatContainer
        Container holding precursor information.
    fragment_container : FragmentContainer
        Container holding fragment information.
    config : CandidateSelectionConfigJIT
        Configuration object containing parameters for candidate selection.
    kernel : np.ndarray
        Convolution kernel for smoothing the precursor and fragment data.
    candidate_container : CandidateContainer
        Container to store the selected candidates.

    Returns
    -------
    None, results are stored in `candidate_container`.

    """
    # prepare precursor isotope intensity
    # (n_isotopes)
    isotope_intensity = precursor_container.isotopes[i][: config.top_k_precursors]
    # (n_isotopes)
    isotope_mz = assemble_isotope_mz(
        precursor_container.mz[i], precursor_container.charge[i], isotope_intensity
    )

    fragment_idx_slices = np.array(
        [
            [
                precursor_container.frag_start_idx[i],
                precursor_container.frag_stop_idx[i],
                1,
            ]
        ],
        dtype=np.uint32,
    )

    fragment_container_slice = slice_manual(fragment_container, fragment_idx_slices)
    if config.exclude_shared_ions:
        fragment_container_slice.filter_by_cardinality(1)
    fragment_container_slice.sort_by_mz()

    if len(fragment_container_slice.precursor_idx) <= 3:
        return

    # start extraction of raw data
    rt = precursor_container.rt[i]
    mobility = precursor_container.mobility[i]

    frame_limits = jit_data.get_frame_indices_tolerance(
        rt, config.rt_tolerance, min_size=config.kernel_size
    )
    scan_limits = jit_data.get_scan_indices_tolerance(
        mobility, config.mobility_tolerance
    )

    # identify most abundant isotope
    # max_isotope_idx = np.argmax(isotope_intensity)
    quadrupole_mz = np.array([[isotope_mz[0], isotope_mz[-1]]], dtype=np.float32)

    dense_precursors, _ = jit_data.get_dense_intensity(
        frame_limits,
        scan_limits,
        isotope_mz,
        config.precursor_mz_tolerance,
        np.array([[-1.0, -1.0]], dtype=np.float32),
    )

    # shape = (2, n_fragments, n_observations, n_scans, n_frames), dtype = np.float32
    dense_fragments, _ = jit_data.get_dense_intensity(
        frame_limits,
        scan_limits,
        fragment_container_slice.mz,
        config.fragment_mz_tolerance,
        quadrupole_mz,
        custom_cycle=jit_data.cycle,
    )

    # FLAG: needed for debugging
    # self.dense_fragments = dense_fragments

    if not _is_valid(dense_fragments, dense_precursors, kernel):
        return

    if config.use_weighted_score:
        mean = config.feature_mean
        std = config.feature_std
        weights = config.feature_weight

    else:
        mean = None
        std = None
        weights = None

    _build_candidates(
        precursor_container.precursor_idx[i],
        candidate_container,
        precursor_container.candidate_start_idx[i],
        dense_precursors,
        dense_fragments,
        kernel,
        jit_data,
        config,
        scan_limits,
        frame_limits,
        candidate_count=config.candidate_count,
        weights=weights,
        mean=mean,
        std=std,
    )


@nb.njit(fastmath=True, cache=USE_NUMBA_CACHING)
def _build_features(
    smooth_precursor: np.ndarray, smooth_fragment: np.ndarray
) -> np.ndarray:
    n_features = 1

    features = np.zeros(
        (
            n_features,
            smooth_precursor.shape[2],
            smooth_fragment.shape[3],
        ),
        dtype=np.float32,
    )

    log_fragment = np.sum(np.log(smooth_fragment[0] + 1), axis=0)
    log_precursor = np.sum(np.log(smooth_precursor[0] + 1), axis=0)

    features[0] = log_fragment + log_precursor

    return features


@nb.njit(cache=USE_NUMBA_CACHING)
def _join_close_peaks(
    peak_scan_list: np.ndarray,
    peak_cycle_list: np.ndarray,
    peak_score_list: np.ndarray,
    scan_tolerance: int,
    cycle_tolerance: int,
) -> np.ndarray:
    """Join peaks that are close in scan and cycle space.

    Parameters
    ----------
    peak_scan_list : np.ndarray
        List of scan indices for each peak

    peak_cycle_list : np.ndarray
        List of cycle indices for each peak

    peak_score_list : np.ndarray
        List of scores for each peak

    scan_tolerance : int
        Maximum number of scans that two peaks can be apart to be considered close

    cycle_tolerance : int
        Maximum number of cycles that two peaks can be apart to be considered close

    Returns
    -------
    peak_mask : np.ndarray, dtype=np.bool_

    """
    n_peaks = peak_scan_list.shape[0]
    peak_mask = np.ones(n_peaks, dtype=np.bool_)
    for peak_idx in range(n_peaks):
        if not peak_mask[peak_idx]:
            continue
        scan = peak_scan_list[peak_idx]
        cycle = peak_cycle_list[peak_idx]
        score = peak_score_list[peak_idx]
        for other_peak_idx in range(peak_idx + 1, n_peaks):
            if not peak_mask[other_peak_idx]:
                continue
            other_scan = peak_scan_list[other_peak_idx]
            other_cycle = peak_cycle_list[other_peak_idx]
            other_score = peak_score_list[other_peak_idx]
            if (
                abs(scan - other_scan) <= scan_tolerance
                and abs(cycle - other_cycle) <= cycle_tolerance
            ):
                if score > other_score:
                    peak_mask[other_peak_idx] = False
                else:
                    peak_mask[peak_idx] = False

    return peak_mask


@nb.njit(cache=USE_NUMBA_CACHING)
def _join_overlapping_candidates(
    scan_limits_list: np.ndarray,
    cycle_limits_list: np.ndarray,
    p_scan_overlap: float = 0.01,
    p_cycle_overlap: float = 0.6,
) -> np.ndarray:
    """Identify overlapping candidates and join them into a single candidate.
    The limits of the candidates are updated in-place.

    Parameters
    ----------
    scan_limits_list : np.ndarray
        List of scan limits for each candidate

    cycle_limits_list : np.ndarray
        List of cycle limits for each candidate

    p_scan_overlap : float
        Minimum percentage of scan overlap to join two candidates

    p_cycle_overlap : float
        Minimum percentage of cycle overlap to join two candidates

    Returns
    -------
    joined_mask : np.ndarray, dtype=np.bool_
        Mask that indicates which candidates were joined

    """
    joined_mask = np.ones(len(scan_limits_list), dtype=np.bool_)

    for i in range(len(scan_limits_list)):
        # check if the candidate is already joined
        if joined_mask[i] == 0:
            continue

        # check if the candidate overlaps with any other candidate
        for j in range(i + 1, len(scan_limits_list)):
            # check if the candidate is already joined
            if joined_mask[j] == 0:
                continue

            # calculate the overlap of the area of the two candidates

            cycle_len = cycle_limits_list[i, 1] - cycle_limits_list[i, 0]
            cycle_overlap = (
                min(cycle_limits_list[i, 1], cycle_limits_list[j, 1])
                - max(cycle_limits_list[i, 0], cycle_limits_list[j, 0])
            ) / cycle_len

            scan_len = scan_limits_list[i, 1] - scan_limits_list[i, 0]
            scan_overlap = (
                min(scan_limits_list[i, 1], scan_limits_list[j, 1])
                - max(scan_limits_list[i, 0], scan_limits_list[j, 0])
            ) / scan_len

            # overlap must be positive in both dimensions
            if scan_overlap < 0 or cycle_overlap < 0:
                continue

            if cycle_overlap > p_cycle_overlap and scan_overlap > p_scan_overlap:
                # join the candidates
                scan_limits_list[i, 0] = min(
                    scan_limits_list[i, 0], scan_limits_list[j, 0]
                )
                scan_limits_list[i, 1] = max(
                    scan_limits_list[i, 1], scan_limits_list[j, 1]
                )
                cycle_limits_list[i, 0] = min(
                    cycle_limits_list[i, 0], cycle_limits_list[j, 0]
                )
                cycle_limits_list[i, 1] = max(
                    cycle_limits_list[i, 1], cycle_limits_list[j, 1]
                )
                joined_mask[j] = 0

    return joined_mask


@nb.njit(fastmath=True, cache=USE_NUMBA_CACHING)
def _build_candidates(
    precursor_idx: int,
    candidate_container: CandidateContainer,
    candidate_start_idx: int,
    dense_precursors: np.ndarray,
    dense_fragments: np.ndarray,
    kernel: np.ndarray,
    jit_data: DiaDataJIT,
    config: CandidateSelectionConfigJIT,
    scan_limits: np.ndarray,
    frame_limits: np.ndarray,
    candidate_count: int = 3,
    weights: np.ndarray | None = None,
    mean: np.ndarray | None = None,
    std: np.ndarray | None = None,
) -> None:
    cycle_length = jit_data.cycle.shape[1]

    feature_weights = np.ones(1) if weights is None else weights
    feature_weights = feature_weights.reshape(-1, 1, 1)

    smooth_precursor = fft.convolve_fourier(dense_precursors, kernel)
    smooth_fragment = fft.convolve_fourier(dense_fragments, kernel)

    if smooth_precursor.shape != dense_precursors.shape:
        print(
            f"smooth_precursor shape does not match dense_precursors shape {smooth_precursor.shape} != {dense_precursors.shape}"
        )
    if smooth_fragment.shape != dense_fragments.shape:
        print(
            f"smooth_fragment shape does not match dense_fragments shape {smooth_fragment.shape} != {dense_fragments.shape}"
        )

    feature_matrix = _build_features(smooth_precursor, smooth_fragment).astype(
        "float32"
    )

    # get mean and std to normalize features
    # if trained, use the mean and std from training, otherwise calculate the mean and std from the current data
    feature_mean = (
        amean1(feature_matrix).reshape(-1, 1, 1)
        if mean is None
        else mean.reshape(-1, 1, 1)
    )

    feature_std = (
        astd1(feature_matrix).reshape(-1, 1, 1)
        if std is None
        else std.reshape(-1, 1, 1)
    )

    if not (feature_std.shape == feature_mean.shape == feature_weights.shape):
        raise ValueError(
            f"feature_mean.shape={feature_mean.shape}, feature_std.shape={feature_std.shape} and feature_weights.shape={feature_weights.shape} must be equal"
        )

    feature_matrix_norm = (
        feature_weights * (feature_matrix - feature_mean) / (feature_std + 1e-6)
    )

    score = np.sum(feature_matrix_norm, axis=0)

    peak_scan_list, peak_cycle_list, peak_score_list = _find_peaks(
        score, candidate_count
    )

    peak_mask = _join_close_peaks(
        peak_scan_list, peak_cycle_list, peak_score_list, 3, 3
    )

    peak_scan_list = peak_scan_list[peak_mask]
    peak_cycle_list = peak_cycle_list[peak_mask]
    peak_score_list = peak_score_list[peak_mask]

    scan_limits_list = np.zeros((peak_scan_list.shape[0], 2), dtype="int32")
    cycle_limits_list = np.zeros((peak_cycle_list.shape[0], 2), dtype="int32")

    for candidate_rank, (scan_relative, cycle_relative) in enumerate(
        zip(peak_scan_list, peak_cycle_list)  # ('strict' not supported by numba yet)
    ):
        scan_limits_relative, cycle_limits_relative = symetric_limits_2d(
            score,
            scan_relative,
            cycle_relative,
            f_mobility=config.f_mobility,
            f_rt=config.f_rt,
            center_fraction=config.center_fraction,
            min_size_mobility=config.min_size_mobility,
            min_size_rt=config.min_size_rt,
            max_size_mobility=config.max_size_mobility,
            max_size_rt=config.max_size_rt,
        )

        scan_limits_list[candidate_rank] = scan_limits_relative
        cycle_limits_list[candidate_rank] = cycle_limits_relative

    # check if candidates overlapping candidates should be joined
    if config.join_close_candidates:
        mask = _join_overlapping_candidates(
            scan_limits_list,
            cycle_limits_list,
            p_scan_overlap=config.join_close_candidates_scan_threshold,
            p_cycle_overlap=config.join_close_candidates_cycle_threshold,
        )

        peak_scan_list = peak_scan_list[mask]
        peak_cycle_list = peak_cycle_list[mask]
        peak_score_list = peak_score_list[mask]
        scan_limits_list = scan_limits_list[mask]
        cycle_limits_list = cycle_limits_list[mask]

    # (n_candidates)
    candidate_rank_array = np.arange(peak_scan_list.shape[0], dtype=np.uint8)

    for (
        candidate_rank,
        scan_relative,
        cycle_relative,
        candidate_score,
        scan_limits_relative,
        cycle_limits_relative,
    ) in zip(
        candidate_rank_array,
        peak_scan_list,
        peak_cycle_list,
        peak_score_list,
        scan_limits_list,
        cycle_limits_list,
    ):  # ('strict' not supported by numba yet)
        # does not work anymore

        scan_limits_absolute = wrap1(
            scan_limits_relative + scan_limits[0, 0], jit_data.scan_max_index
        )
        frame_limits_absolute = wrap1(
            cycle_limits_relative * cycle_length + frame_limits[0, 0],
            jit_data.frame_max_index,
        )

        scan_absolute = wrap0(
            scan_relative + scan_limits[0, 0], jit_data.scan_max_index
        )
        frame_absolute = wrap0(
            cycle_relative * cycle_length + frame_limits[0, 0], jit_data.frame_max_index
        )

        candidate_index = candidate_start_idx + candidate_rank

        candidate_container.precursor_idx[candidate_index] = precursor_idx
        candidate_container.rank[candidate_index] = candidate_rank
        candidate_container.score[candidate_index] = candidate_score

        candidate_container.scan_center[candidate_index] = scan_absolute
        candidate_container.scan_start[candidate_index] = scan_limits_absolute[0]
        candidate_container.scan_stop[candidate_index] = scan_limits_absolute[1]

        candidate_container.frame_center[candidate_index] = frame_absolute
        candidate_container.frame_start[candidate_index] = frame_limits_absolute[0]
        candidate_container.frame_stop[candidate_index] = frame_limits_absolute[1]


@nb.njit(cache=USE_NUMBA_CACHING)
def _find_peaks(
    score: np.ndarray,
    candidate_count: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Identify distinct peaks."""
    #  check if there is a real ion mobility dimension
    if score.shape[0] <= 2:
        peak_scan_list, peak_cycle_list, peak_score_list = find_peaks_1d(
            score, top_n=candidate_count
        )
    else:
        peak_scan_list, peak_cycle_list, peak_score_list = find_peaks_2d(
            score, top_n=candidate_count
        )
    return peak_scan_list, peak_cycle_list, peak_score_list


[docs] class CandidateSelection:
[docs] def __init__( self, dia_data: DiaData, precursors_flat: pd.DataFrame, fragments_flat: pd.DataFrame, config: CandidateSelectionConfig, rt_column: str, mobility_column: str, precursor_mz_column: str, fragment_mz_column: str, fwhm_rt: float = 5.0, fwhm_mobility: float = 0.012, ) -> None: """Select candidates for MS2 extraction based on MS1 features Parameters ---------- dia_data : DiaData dia data object precursors_flat : pd.DataFrame flattened precursor dataframe fragments_flat : pd.DataFrame flattened fragment dataframe config : CandidateSelectionConfig config object rt_column : str name of the rt column in the precursor dataframe mobility_column : str name of the mobility column in the precursor dataframe precursor_mz_column : str name of the precursor mz column in the precursor dataframe fragment_mz_column : str name of the fragment mz column in the fragment dataframe fwhm_rt : float, optional full width at half maximum in RT dimension for the GaussianKernel, by default 5.0 fwhm_mobility : float, optional full width at half maximum in mobility dimension for the GaussianKernel, by default 0.012 """ self.dia_data_jit: DiaDataJIT = dia_data.to_jitclass() self.precursors_flat = precursors_flat.sort_values("precursor_idx").reset_index( drop=True ) self.fragments_flat = fragments_flat self.config_jit = config.to_jitclass() self.rt_column = rt_column self.precursor_mz_column = precursor_mz_column self.fragment_mz_column = fragment_mz_column self.mobility_column = mobility_column gaussian_filter = GaussianKernel( self.dia_data_jit, fwhm_rt=fwhm_rt, sigma_scale_rt=self.config_jit.sigma_scale_rt, fwhm_mobility=fwhm_mobility, sigma_scale_mobility=self.config_jit.sigma_scale_mobility, kernel_width=self.config_jit.kernel_size, kernel_height=min( self.config_jit.kernel_size, self.dia_data_jit.scan_max_index + 1 ), ) self.kernel = gaussian_filter.get_dense_matrix()
def __call__(self, thread_count: int = 10, debug: bool = False) -> pd.DataFrame: """Perform candidate extraction workflow. 1. First, elution groups are assembled based on the annotation in the flattened precursor dataframe. Each elution group is instantiated as an ElutionGroup Numba JIT object. Elution groups are stored in the ElutionGroupContainer Numba JIT object. 2. Then, the elution groups are iterated over and the candidates are selected. The candidate selection is performed in parallel using the `pjit` decorator. 3. Finally, the candidates are collected from the ElutionGroup, assembled into a pd.DataFrame and precursor information is appended. Returns ------- pd.DataFrame dataframe containing the extracted candidates """ logging.info("Starting candidate selection") precursor_container = self._assemble_precursor_container(self.precursors_flat) fragment_container = self._assemble_fragment_container() # initialize output container candidate_container = CandidateContainer( len(self.precursors_flat) * self.config_jit.candidate_count ) iterator_len = len(self.precursors_flat) if debug: iterator_len = min(10, len(self.precursors_flat)) thread_count = 1 set_pjit_threads(thread_count) _select_candidates_pjit( range(iterator_len), # type: ignore # noqa: PGH003 # function is wrapped by pjit -> will be turned into single index and passed to the method self.dia_data_jit, precursor_container, fragment_container, self.config_jit, self.kernel, candidate_container, ) candidate_df = candidate_container_to_df(candidate_container) candidate_with_precursors_df = candidate_df.merge( self.precursors_flat[["precursor_idx", "elution_group_idx", "decoy"]], on="precursor_idx", how="left", ) return candidate_with_precursors_df def _assemble_fragment_container(self) -> FragmentContainer: # set cardinality to 1 if not present if "cardinality" in self.fragments_flat.columns: cardinality_values = self.fragments_flat["cardinality"].values else: logging.warning( "Fragment cardinality column not found in fragment dataframe. Setting cardinality to 1." ) cardinality_values = np.ones(len(self.fragments_flat), dtype=np.uint8) self.fragments_flat["cardinality"] = cardinality_values # prepare jitclass compatible dtypes fragments_flat_schema.validate( self.fragments_flat, warn_on_critical_values=True ) return FragmentContainer( self.fragments_flat[CalibCols.MZ_LIBRARY].values, self.fragments_flat[self.fragment_mz_column].values, self.fragments_flat["intensity"].values, self.fragments_flat["type"].values, self.fragments_flat["loss_type"].values, self.fragments_flat["charge"].values, self.fragments_flat["number"].values, self.fragments_flat["position"].values, self.fragments_flat["cardinality"].values, ) def _assemble_precursor_container( self, precursors_flat: pd.DataFrame ) -> PrecursorFlatContainer: # prepare jitclass compatible dtypes precursors_flat_schema.validate(precursors_flat, warn_on_critical_values=True) available_isotopes = utils.get_isotope_columns(precursors_flat.columns) available_isotope_columns = [f"i_{i}" for i in available_isotopes] candidate_start_index = np.arange( 0, len(precursors_flat) * self.config_jit.candidate_count, self.config_jit.candidate_count, dtype=np.uint32, ) candidate_stop_index = ( candidate_start_index + self.config_jit.candidate_count ).astype(np.uint32) return PrecursorFlatContainer( precursors_flat["precursor_idx"].values, precursors_flat["flat_frag_start_idx"].values, precursors_flat["flat_frag_stop_idx"].values, candidate_start_index, candidate_stop_index, precursors_flat["charge"].values, precursors_flat[self.rt_column].values, precursors_flat[self.mobility_column].values, precursors_flat[self.precursor_mz_column].values, precursors_flat[available_isotope_columns].values.copy(), )