Source code for alphadia.workflow.config

"""This module is responsible for creating and storing the configuration.

It allows updating the default configuration with one or more other configuration objects.
The order of configs holds significance, with configurations later in the sequence overwriting previous values.
Lists are always overwritten completely.

On demand, the current config can be visualized in a tree-like structure.
"""

import json
import logging
from collections import UserDict, defaultdict
from copy import deepcopy
from typing import Any

import numpy as np
import yaml

from alphadia.constants.keys import ConfigKeys
from alphadia.exceptions import KeyAddedConfigError, TypeMismatchConfigError

logger = logging.getLogger()

DEFAULT = "default"
USER_DEFINED = "user defined"
USER_DEFINED_CLI_PARAM = "user defined (cli)"
MULTISTEP_SEARCH = "multistep search"

# delimiters to parse config
MULTIPLEXING_CHANNELS_DELIM = ","
MODIFICATIONS_DELIM = ";"


[docs] class Config(UserDict): """Dict-like config class that can read from and write to yaml and json files and allows updating with other config objects. TODO: this class should be read-only, but currently mutable value elements can be mutated. """
[docs] def __init__(self, data: dict = None, name: str = DEFAULT) -> None: # super class deliberately not called as this calls "update" (which we overwrite) self.data = ( {**data} if data is not None else {} ) # this needs to be called 'data' as we inherit from UserDict self.name = name
[docs] def from_yaml(self, path: str) -> None: with open(path) as f: self.data = yaml.safe_load(f)
[docs] def from_json(self, path: str) -> None: with open(path) as f: self.data = json.load(f)
[docs] def to_yaml(self, path: str) -> None: with open(path, "w") as f: yaml.dump(_convert_numpy_types(self.data), f, sort_keys=False)
[docs] def to_json(self, path: str) -> None: with open(path, "w") as f: json.dump(self.data, f)
def __setitem__(self, key: str, item: Any): """Forbid settings keys directly. Use set_path() to set certain (nested) keys, or update() for other keys. """ raise NotImplementedError("Use set_value() or update() to update the config.")
[docs] def set_value(self, key: str | tuple[str, ...], path: str | list[str]) -> None: """Set a config key. Only certain keys are allowed to be set. Use a tuple key for nested access, e.g. ("library_prediction", "peptdeep_model_path"). """ if key not in [ ConfigKeys.VERSION, ConfigKeys.OUTPUT_DIRECTORY, ConfigKeys.LIBRARY_PATH, ConfigKeys.QUANT_DIRECTORY, ConfigKeys.RAW_PATHS, ConfigKeys.FASTA_PATHS, ( ConfigKeys.LIBRARY_PREDICTION, ConfigKeys.LIBRARY_PREDICTION.PEPTDEEP_MODEL_PATH, ), ]: raise NotImplementedError( "Only certain values may be set directly, use update() to update the config otherwise." ) if isinstance(key, tuple): # nested access target = self.data for k in key[:-1]: target = target[k] target[key[-1]] = path else: self.data[key] = path
def __delitem__(self, key): raise NotImplementedError("Use update() to update the config.")
[docs] def copy(self): raise NotImplementedError("Use deepcopy() to copy the config.")
[docs] def update(self, configs: list["Config"], do_print: bool = False): """ Updates the config with one or more other config objects. The order of configs holds significance, with configurations later in the sequence taking precedence in terms of their impact on changes. All changes to the default config are tracked and stored in a separate dictionary to enable convenient visualization of the changes. Parameters ---------- configs : list of configs List of config objects to update the current config with. The order of the configs is important (last one wins). do_print : bool, optional Whether to print the modified config. Default is False. """ # we assume that self.data holds the default config default_config = deepcopy(self.data) def _recursive_defaultdict(): """Allow initialization of an infinitely nested dictionary to be able to map arbitrary structures.""" return defaultdict(_recursive_defaultdict) tracking_dict = defaultdict(_recursive_defaultdict) current_config = deepcopy(self.data) for config in configs: logger.info(f"Updating config with '{config.name}'") _update( current_config, config.data, tracking_dict, config.name, ) self.data = current_config if do_print: try: self._pretty_print(current_config, default_config, tracking_dict) except Exception as e: logger.warning(f"Could not print config: {e}") logger.info(f"{(yaml.dump(current_config))}")
@staticmethod def _pretty_print(config: dict, default_config: dict, tracking_dict: dict) -> None: """ Pretty print a configuration dictionary in a tree-like structure. Parameters ---------- config: The configuration dictionary to print default_config: The default configuration dictionary to print tracking_dict: A dictionary with the same structure as config, whose leaf values contain the name of the config object that last updated the value """ _pretty_print( config, default_config=default_config, tracking_dict=tracking_dict )
# keys that have been removed from the config but are still tolerated # Note: if multiple levels have been removed, multiple entries are needed, e.g. ["removed_key_level1, removed_key_level1.removed_key_level2"] TOLERATED_KEYS = [ # supported until 2.0.2: "general.astral_ms1", "general.mmap_detector_events", "fdr.enable_two_step_classifier", "fdr.two_step_classifier_max_iterations", # supported until 2.0.1: "scoring_config", "scoring_config.score_grouped", "scoring_config.top_k_isotopes", "scoring_config.reference_channel", "scoring_config.precursor_mz_tolerance", "scoring_config.fragment_mz_tolerance", "selection_config", "selection_config.peak_len_rt", "selection_config.sigma_scale_rt", "selection_config.peak_len_mobility", "selection_config.sigma_scale_mobility", "selection_config.top_k_precursors", "selection_config.kernel_size", "selection_config.f_mobility", "selection_config.f_rt", "selection_config.center_fraction", "selection_config.min_size_mobility", "selection_config.min_size_rt", "selection_config.max_size_mobility", "selection_config.max_size_rt", "selection_config.group_channels", "selection_config.use_weighted_score", "selection_config.join_close_candidates", "selection_config.join_close_candidates_scan_threshold", "selection_config.join_close_candidates_cycle_threshold", # supported until 1.10.4: "calibration.norm_rt_mode", ] def _convert_numpy_types(data: Any) -> Any: """Recursively convert numpy types to native Python types for YAML serialization. These could come from dynamically set values, e.g. calculated mass tolerances in multi-step searches. Note: no need to handle tuples or sets since YAML doesn't have native tuple/set types. Parameters ---------- data : any Data structure potentially containing numpy types Returns ------- any Same structure with numpy types converted to native Python types """ if isinstance(data, dict): return {key: _convert_numpy_types(value) for key, value in data.items()} elif isinstance(data, list): return [_convert_numpy_types(item) for item in data] elif isinstance(data, np.generic): return data.item() elif isinstance(data, np.ndarray): return data.tolist() elif isinstance(data, tuple | set): raise NotImplementedError( "Tuples and sets are not supported in config serialization." ) else: return data def _update( target_config: dict, update_config: dict, tracking_dict: dict, config_name: str, parent_keys: str = "", ) -> None: """ Recursively update target_dict in-place with values from update_dict, following specific rules for different types. For each value that gets updated, the corresponding value in tracking_dict is updated with config_name. Parameters ---------- target_config: The config dictionary to be modified update_config: The config dictionary containing update values tracking_dict: A dictionary of nested dictionaries. If a value target_config gets overwritten, the same value in tracking_dict will be overwritten with `config_name`. config_name: The name of the current config object parent_keys: Names of the parent keys, separated by dots. Used only for exception messages. Notes ----- - Nested dictionaries are recursively updated - Only updates existing keys (adding new keys not allowed) - lists are always overwritten Raises ------ - KeyAddedConfigError: a key is not found in the target_config - ValueTypeMismatchConfigError: the type of the update value does not match the type of the target value """ for key, update_value in update_config.items(): full_key = f"{parent_keys}.{key}" if parent_keys else key if key not in target_config: if full_key in TOLERATED_KEYS: logger.warning( f"Key '{full_key}' has been removed from AlphaDIA. Please update your config.yaml as this will be an error in future versions." ) continue raise KeyAddedConfigError(full_key, update_value, config_name) target_value = target_config[key] tracking_value = tracking_dict[key] # Convert string "true"/"false" to boolean to avoid type mismatch errors (especially from --config-dict CLI parameter) if isinstance(update_value, str): if update_value.lower() == "true": update_value = True elif update_value.lower() == "false": update_value = False if ( # exception: None -> something allowed target_value is not None # exception: something -> None allowed and update_value is not None # exception: int <-> float allowed and not ( isinstance(target_value, int | float) and isinstance(update_value, int | float) ) # actual type check and type(target_value) is not type(update_value) ): raise TypeMismatchConfigError( full_key, update_value, config_name, f"{type(update_value)} != {type(target_value)}", ) if isinstance(target_value, dict): _update( target_value, update_value, tracking_value, config_name, parent_keys=full_key, ) elif isinstance(target_value, list): # overwrite lists completely target_config[key] = update_value tracking_dict[key] = config_name # handle simple values else: target_config[key] = update_value tracking_dict[key] = config_name def _pretty_print( config: dict, *, default_config: dict | list | None, tracking_dict: dict | str, prefix: str = "", ): """Recursively pretty print a configuration dictionary in a tree-like structure. Note: all special unicode characters used here must be explicitly escaped in the GUI (cf. replaceConfigFormatUnicodeEscapes()) """ for i, (key, value) in enumerate(config.items()): is_last_item = i == len(config.items()) - 1 # determine the current line's prefix current_prefix = "└──" if is_last_item else "├──" # determine the next level's prefix next_prefix = prefix + (" " if is_last_item else "│ ") if default_config is None: # in case something was added default_config_value = None elif isinstance(default_config, dict): try: default_config_value = default_config[key] except KeyError: # in case a key was added default_config_value = None else: # we can assume it's a list here, as simple types are printed right away default_config_value = default_config[i] if isinstance(tracking_dict, str): # we have a leaf node (e.g. "default") tracking_dict_value = tracking_dict else: # tracking values are either dict or str (not lists: those are overwritten by config_name) tracking_dict_value = tracking_dict[key] if isinstance(value, dict): logger.info(f"{prefix}{current_prefix}{key}") _pretty_print( value, default_config=default_config_value, tracking_dict=tracking_dict_value, prefix=next_prefix, ) elif isinstance(value, list): logger.info(f"{prefix}{current_prefix}{key}:") for j, value_ in enumerate(value): default_value = ( default_config_value[j] if default_config_value is not None and j < len(default_config_value) else None ) # complex lists if isinstance(value_, dict): next_prefix = prefix + (" " if is_last_item else "│ ") _pretty_print( value_, default_config=default_value, tracking_dict=tracking_dict_value, prefix=next_prefix, ) # simple lists else: color_on, color_off = _get_color_tokens(value_, default_value) logger.info( f"{next_prefix}{color_on}- {_expand(value_, default_value, tracking_dict_value)}{color_off}" ) # simple value else: color_on, color_off = _get_color_tokens(value, default_config_value) logger.info( f"{prefix}{color_on}{current_prefix}{key}: {_expand(value, default_config_value, tracking_dict_value)}{color_off}" ) def _get_color_tokens( actual_value: str | int | float | None, default_value: str | int | float | None, ) -> tuple[str, str]: """Get color on/off tokens if values differ, else empty strings.""" if default_value != actual_value: style = "\x1b[32;20m" reset = "\x1b[0m" return style, reset return "", "" def _expand( actual_value: str | int | float, default_value: str | int | float, tracking_value: str, ) -> str: """Create an expanded string representation of a configuration value in case it differs from the default.""" msg = str(actual_value) if default_value != actual_value: return f"{msg} [{tracking_value}, default: {default_value}]" return msg