Source code for pypesto.startpoint.prior

"""Prior-based sampling."""

import logging

import numpy as np

from ..objective import NegLogParameterPriors, NegLogPriors
from .base import CheckedStartpoints
from .uniform import uniform

logger = logging.getLogger(__name__)


[docs] class PriorStartpoints(CheckedStartpoints): """Generate startpoints from prior distribution. Samples from the prior distribution if available, otherwise falls back to uniform sampling. Ensures all samples are within bounds by resampling out-of-bounds values. """
[docs] def __init__( self, use_guesses: bool = True, check_fval: bool = False, check_grad: bool = False, ): """Initialize. Parameters ---------- use_guesses: Whether to use guesses provided in the problem. check_fval: Whether to check function values at the startpoint, and resample if not finite. check_grad: Whether to check gradients at the startpoint, and resample if not finite. """ super().__init__( use_guesses=use_guesses, check_fval=check_fval, check_grad=check_grad, )
[docs] def sample( self, n_starts: int, lb: np.ndarray, ub: np.ndarray, priors: NegLogParameterPriors | NegLogPriors | None = None, max_iterations: int = 100, ) -> np.ndarray: """Sample startpoints from prior or uniform distribution. Parameters ---------- n_starts: Number of startpoints to generate. lb: Lower parameter bound. ub: Upper parameter bound. priors: Parameter priors. If available, samples from priors; otherwise falls back to uniform sampling. max_iterations: Maximum number of resampling iterations to ensure all samples are within bounds. Returns ------- xs: Startpoints, shape (n_starts, n_par). """ # Fallback to uniform sampling if no priors available if priors is None: return uniform(n_starts=n_starts, lb=lb, ub=ub) n_par = len(lb) lb_reshaped = lb.reshape(1, -1) ub_reshaped = ub.reshape(1, -1) # Initialize samples and mark all rows for initial sampling samples = np.zeros((n_starts, n_par)) rows_to_resample = np.ones(n_starts, dtype=bool) iteration = 0 while True: n_resample = rows_to_resample.sum() if n_resample == 0: break # Generate new samples for the selected rows new_samples = np.zeros((n_resample, n_par)) # Sample from priors where available prior_samples_dict = priors.sample(n_samples=n_resample) for param_index in prior_samples_dict.keys(): new_samples[:, param_index] = prior_samples_dict[param_index] # Fill in uniform samples for parameters without priors for param_index in range(n_par): if param_index not in prior_samples_dict.keys(): new_samples[:, param_index] = np.random.uniform( lb[param_index], ub[param_index], n_resample ) # Replace selected rows in the main array samples[rows_to_resample] = new_samples # Check bounds out_of_bounds = np.logical_or( samples < lb_reshaped, samples > ub_reshaped ) # If all samples are within bounds, we are done if not np.any(out_of_bounds): break iteration += 1 if iteration >= max_iterations: samples = np.clip(samples, lb_reshaped, ub_reshaped) logger.warning( "Maximum startpoint resampling iterations reached. " "Some samples were clipped to bounds." ) break # Resample only rows that are still out-of-bounds rows_to_resample = np.any(out_of_bounds, axis=1) return samples