Source code for alphadia.fdr.utils

"""Utility functions for FDR classification tasks."""

import logging
from collections.abc import Callable
from typing import Any

import numpy as np
import torch
from sklearn.model_selection import train_test_split

from alphadia.exceptions import TooFewPSMError

logger = logging.getLogger()


[docs] def train_test_split_( X: np.ndarray, y: np.ndarray, *, exception: type[Exception] = TooFewPSMError, **kwargs, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Wrapper around `sklearn.model_selection.train_test_split` to handle exceptions. Parameters ---------- X : np.ndarray The input features. y : np.ndarray The target values. exception : type[Exception], default=TooFewPSMError The exception to raise if train_test_split fails. **kwargs Additional arguments passed to sklearn's train_test_split. Returns ------- X_train, X_test, y_train, y_test, indices_train, indices_test : np.ndarray The split data, including the indices. """ try: indices = np.arange(len(X)) X_train, X_test, y_train, y_test, indices_train, indices_test = ( train_test_split(X, y, indices, **kwargs) ) except ValueError as e: raise exception(str(e)) from e else: return X_train, X_test, y_train, y_test, indices_train, indices_test
[docs] def manage_torch_threads(max_threads: int = 2) -> Callable[..., Any]: """Decorator to manage torch thread count during method execution. Parameters ---------- max_threads : int, default=2 Maximum number of threads to use during method execution Returns ------- Callable Decorated function that manages torch thread count """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 is_threads_changed = False original_threads = torch.get_num_threads() # Restrict threads if needed if original_threads > max_threads: torch.set_num_threads(max_threads) is_threads_changed = True logger.info( f"Setting torch num_threads to {max_threads} for FDR classification task" ) try: # Execute the wrapped function return func(*args, **kwargs) finally: # Reset threads if we changed them if is_threads_changed: logger.info(f"Resetting torch num_threads to {original_threads}") torch.set_num_threads(original_threads) return wrapper return decorator