Profiling the model evaluation
Within the test suite for mlgw_bns the evaluation time for
a prediction by a trained model is checked thanks to pytest-benchmark;
however one might wish to get more granular information
about which functions are taking the longest to run.
In order to achieve this, we can make use of the cProfile module, combined
with the snakeviz visualization tool.
Both of these dependencies should be installed when running the command
poetry install
without the --no-dev option.
We write a script with the following imports:
from mlgw_bns import Model
from mlgw_bns.model import ParametersWithExtrinsic
import cProfile
import pstats
and define a utility function for the profilation of a function’s execution:
def profile(func, *args, **kwargs):
with cProfile.Profile() as pr:
func(*args, **kwargs)
stats = pstats.Stats(pr)
stats.sort_stats(pstats.SortKey.TIME)
return stats
Now, we are ready to start profiling:
we need a trained model, which we can obtain as outlined in Making a new model.
Let us call our model m.
We also need a set of parameters for the generated waveform: an example set is given here.
params = ParametersWithExtrinsic(
mass_ratio=1.0,
lambda_1=500,
lambda_2=50,
chi_1=0.1,
chi_2=-0.1,
distance_mpc=1.0,
inclination=0.0,
total_mass=2.8,
)
We also need an array of frequencies at which the new waveform will need to be computed:
the standard, FFT-grid frequency set is provided by
m.dataset.frequencies_hz, so we can use this as a baseline.
Its only issue is that it is very dense, so we can downsample by a certain amount,
for example defining
frequencies = m.dataset.frequencies_hz[::1024]
We are almost ready to profile:
one issue is the fact that several functions in the
prediction pipeline are decorated to make use of numba’s just-in-time compilation,
or they employ caching, therefore the first run of the prediction function will
be very slow (on the order of a few seconds) compared to the speed
the pipeline can achieve.
To account for this, we need to make a dry run of the prediction function before profiling it. The profilation code will therefore look like:
m.predict(frequencies, params)
stats = profile(m.predict, frequencies, params)
stats.dump_stats('prediction_profile.prof')
This will generate a profile file which is not human-readable,
but it can be nicely visualized with the snakeviz utility:
simply run
poetry run snakeviz prediction_profile.prof
from the command line, and a local webpage containing an interactive visualization of the execution times of the various internal functions.
The expected result, if the length of the required frequencies is of the order of a few thousands (as is needed, for example, with reduced order quadratures), is that:
most of the time will be taken by interpolation: the internal function is called
model.cartesian_waveforms_at_frequencies, which callsdownsampling_interpolation.resampletwice (once for amplitude, once for phase), and which in turns calls the scipy low-level functionssplrepandsplev;some time is taken by the prediction of the waveforms (
model.Model.predict_waveforms_bulk), of whichsome time (about half) is taken by the prediction of the residuals with the neural network and PCA:
model.Model.predict_residuals_bulk;some time (about half) is taken by the evaluation of the post-Newtonian amplitude and phase:
dataset_generation.Dataset.recompose_residuals;
any remaining time required should be comparatively small.