"""Dedispersion routines."""
import math
import numpy as np
import astropy.units as u
import pulsarbat as pb
import dask
import dask.array as da
__all__ = [
"DispersionMeasure",
"DM",
"coherent_dedispersion",
"incoherent_dedispersion",
]
def _transfer_function(coeff, N, dt, center_freq, ref_freq):
f = center_freq.to(u.Hz) + np.fft.fftfreq(N, dt).to(u.Hz)
phase = coeff * f * u.cycle * (1 / ref_freq - 1 / f) ** 2
tf = np.exp(-1j * phase.to_value(u.rad))
return tf.astype(np.complex64)
[docs]class DispersionMeasure(u.SpecificTypeQuantity):
"""Dispersion Measure class (with default units of pc / cm^3)."""
_equivalent_unit = _default_unit = u.pc / u.cm ** 3
dispersion_constant = u.s * u.MHz ** 2 * u.cm ** 3 / u.pc / 2.41e-4
def time_delay(self, f, ref_freq):
"""Time delay of frequencies relative to reference frequency."""
coeff = self.dispersion_constant * self
delay = coeff * (1 / f ** 2 - 1 / ref_freq ** 2)
return delay.to(u.s)
def sample_delay(self, f, ref_freq, sample_rate):
"""Sample delay of frequencies relative to reference frequency."""
samples = self.time_delay(f, ref_freq) * sample_rate
samples = samples.to_value(u.one)
return samples
def chirp_function(self, N, dt, center_freq, ref_freq, use_dask=False):
"""Chirp function for coherent dedispersion."""
coeff = self.dispersion_constant * self
tf_args = (coeff, N, dt, center_freq, ref_freq)
if use_dask:
delayed_tf = dask.delayed(_transfer_function, pure=True)
chirp = da.from_delayed(
delayed_tf(*tf_args), dtype=np.complex64, shape=(N,)
)
else:
chirp = _transfer_function(*tf_args)
return chirp
def chirp_from_signal(self, z, /, *, ref_freq=None):
"""Returns chirp function to dedisperse given baseband signal."""
if not isinstance(z, pb.BasebandSignal):
raise TypeError("Signal must be a BasebandSignal object.")
ix = tuple(slice(None) if i < 2 else None for i in range(z.ndim))
N, dt = len(z), z.dt
if ref_freq is None:
ref_freq = z.center_freq
chirps = [
self.chirp_function(N, dt, f, ref_freq, isinstance(z.data, da.Array))
for f in z.channel_freqs
]
return np.stack(chirps, axis=1)[ix]
DM = DispersionMeasure
[docs]def coherent_dedispersion(z, DM, /, *, ref_freq=None, chirp=None):
"""Coherently dedisperses a baseband signal.
The given signal will be coherently dedispersed by a given dispersion
measure (DM). If a reference frequency (``ref_freq``) is not given, the
center frequency of the signal will be used as reference.
Optionally, a pre-computed chirp function (``chirp``) can be provided
as an array. If a chirp is provided, it will not be checked against
the given DM and reference frequency for correctness.
The output signal will be cropped on both ends to avoid wrap-around
artifacts caused by dedispersion. This depends on where the
reference frequency (``ref_freq``) is compared to the band of the
signal.
Parameters
----------
z : BasebandSignal
The signal to be transformed.
DM : DispersionMeasure
Dispersion measure by which to dedisperse ``z``.
ref_freq : Quantity, optional
Reference frequency for dedispersion. If None (default), uses
the center frequency from signal.
chirp : array-like, optional
A pre-computed chirp function. Must be a 2-D array with shape
``z.shape[:2]``.
Returns
-------
BasebandSignal
The dedispersed signal.
"""
if not isinstance(z, pb.BasebandSignal):
raise TypeError("Signal must be a BasebandSignal object.")
if ref_freq is None:
ref_freq = z.center_freq
if chirp is None:
chirp = DM.chirp_from_signal(z, ref_freq=ref_freq)
chirp = chirp[(slice(None),) * chirp.ndim + (None,) * (z.ndim - chirp.ndim)]
x = pb.fft.ifft(pb.fft.fft(z.data, axis=0) * chirp, axis=0)
delay_top = DM.sample_delay(z.max_freq, ref_freq, z.sample_rate)
delay_bot = DM.sample_delay(z.min_freq, ref_freq, z.sample_rate)
start = math.ceil(-min(0, delay_top, delay_bot))
stop = x.shape[0] - math.ceil(+max(0, delay_top, delay_bot))
return type(z).like(z, x)[start:stop]
[docs]def incoherent_dedispersion(z, DM, /, *, ref_freq=None):
"""Incoherently dedisperses a signal by a given dispersion measure.
The output signal will be cropped on both ends to avoid wrap-around
artifacts caused by dedispersion. This depends on where the
reference frequency (``ref_freq``) compared to the band of the signal.
Parameters
----------
z : RadioSignal
The signal to be transformed.
DM : DispersionMeasure
Dispersion measure by which to dedisperse ``z``.
ref_freq : Quantity, optional
Reference frequency for dedispersion. If None (default), uses
the center frequency from signal.
Returns
-------
RadioSignal
The dedispersed signal.
"""
if not isinstance(z, pb.RadioSignal):
raise TypeError("Signal must be a RadioSignal object.")
if ref_freq is None:
ref_freq = z.center_freq
delays = DM.sample_delay(z.channel_freqs, ref_freq, z.sample_rate)
delays = delays.round().astype(np.int64)
crop_before = -min(0, delays[0], delays[-1])
delays += crop_before
N = len(z) - max(delays)
x = np.stack([z.data[j : j + N, i] for i, j in enumerate(delays)], axis=1)
new_start = z.start_time
if crop_before and z.start_time is not None:
new_start += crop_before * z.dt
return type(z).like(z, x, start_time=new_start)