Source code for SSINS.match_filter

"""
Match Filter class
"""

import numpy as np
import warnings
from collections import namedtuple
import yaml
from copy import deepcopy
from functools import reduce
import os

Event = namedtuple("Event", ["time_slice", "freq_slice", "shape", "sig"])


[docs]class MF(): """ Defines the Match Filter (MF) class. """ def __init__(self, freq_array, sig_thresh, shape_dict={}, tb_aggro=0, N_samp_thresh=None, narrow=True, streak=True, broadcast_dict={}, broadcast_streak=False): """ Args: freq_array: Sets the freq_array attribute of the filter sig_thresh (dict or number): If dictionary, the keys are the desired shapes to flag. The values are the desired significance thresholds for each shape. If streak or narrow are True, thresholds for these must be included in this dictionary, although they should not be included in the shape_dict keyword input. If passing a number, this number is used as the threshold for all shapes. shape_dict (dict): A dictionary of shapes to flag. Keys are shapes other than 'streak' and 'narrow'. Values are frequency limits of corresponding shape. N_samp_thresh (int): Sets the N_samp_thresh attribute of the filter broadcast_dict (dict): Optional. Describes how to partition the band when broadcasting over frequencies. The keys should be the names of each subband to broadcast over, and the values should be the edges of each subband (i.e. two-element lists). broadcast_streak (bool): Optional, broadcast flags over whole observing band. narrow (bool): If True, search for narrowband (single channel) RFI streak (bool): If True, search for broad RFI streaks that occupy the entire observing band """ if (not shape_dict) and (not narrow) and (not streak): raise ValueError("There are not shapes in the shape_dict and narrow/streak shapes are disabled. Check keywords.") self.freq_array = freq_array """A 1-d array of frequencies (in hz) for the filter to operate with""" self.shape_dict = shape_dict """A dictionary of shapes. Keys are a shape name, values are the lower and upper frequency bounds in Hz.""" self.sig_thresh = sig_thresh """A dictionary of significance thresholds to flag per shape. Keys are shapes and values are thresholds.""" self.tb_aggro = tb_aggro """The threshold for flagging an entire channel when some flags exist and apply_samp_thresh() is called. Must be between 0 and 1. Represents a fraction of unflagged data remaining.""" if (self.tb_aggro >= 1) or (self.tb_aggro < 0): raise ValueError("tb_aggro parameter must be between 0 and 1.") if N_samp_thresh is not None: raise ValueError("The N_samp_thresh parameter is now deprected. See" " the tb_aggro parameter for its replacement.") self.slice_dict = self._shape_slicer(narrow, streak, "shape_dict") """A dictionary whose keys are the same as shape_dict and whose values are corresponding slices into the freq_array attribute""" if type(self.sig_thresh) is dict: for shape in self.slice_dict: if shape not in self.sig_thresh.keys(): raise KeyError("%s shape has no sig_thresh. Check sig_thresh input." % shape) else: sig_thresh_dict = {} for shape in self.slice_dict: sig_thresh_dict[shape] = self.sig_thresh self.sig_thresh = sig_thresh_dict self.broadcast_dict = broadcast_dict """A dictionary of subbands. Keys are a subband name, values are the lower and upper frequency bounds in Hz.""" self.broadcast_slc_dict = self._shape_slicer(False, broadcast_streak, input_dict="broadcast_dict") """A dictionary whose keys are the same as broadcast_dict and whose values are corresponding slices into the freq_array attribute""" def _shape_slicer(self, narrow, streak, input_dict="shape_dict"): """ This function converts the frequency information in the shape_dict attribute to slice objects for the channel numbers of the spectrum. The narrow and streak shapes require special slices. Args: narrow (bool): If True, add the narrow shape to the dictionary streak (bool): If True, add the streak shape to the dictionary input_dict (str): Which dict attribute to operate on. Returns: slice_dict: See slice_dict attribute """ slice_dict = {} ch_wid = self.freq_array[1] - self.freq_array[0] for shape in getattr(self, input_dict): if min(self.freq_array) <= min(getattr(self, input_dict)[shape]) or \ max(self.freq_array) >= max(getattr(self, input_dict)[shape]): # Assuming frequencies represent channel centers at that fine channel bandpass has sharp cutoff at midpoint min_chan = np.argmin(np.abs(self.freq_array - min(getattr(self, input_dict)[shape]))) # Extend by 1 so that it is inclusive at the upper boundary max_chan = np.argmin(np.abs(self.freq_array - max(getattr(self, input_dict)[shape]))) + 1 slice_dict[shape] = slice(min_chan, max_chan) if narrow: slice_dict['narrow'] = None if streak: slice_dict['streak'] = slice(0, len(self.freq_array)) return(slice_dict)
[docs] def match_test(self, INS): """ The primary test that the filter is used with. The greatest outlier in each shape is put forth and a champion among them is chosen. The time, frequencies, and outlier statistic and shape of this champion is returned to the stack. Args: INS: An INS to test Returns: t_max: The time index of the strongest outlier (None if no significant outliers) f_max: The slice in the freq_array for the strongest outlier (None if no significant outliers) R_max: The ratio of the z-score of the outlier to the sig_thresh (-np.inf if no significant outliers) shape_max: The shape of the strongest outlier """ sig_max = -np.inf t_max = None f_max = None shape_max = None for shape in self.slice_dict: if shape == 'narrow': t, f, p = np.unravel_index(np.absolute(INS.metric_ms).argmax(), INS.metric_ms.shape) sig = np.absolute(INS.metric_ms[t, f, p]) t = slice(t, t + 1) f = slice(f, f + 1) else: N = np.count_nonzero(np.logical_not(INS.metric_ms[:, self.slice_dict[shape]].mask), axis=1) sliced_arr = np.absolute(INS.metric_ms[:, self.slice_dict[shape]].mean(axis=1)) * np.sqrt(N) t, p = np.unravel_index((sliced_arr / self.sig_thresh[shape]).argmax(), sliced_arr.shape) t = slice(t, t + 1) f = self.slice_dict[shape] # Pull out the number instead of a sliced arr sig = sliced_arr[t, p][0] if sig > self.sig_thresh[shape]: if sig > sig_max: t_max, f_max, shape_max, sig_max = (t, f, shape, sig) if shape_max == "narrow": shape_max = "narrow_%.3fMHz" % (INS.freq_array[f_max][0] * 10**(-6)) event = Event(t_max, f_max, shape_max, sig_max) return(event)
[docs] def apply_match_test(self, INS, event_record=True, apply_samp_thresh=None, freq_broadcast=False, time_broadcast=False): """ A method that uses the match_test() method to flag RFI. The champion from match_test() is flagged and the mean-subtracted spectrum is recalculated. This repeats until there are no more outliers greater than sig_thresh. Also can apply the samp_thresh_test in each iteration, which flags highly occupied channels between match test iterations. Args: INS: The INS to flag event_record (bool): If True, append events to INS.match_events apply_samp_thresh (bool): Deprecated in favor of the time_broadcast keyword. freq_broadcast (bool): If True, broadcast flags between iterations using the broadcast_dict time_broadcast (bool): If True, broadcasts flags in time if significant flagging in channels. Set tb_aggro parameter for aggression. """ if apply_samp_thresh is not None: raise ValueError("apply_samp_thresh has been deprecated in favor of" " the time_broadcast keyword.") # Initialize the counter so the loop starts. count = 1 while count: # If no events are found, this will remain 0, and the loop will end count = 0 event = self.match_test(INS) if event.sig > -np.inf: count += 1 INS.metric_array[event[:2]] = np.ma.masked # Only adjust those values in the sig_array that are not already assigned nonmask = np.logical_not(INS.metric_ms.mask[event[:2]]) INS.sig_array[event[:2]][nonmask] = INS.metric_ms[event[:2]][nonmask] if event_record: INS.match_events.append(event) if time_broadcast: event = self.time_broadcast(INS, event, event_record=event_record) if freq_broadcast: event = self.freq_broadcast(INS, event, event_record=event_record) if not np.all(INS.metric_array[:, event[1]].mask): INS.metric_ms[:, event[1]] = INS.mean_subtract(freq_slice=event[1]) else: INS.metric_ms[:, event[1]] = np.ma.masked nonmask_all = np.logical_not(INS.metric_ms.mask) INS.sig_array[nonmask_all] = INS.metric_ms[nonmask_all]
[docs] def time_broadcast(self, INS, event, event_record=False): """ Broadcasts flags in time for a subband (determined by the passed event) if the fraction of unflagged samples in the subband is less than the tb_aggro parameter. If so, the entire subband is flagged for the whole object. A ValueError is raised if the aggro parameter is greater than or equal to 1, which will always flag everything. Args: INS: An INS to test event: The event to check. event_record (bool): If true, append events to INS.match_events. Returns: new_event: Possible new event if more flagging happened. Returns old event if no additional flagging happened. """ # Find how many channels are already fully flagged, so we can ignore them num_chans_all_flag = np.sum(np.all(INS.metric_array.mask[:, event[1], :], axis=(0, -1))) # Find the total data volume and subtract off the data volume in channels that are totally flagged total = np.prod(INS.metric_array[:, event[1]].shape) total_invalid = num_chans_all_flag * INS.Ntimes * INS.Npols total_valid = total - total_invalid # Find the total flagged data volume and subtract off the invalid data total_flag = np.sum(INS.metric_array[:, event[1]].mask) total_flag_valid = total_flag - total_invalid # Find the flag fraction, unflagged fraction, compare to aggro parameter flag_frac = total_flag_valid / total_valid unflag_frac = 1 - flag_frac if unflag_frac <= self.tb_aggro: INS.metric_array[:, event[1]] = np.ma.masked if event_record: new_event = Event(slice(0, INS.Ntimes), event[1], f"time_broadcast_{event[2]}", None) INS.match_events.append(new_event) else: new_event = event return(new_event)
[docs] def freq_broadcast(self, INS, event, event_record=False): """ Broadcast flags in frequency, regardless of flagging fraction. Determined, by the broadcast_dict. An event will be recorded in the match_filter saying which integration/band was flagged. Args: INS: The incoherent noise spectrum being worked on. event: The event to broadcast flags for. event_record (bool): Whether to record a new event for this flagging entry. """ if self.broadcast_slc_dict == {}: raise ValueError("MF object does not have a broadcast_dict, but is " " being asked to broadcast flags. Check " " initialization of MF object.") new_event_set = set() sbs = [] for sb in self.broadcast_slc_dict: sb_slc = self.broadcast_slc_dict[sb] event_set = set(np.arange(event[1].stop)[event[1]]) broad_set = set(np.arange(sb_slc.stop)[sb_slc]) if not event_set.isdisjoint(broad_set): new_event_set = new_event_set.union(broad_set) sbs.append(sb) if len(new_event_set) > 0: sb_string = "_" sb_string = sb_string.join(sbs) # They should all be contiguous until discontiguous shapes are allowed new_event_slc = slice(min(new_event_set), max(new_event_set) + 1) INS.metric_array[event[0], new_event_slc] = np.ma.masked final_event = Event(event[0], new_event_slc, f"freq_broadcast_{sb_string}", None) if event_record: INS.match_events.append(final_event) else: final_event = event return(final_event)
[docs] def write(self, prefix, sep="_", clobber=False): """ Writes out a yaml file with the important information about the filter. Args: prefix: The filepath prefix for the output file. Output file will be named f'{prefix}{sep}matchfilter.yaml' sep: The separator character between the prefix and the rest of the output filepath. clobber: Whether to overwrite an identically named file. True overwrites. """ outpath = f"{prefix}{sep}SSINS{sep}matchfilter.yml" yaml_dict = self._make_yaml_dict() file_exists = os.path.exists(outpath) if file_exists and not clobber: raise ValueError(f"matchfilter file with prefix {prefix} already exists and clobber is False.") else: with open(outpath, 'w') as outfile: yaml.safe_dump(yaml_dict, outfile)
def _make_yaml_dict(self): """ Helper function for MF.write that sets up the dictionary for the yaml output. """ from . import __version__ broadcast_dict = deepcopy(self.broadcast_dict) # Include additional shape if in the slc_dict which may be missing from the broadcast_dict if "streak" in self.broadcast_slc_dict: broadcast_dict.update({"streak": [self.freq_array[0], self.freq_array[-1]]}) shape_dict = deepcopy(self.shape_dict) if "streak" in self.slice_dict: shape_dict.update({"streak": [self.freq_array[0], self.freq_array[-1]]}) if "narrow" in self.slice_dict: # Placeholder values. "narrow" really refers to Nfreqs different shapes. shape_dict.update({"narrow (vals are placeholders)": [self.freq_array[0], self.freq_array[-1]]}) yaml_dict = {"freqs": [float(freq) for freq in self.freq_array], "shape_dict": {shape: [float(shape_dict[shape][0]), float(shape_dict[shape][1])] for shape in shape_dict}, "sig_thresh": self.sig_thresh, "tb_aggro": self.tb_aggro, "freq_broadcast": {shape: [float(broadcast_dict[shape][0]), float(broadcast_dict[shape][1])] for shape in broadcast_dict}, "version": __version__} return(yaml_dict)