Source code for alphadia.search.scoring.features.fragment_features

"""Feature extraction for fragment ions."""

import numba as nb
import numpy as np

from alphadia.search.scoring.features.features_utils import (
    cosine_similarity_a1,
    weighted_center_mean_2d,
)
from alphadia.search.scoring.utils import (
    fragment_correlation,
    fragment_correlation_different,
    tile,
)
from alphadia.utils import USE_NUMBA_CACHING

nb_float32_array = nb.types.Array(nb.types.float32, 1, "C")


[docs] @nb.njit(cache=USE_NUMBA_CACHING) def weighted_center_of_mass( single_dense_representation, ): intensity = [0.0] scans, frames = np.nonzero(single_dense_representation > 0) if len(scans) == 0: return 0, 0, 0, 0 for scan, frame in zip(scans, frames): intensity.append(single_dense_representation[scan, frame]) intensity_arr = np.array(intensity)[1:] intensity_sum = np.sum(intensity_arr) scan_mean = ( np.sum(scans * intensity_arr) / intensity_sum if intensity_sum > 0 else 0 ) frame_mean = ( np.sum(frames * intensity_arr) / intensity_sum if intensity_sum > 0 else 0 ) frame_var_weighted = np.sum((frames - frame_mean) ** 2 * intensity_arr) scan_var_weighted = np.sum((scans - scan_mean) ** 2 * intensity_arr) frame_var_weighted = frame_var_weighted / intensity_sum if intensity_sum > 0 else 0 scan_var_weighted = scan_var_weighted / intensity_sum if intensity_sum > 0 else 0 return scan_mean, frame_mean, scan_var_weighted, frame_var_weighted
[docs] @nb.njit(cache=USE_NUMBA_CACHING) def weighted_center_of_mass_1d( dense_representation, ): scan = np.zeros(dense_representation.shape[0]) frame = np.zeros(dense_representation.shape[0]) frame_var_weighted = np.zeros(dense_representation.shape[0]) scan_var_weighted = np.zeros(dense_representation.shape[0]) for i in range(dense_representation.shape[0]): ( scan[i], frame[i], scan_var_weighted[i], frame_var_weighted[i], ) = weighted_center_of_mass(dense_representation[i]) return scan, frame, scan_var_weighted, frame_var_weighted
@nb.njit(inline="always", cache=USE_NUMBA_CACHING) def _odd_center_envelope(x: np.ndarray): """ Applies an interference correction envelope to a collection of odd-length 1D arrays. Numba function which operates in place. Parameters ---------- x: np.ndarray Array of shape (a, b) where a is the number of arrays and b is the length of each array. It is mandatory that dimension b is odd. """ center_index = x.shape[1] // 2 for a0 in range(x.shape[0]): left_intensity = (x[a0, center_index - 1] + x[a0, center_index]) * 0.5 right_intensity = (x[a0, center_index + 1] + x[a0, center_index]) * 0.5 for i in range(1, center_index + 1): x[a0, center_index - i] = min(left_intensity, x[a0, center_index - i]) left_intensity = ( x[a0, center_index - i] + x[a0, center_index - i + 1] ) * 0.5 x[a0, center_index + i] = min(right_intensity, x[a0, center_index + i]) right_intensity = ( x[a0, center_index + i] + x[a0, center_index + i - 1] ) * 0.5 @nb.njit(inline="always", cache=USE_NUMBA_CACHING) def _even_center_envelope(x: np.ndarray): """ Applies an interference correction envelope to a collection of even-length 1D arrays. Numba function which operates in place. Parameters ---------- x: np.ndarray Array of shape (a, b) where a is the number of arrays and b is the length of each array. It is mandatory that dimension b is even. """ center_index_right = x.shape[1] // 2 center_index_left = center_index_right - 1 for a0 in range(x.shape[0]): left_intensity = x[a0, center_index_left] right_intensity = x[a0, center_index_right] for i in range(1, center_index_left + 1): x[a0, center_index_left - i] = min( left_intensity, x[a0, center_index_left - i] ) left_intensity = ( x[a0, center_index_left - i] + x[a0, center_index_left - i + 1] ) * 0.5 x[a0, center_index_right + i] = min( right_intensity, x[a0, center_index_right + i] ) right_intensity = ( x[a0, center_index_right + i] + x[a0, center_index_right + i - 1] ) * 0.5
[docs] @nb.njit(cache=USE_NUMBA_CACHING) def center_envelope_1d(x: np.ndarray): """ Applies an interference correction envelope to a collection of 1D arrays. Numba function which operates in place. Parameters ---------- x: np.ndarray Array of shape (a, b) where a is the number of arrays and b is the length of each array. It is mandatory that dimension b is odd. """ is_even = x.shape[1] % 2 == 0 if is_even: _even_center_envelope(x) else: _odd_center_envelope(x)
[docs] @nb.njit(cache=USE_NUMBA_CACHING) def weighted_mean_a1(array, weight_mask): """ takes an array of shape (a, b) and a mask of shape (a, b) and returns an array of shape (a) where each element is the weighted mean of the corresponding masked row in the array. Parameters ---------- array: np.ndarray array of shape (a, b) weight_mask: np.ndarray array of shape (a, b) Returns ------- np.ndarray array of shape (a) """ mask = weight_mask > 0 mean = np.zeros(array.shape[0]) for i in range(array.shape[0]): masked_array = array[i][mask[i]] if len(masked_array) > 0: local_weight_mask = weight_mask[i][mask[i]] / np.sum( weight_mask[i][mask[i]] ) mean[i] = np.sum(masked_array * local_weight_mask) else: mean[i] = 0 return mean
[docs] @nb.njit(cache=USE_NUMBA_CACHING) def fragment_features( dense_fragments: np.ndarray, fragments_frame_profile: np.ndarray, frame_rt: np.ndarray, observation_importance: np.ndarray, template: np.ndarray, fragments: np.ndarray, feature_array: nb_float32_array, quant_window: nb.uint32 = 3, quant_all: nb.boolean = False, ): n_observations = observation_importance.shape[0] n_fragments = dense_fragments.shape[1] feature_array[17] = float(n_observations) # (1, n_observations) observation_importance_reshaped = observation_importance.reshape(1, -1) # (n_fragments) fragment_intensity_norm = fragments.intensity / np.sum(fragments.intensity) if n_fragments == 0: print(n_fragments) # (n_observations) ( expected_scan_center, expected_frame_center, expected_scan_variance, expected_frame_variance, ) = weighted_center_of_mass_1d(template) # expand the expected center of mass to the number of fragments # (n_fragments, n_observations) f_expected_scan_center = tile(expected_scan_center, n_fragments).reshape( n_fragments, -1 ) f_expected_frame_center = tile(expected_frame_center, n_fragments).reshape( n_fragments, -1 ) if quant_all: best_profile = np.sum(fragments_frame_profile, axis=1) else: # most intense observation across all observations best_observation = np.argmax(observation_importance) # (n_fragments, n_frames) best_profile = fragments_frame_profile[:, best_observation] center_envelope_1d(best_profile) # handle rare case where the best observation is at the edge of the frame quant_window = min((best_profile.shape[1] // 2) - 1, quant_window) # center the profile around the expected frame center center = best_profile.shape[1] // 2 # (n_fragments, quant_window * 2 + 1) best_profile = best_profile[:, center - quant_window : center + quant_window + 1] # (quant_window * 2 + 1) frame_rt_quant = frame_rt[center - quant_window : center + quant_window + 1] # (quant_window * 2) delta_rt = frame_rt_quant[1:] - frame_rt_quant[:-1] # (n_fragments) fragment_area = np.sum( (best_profile[:, 1:] + best_profile[:, :-1]) * delta_rt.reshape(1, -1) * 0.5, axis=-1, ) fragment_area_norm = fragment_area * quant_window observed_fragment_intensity = np.sum(best_profile, axis=-1) # create fragment masks for filtering fragment_profiles = np.sum(dense_fragments[0], axis=-1) # (n_fragments, n_observations) sum_fragment_intensity = np.sum(fragment_profiles, axis=-1) # create fragment intensity mask # fragment_intensity_mask_2d = sum_fragment_intensity > 0 # fragment_intensity_weights_2d = ( # fragment_intensity_mask_2d * observation_importance_reshaped # ) # (n_fragments, n_observations) # normalize rows to 1 # fragment_intensity_weights_2d = fragment_intensity_weights_2d / ( # np.sum(fragment_intensity_weights_2d, axis=-1).reshape(-1, 1) + 1e-20 # ) # (n_fragments) # observed_fragment_intensity = weighted_mean_a1( # sum_fragment_intensity, fragment_intensity_weights_2d # ) # (n_observations) sum_template_intensity = np.sum(np.sum(template, axis=-1), axis=-1) # get the observed fragment mz and intensity # (n_fragments, n_observations) observed_fragment_mz = weighted_center_mean_2d( dense_fragments[1], f_expected_scan_center, f_expected_frame_center ) # (n_fragments, n_observations) o_fragment_height = weighted_center_mean_2d( dense_fragments[0], f_expected_scan_center, f_expected_frame_center ) # (n_fragments, n_observations) fragment_height_mask_2d = o_fragment_height > 0 # (n_fragments) fragment_height_mask_1d = np.sum(fragment_height_mask_2d, axis=-1) > 0 # (n_fragments, n_observations) fragment_height_weights_2d = ( fragment_height_mask_2d * observation_importance_reshaped ) # (n_fragments, n_observations) # normalize rows to 1 fragment_height_weights_2d = fragment_height_weights_2d / ( np.sum(fragment_height_weights_2d, axis=-1).reshape(-1, 1) + 1e-20 ) # (n_fragments) observed_fragment_mz_mean = weighted_mean_a1( observed_fragment_mz, fragment_height_weights_2d ) # (n_fragments) observed_fragment_height = weighted_mean_a1( o_fragment_height, fragment_height_weights_2d ) if np.sum(fragment_height_mask_1d) > 0.0: feature_array[18] = np.corrcoef(fragment_area_norm, fragment_intensity_norm)[ 0, 1 ] if np.sum(observed_fragment_height) > 0.0: feature_array[19] = np.corrcoef( observed_fragment_height, fragment_intensity_norm )[0, 1] feature_array[20] = np.sum(observed_fragment_intensity > 0.0) / n_fragments feature_array[21] = np.sum(observed_fragment_height > 0.0) / n_fragments feature_array[22] = np.sum( fragment_intensity_norm[observed_fragment_intensity > 0.0] ) feature_array[23] = np.sum(fragment_intensity_norm[observed_fragment_height > 0.0]) fragment_mask = observed_fragment_intensity > 0 if np.sum(fragment_mask) > 0: sum_template_intensity_expanded = sum_template_intensity.reshape(1, -1) observation_score = cosine_similarity_a1( sum_template_intensity_expanded, sum_fragment_intensity[fragment_mask] ).astype(np.float32) feature_array[24] = np.mean(observation_score) # ============= FRAGMENT TYPE FEATURES ============= b_ion_mask = fragments.type == 98 y_ion_mask = fragments.type == 121 weighted_b_ion_intensity = observed_fragment_intensity[b_ion_mask] weighted_y_ion_intensity = observed_fragment_intensity[y_ion_mask] feature_array[25] = ( np.log(np.sum(weighted_b_ion_intensity) + 1) if len(weighted_b_ion_intensity) > 0 else 0.0 ) feature_array[26] = ( np.log(np.sum(weighted_y_ion_intensity) + 1) if len(weighted_y_ion_intensity) > 0 else 0.0 ) feature_array[27] = feature_array[25] - feature_array[26] # ============= FRAGMENT FEATURES ============= mass_error = (observed_fragment_mz_mean - fragments.mz) / fragments.mz * 1e6 fragment_idx_sorted = np.argsort(fragments.intensity)[::-1] top_3_idxs = fragment_idx_sorted[:3] # top_3_ms2_mass_error feature_array[41] = mass_error[top_3_idxs].mean() # mean_ms2_mass_error feature_array[42] = mass_error.mean() # ============= FRAGMENT intersection ============= is_b = fragments.type == 98 is_y = fragments.type == 121 if np.sum(is_b) > 0 and np.sum(is_y) > 0: min_y = fragments.position[is_y].min() max_b = fragments.position[is_b].max() overlapping = (is_y & (fragments.position < max_b)) | ( is_b & (fragments.position > min_y) ) # n_overlapping feature_array[43] = overlapping.sum() if feature_array[43] > 0: # mean_overlapping_intensity feature_array[44] = np.mean(fragment_area_norm[overlapping]) # mean_overlapping_mass_error feature_array[45] = np.mean(mass_error[overlapping]) else: feature_array[44] = 0 feature_array[45] = 15 return ( observed_fragment_mz_mean, mass_error, observed_fragment_height, fragment_area_norm, )
[docs] @nb.njit(cache=USE_NUMBA_CACHING) def fragment_mobility_correlation( fragments_scan_profile, template_scan_profile, observation_importance, fragment_intensity, ): n_observations = len(observation_importance) fragment_mask_1d = np.sum(np.sum(fragments_scan_profile, axis=-1), axis=-1) > 0 if np.sum(fragment_mask_1d) < 3: return 0, 0 non_zero_fragment_norm = fragment_intensity[fragment_mask_1d] / np.sum( fragment_intensity[fragment_mask_1d] ) # (n_observations, n_fragments, n_fragments) fragment_scan_correlation_masked = fragment_correlation( fragments_scan_profile[fragment_mask_1d], ) # (n_fragments, n_fragments) fragment_scan_correlation_maked_reduced = np.sum( fragment_scan_correlation_masked * observation_importance.reshape(-1, 1, 1), axis=0, ) fragment_scan_correlation_list = np.dot( fragment_scan_correlation_maked_reduced, non_zero_fragment_norm ) # fragment_scan_correlation fragment_scan_correlation = np.mean(fragment_scan_correlation_list) # (n_observation, n_fragments) fragment_template_scan_correlation = fragment_correlation_different( fragments_scan_profile[fragment_mask_1d], template_scan_profile.reshape(1, n_observations, -1), ).reshape(n_observations, -1) # (n_fragments) fragment_template_scan_correlation_reduced = np.sum( fragment_template_scan_correlation * observation_importance.reshape(-1, 1), axis=0, ) # template_scan_correlation template_scan_correlation = np.dot( fragment_template_scan_correlation_reduced, non_zero_fragment_norm ) return fragment_scan_correlation, template_scan_correlation