"""Functionality for the generation of a set of indices
which accurately represent a waveform.
The default implementation is a greedy one, as defined in
:class:`GreedyDownsamplingTraining`.
To provide an alternate method, just subclass
:class:`DownsamplingTraining`.
"""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Optional, Tuple
import h5py
import numpy as np
from scipy import interpolate # type: ignore
from sortedcontainers import SortedList # type: ignore
from .data_management import DownsamplingIndices
from .dataset_generation import Dataset
[docs]class DownsamplingTraining(ABC):
"""Selection of the downsampling indices.
Parameters
----------
dataset : Dataset
dataset to which to refer for the generation
of training waveforms for the downsampling.
degree : int
degree for the interpolation.
Defaults to 3.
tol : float
Tolerance for the interpolation error.
Defaults to ``1e-5``.
"""
degree: int = 3
def __init__(self, dataset: Dataset, tol: float = 1e-5):
self.dataset = dataset
self.tol = tol
[docs] @abstractmethod
def train(self, training_dataset_size: int) -> DownsamplingIndices:
"""Calcalate downsampling with a generic algoritm,
training on a dataset with a given sizes."""
[docs] def validate_downsampling(
self, training_dataset_size: int, validating_dataset_size: int
) -> tuple[list[float], list[float]]:
r"""Check that the downsampling is working by looking at the
reconstruction error on a fresh dataset.
Parameters
----------
training_dataset_size : int
How many waveforms to train the downsampling on.
validating_dataset_size : int
How many waveforms to validate on.
Returns
-------
tuple[list[float], list[float]]
Amplitude and phase validation errors;
these are reported as :math:`L_\infty` errors:
the absolute maximum of the difference.
"""
amp_indices, phi_indices = self.train(training_dataset_size)
frequencies, _, residuals = self.dataset.generate_residuals(
size=validating_dataset_size
)
amp_residuals, phi_residuals = residuals
amp_validation = self.validate_indices(
amp_indices, frequencies, amp_residuals[-validating_dataset_size:]
)
phi_validation = self.validate_indices(
phi_indices, frequencies, phi_residuals[-validating_dataset_size:]
)
return amp_validation, phi_validation
[docs] @classmethod
def resample(
cls, x_ds: np.ndarray, new_x: np.ndarray, y_ds: np.ndarray
) -> np.ndarray:
"""Resample a function :math:`y(x)` from its values
at certain points :math:`y_{ds} = y(x_{ds})`.
Parameters
----------
x_ds : np.ndarray
Old, sparse :math:`x` values.
new_x : np.ndarray
New :math:`x` coordinates at which to evaluate the function.
y_ds : np.ndarrays
Old, sparse :math:`y` values.
Returns
-------
new_y : np.ndarray
Function evaluated at the coordinates ``new_x``.
"""
if x_ds.shape != y_ds.shape:
raise ValueError(
f"""Shape mismatch in the downsampling arrays!
The shape of x_ds is {x_ds.shape} while the shape of y_ds is {y_ds.shape}."""
)
return interpolate.splev(
new_x, tck=interpolate.splrep(x_ds, y_ds, s=0, k=cls.degree), der=0
)
def validate_indices(
self, indices: list[int], x_val: np.ndarray, ys_val: list[np.ndarray]
) -> list[float]:
validation = []
for y_val in ys_val:
ypred = self.resample(x_val[indices], x_val, y_val[indices])
validation.append(max(abs(y_val - ypred)))
return validation
[docs]class GreedyDownsamplingTraining(DownsamplingTraining):
[docs] def indices_error(
self, ytrue: np.ndarray, ypred: np.ndarray, current_indices: SortedList
) -> tuple[list[int], list[float]]:
"""Find new indices to add to the sampling.
Arguments
---------
ytrue : np.ndarray
True values of y.
ypred : np.ndarray
Predicted values of y through interpolation.
The algorithm minimizes the difference ``abs(y - ypred)``.
current_indices : SortedList
Indices to which the algorithm should add.
tol : float
Tolerance for the reconstruction error ---
new indices are not added if the reconstruction error is below this value.
Returns
-------
new_indices : list[int]
Indices to insert among the current ones.
errors : list[float]
Errors (``abs(y - y_pred)``) at the points where the
algorithm inserted the new indices.
"""
arr = np.abs(ytrue - ypred)
new_indices = []
errors = []
for key in range(len(current_indices) - 1):
i = (
np.argmax(arr[current_indices[key] : current_indices[key + 1]])
+ current_indices[key]
)
err = arr[i]
if err > self.tol:
new_indices.append(i)
errors.append(err)
return new_indices, errors
[docs] def find_indices(
self,
x_train: np.ndarray,
ys_train: list[np.ndarray],
seeds_number: int = 4,
) -> list[int]:
"""Greedily downsample y(x) by making sure that the reconstruction error of each of
the ys (instances of y(x)) is smaller than tol.
Arguments
---------
x_train : np.ndarray
x array
ys : np.ndarray
a list of y arrays
seeds_number : np.ndarray, optional
number of "seed" indices. Defaults to 4.
These are placed as equally spaced along the array.
Note: this should always be larger than the degree
for the interpolation.
Returns
-------
indices : np.ndarray
indices which make the interpolation errors smaller than
the tolerance on the training dataset.
"""
indices = SortedList(
list(np.linspace(0, len(x_train) - 1, num=seeds_number, dtype=int))
)
err = self.tol + 1
done_with_wf = np.zeros(len(ys_train), dtype=bool)
logging.info("Starting interpolation")
while not all(done_with_wf):
for i, y in enumerate(ys_train):
if done_with_wf[i]:
continue
ypred = self.resample(x_train[indices], x_train, y[indices])
indices_batch, errs = self.indices_error(y, ypred, indices)
if len(errs) < 1:
done_with_wf[i] = True
else:
indices.update(set(indices_batch))
err = min(max(errs), err)
logging.info(
"%i indices, error = %f = %f times the tol",
len(indices),
err,
err / self.tol,
)
return list(indices)
[docs] def train(self, training_dataset_size: int) -> DownsamplingIndices:
"""Compute a close-to-optimal set of indices at which to sample
waveforms, so that the reconstruction stays below a certain tolerance.
Parameters
----------
training_dataset_size : int
Number of waveforms to generate and with which to train.
Returns
-------
tuple[list[int], list[int]]
Indices for amplitude and phase, respectively.
"""
generator = self.dataset.make_parameter_generator()
param_set = self.dataset.parameter_set_cls.from_parameter_generator(
generator, training_dataset_size
)
waveforms = self.dataset.generate_waveforms_from_params(param_set)
frequencies = self.dataset.frequencies
amp_indices = self.find_indices(frequencies, list(waveforms.amplitudes))
phi_indices = self.find_indices(frequencies, list(waveforms.phases))
return DownsamplingIndices(amp_indices, phi_indices)
class GreedyDownsamplingTrainingWithResiduals(GreedyDownsamplingTraining):
def train(self, training_dataset_size: int) -> DownsamplingIndices:
"""Compute a close-to-optimal set of indices at which to sample
waveforms, so that the reconstruction stays below a certain tolerance.
Parameters
----------
training_dataset_size : int
Number of waveforms to generate and with which to train.
Returns
-------
tuple[list[int], list[int]]
Indices for amplitude and phase, respectively.
"""
frequencies, _, residuals = self.dataset.generate_residuals(
size=training_dataset_size
)
amp_residuals, phi_residuals = residuals
amp_indices = self.find_indices(
frequencies, amp_residuals[:training_dataset_size]
)
phi_indices = self.find_indices(
frequencies, phi_residuals[:training_dataset_size]
)
return DownsamplingIndices(amp_indices, phi_indices)