"""The code in this module contains BFBrain's default algorithm for
labelling if a point is bounded from below. It is designed to work with
the methods in Active_Learning.py. The fundamental strategy of this
labelling method is to take a quartic part of the potential function with
fixed quartic coefficients, and attempt many consecutive local
optimizations with respect to the scalar vev with random initial starting
points. If it can find a point where the potential is negative, it will
label the set of quartic coefficients as not bounded from below, while
if after some user-specified number of local minimization iterations the
found minima are always positive, the set will be labelled as bounded
from below.
"""
import jax
from jax import random as jrandom
import jax.numpy as jnp
import jax.scipy as jsp
from functools import partial
from bfbrain.Hypersphere_Formulas import jax_convert_to_polar
import numpy as np
import jaxopt
import os
#This line forces Jax to only grab memory from the GPU as it needs it--
#otherwise it will by default reserve 90% of the available GPU memory and leave Tensorflow to live on scraps.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
[docs]
def take_step(x0, key, polar):
"""Return a tuple of phi, key (in that order) after randomly
generating phi as a vev point on the unit hypersphere.
Parameters
----------
x0 : jnp.array(jnp.float32)
A Jax NumPy array that tells the function what shape the vev input
arrays should be.
key : Jax PRNGGKey
A key object used to generate random numbers in Jax. Will be
consumed over the running of the function and a new key will be
returned.
polar : bool
A flag denoting whether the labeller should use Cartesian
coordinates for the vev (False) or convert them into polar
coordinates on the unit hypersphere (True).
Returns
-------
rand_phi : jnp.array(jnp.float32).
A 1-D Jax Numpy array representing a single point in vev space.
new_key : Jax PRNGKey
A key object used to generate random numbers in Jax. Supplied to
replace the consumed key.
"""
# Generate two new keys from key: subkey to be consumed to generate a
# random vev configuration and new_key to pass back for future random
# number generation.
new_key, subkey = jrandom.split(key)
# Generate a random vev configuration on the phi_len-dimensional unit hypersphere.
rand_phi = (jrandom.normal(key = subkey, shape = jnp.shape(x0)))
rand_phi = rand_phi / jnp.linalg.norm(rand_phi)
# If the minimizer is using polar coordinates, convert rand_phi
# to polar coordinates. Otherwise simply return rand_phi (and new_key).
if(polar):
return jax_convert_to_polar(rand_phi), new_key
else:
return rand_phi, new_key
[docs]
def run_one_step(key, minimizer, stepper, lam, polar):
"""Takes a single local minimization step. This consists of randomly
generating a starting point, running the local minimizer from that
starting point, and extracting the potential value after minimization.
Parameters
----------
key : Jax PRNGKey
A key object used to generate random numbers in Jax. Will be
consumed over the running of the function and a new key will be
returned.
minimizer : JaxOpt Projected Gradient optimizer object.
stepper : callable
The function take_step, wrapped to require only a PRNGKey object
as input.
lam : jnp.array(jnp.float32)
A 1-D Jax Numpy array representing a set of quartic coefficients
in the scalar potential.
Returns
-------
energy_after_quench : jnp.float32
The minimum value of the potential found by minimizer after
starting from a random starting position generated with key.
new_key : Jax PRNGKey
A key object used to generate random numbers in Jax. Supplied to
replace the consumed key.
"""
# Generate an initial position for the minimizer to use from the
# function stepper. This will also consume the PRNG key and
# necessitate acquiring a new one.
x_after_step, new_key = stepper(key = key)
# Run the minimizer from the point x_after_step. The arguments
# for the bounds on variables in the minimizer depend on whether
# the minimizer is using polar coordinates or not.
if(polar):
ub = jnp.pi*jnp.ones(jnp.shape(x_after_step)[0])
ub = ub.at[-1].multiply(2.)
minres = minimizer.run(x_after_step, hyperparams_proj = (jnp.zeros(jnp.shape(x_after_step)[0]), ub), lam = lam)
else:
minres = minimizer.run(x_after_step, hyperparams_proj = 1., lam = lam)
# Compute the value of the quartic part of the potential at the
# minimum.
energy_after_quench = minimizer.fun(minres[0], lam)[0]
return energy_after_quench, new_key
[docs]
@partial(jax.jit, static_argnums=[3,5])
def jax_basinhopping(lam, rng_key, x0, minimizer, niter, polar):
"""Iterate over run_one_step a large number of times.
Parameters
----------
lam : jnp.array(jnp.float32)
A 1-D Jax Numpy array representing a set of quartic coefficients
in the scalar potential.
rng_key : Jax PRNGKey
Used to generate random numbers.
x0 : jnp.array(jnp.float32)
A Jax NumPy array that tells the function what shape the vev input
arrays should be.
minimizer : JaxOpt Projected Gradient optimizer object.
niter : int
The maximum number of run_one_step iterations to perform in an
attempt to find a negative local minimum of the potential.
polar : bool
A flag denoting whether the labeller should use Cartesian
coordinates for the vev (False) or convert them into polar
coordinates on the phi_len-dimensional unit hypersphere (True).
Returns
-------
new_min_energy: jnp.float32
The smallest local minimum that the optimizer found after its
iterations.
"""
# Define the run_step method from run_one_step, specifying various arguments.
run_step = partial(run_one_step, minimizer=minimizer, stepper=partial(take_step, x0 = x0, polar = polar), lam=lam, polar = polar)
# Define a Jax while loop. This loop performs run_one_step repeatedly
# and keeps track of the minimum potential value it's found as min_energy.
# If niter iterations have been performed or min_energy is
# ever negative, break out of the loop.
def body(carry):
nstep, min_energy, key = carry
trial_energy, new_key = run_step(key = key)
new_nstep = nstep + jnp.array(1)
min_energy = jax.lax.select(trial_energy > min_energy, min_energy, trial_energy)
return new_nstep, min_energy, new_key
def cond_fun(carry):
nstep, min_energy, _ = carry
break_cond_negative = min_energy < jnp.array(0.)
return (~break_cond_negative) & (nstep < niter)
_, new_min_energy, _ = jax.lax.while_loop(cond_fun, body, init_val=(jnp.array(0), jnp.array(1.), rng_key))
# Return the minimum energy found from the loop.
return new_min_energy
[docs]
@partial(jax.jit, static_argnums=[0,6])
@partial(jax.vmap, in_axes=(None, 0, 0, None, None, None, None))
def vectorized_minTest(func, lam, key, x0, niter, tol, polar):
"""A vectorized version of jax_basinhopping.
Parameters
----------
func : callable
A Jax Numpy function that returns the quartic potential and its
gradient with respect to the vev parameters. This function will
be generated by a BFBrain.DataManager object.
lam : jnp.array(jnp.float32, jnp.float32)
A 2-D Jax NumPy array representing multiple sets of quartic
coefficients for the potential.
key : Jax PRNGKey
x0 : jnp.array(jnp.float32)
A Jax NumPy array that informs the function about the shape of
the vev input.
niter : int
The maximum number of local minimization iterations the minimizer
should perform for each set of quartic coefficients in lam,
searching for a negative minimum potential value.
tol : jnp.float32
The tolerance for the local minimizer to stop.
polar : bool
A flag denoting whether the labeller should use Cartesian
coordinates for the vev (False) or convert them into polar
coordinates on the phi_len-dimensional unit hypersphere (True).
Returns
-------
jnp.array(jnp.float32)
An array of minimum energy values found for the potentials
specified by the elements of lam.
"""
# Print statement to alert user to retracing. Retracing is expensive
# and should be minimized.
print('recompiling vectorized_minTest...')
# Set up a local minimizer based on projected gradient descent,
# with slightly different construction depending on whether or not
# the solver is using polar coordinates.
if(polar):
minimizer = jaxopt.ProjectedGradient(func, jaxopt.projection.projection_box, tol=tol, value_and_grad = True)
else:
minimizer = jaxopt.ProjectedGradient(func, jaxopt.projection.projection_l2_sphere, tol=tol, value_and_grad = True)
# Run jax_basinhopping. The jax.vmap decorator vectorizes this code.
return jax_basinhopping(lam, key, x0, minimizer = minimizer, niter = niter, polar = polar) > 0
[docs]
def label_func(func, phi_len, polar, rng, lam, niter = 100, tol = 0.001, cutoff = 150000):
"""The function which interfaces directly with BFBrain.DataManager's
methods for handling oracles.
Parameters
----------
func : callable
The Jax Numpy function that returns the quartic part of the
potential and the gradient.
phi_len : int
The number of real parameters necessary to uniquely specify a vev.
polar : bool
A flag denoting whether the labeller should use Cartesian
coordinates for the vev (False) or convert them into polar
coordinates on the phi_len-dimensional unit hypersphere (True).
rng : np.random.Generator
The NumPy random number generator which will generate the initial
PRNGKey used by this oracle.
lam : np.array(np.float32,np.float32)
A 2-D Numpy array of quartic potential coefficients.
niter : int, default=100
The number of local minimizations to perform on the potential
before declaring it to be bounded-from-below.
tol : float, default=0.001
The tolerance for the local minimizer.
cutoff : int, default=150000
The maximum size of a batch of coefficient values to pass to the
GPU at one time. If lam consists of more sets of coefficient
values than this, the method will split it into digestible batches.
Returns
-------
np.array(bool)
a 1-D NumPy array of labels for each set of quartic coefficients
in lam. Labels False for points where the labeller found a
negative local minimum and True otherwise.
"""
# If the size of lam is below cutoff, transfer the entirety of lam to
# the GPU at once and label it.
if(len(lam) <= cutoff):
return label_func_do_batch(func, phi_len, polar, rng, lam, niter, tol)
# Otherwise, split lam into smaller chunks that are guaranteed to have
# length smaller than cutoff.
else:
lam_arr = np.array_split(lam, len(lam)//cutoff + 1)
out_arr = []
for in_lam in lam_arr:
out_arr.append(label_func_do_batch(func, phi_len, polar, rng, in_lam, niter, tol))
return np.concatenate(out_arr)
[docs]
def label_func_do_batch(func, phi_len, polar, rng, lam, niter, tol):
"""The method usedy by label_func to transfer the coefficient data to
the GPU and perform the jit-compiled analysis with Jax.
Parameters
----------
func : callable
The Jax Numpy function that returns the quartic part of the
potential and the gradient.
phi_len : int
The number of real parameters necessary to uniquely specify a vev.
polar : bool
A flag denoting whether the labeller should use Cartesian
coordinates for the vev (False) or convert them into polar
coordinates on the phi_len-dimensional unit hypersphere (True).
rng : np.random.Generator
The NumPy random number generator which will generate the initial
PRNGKey used by this oracle.
lam : np.array(np.float32,np.float32)
A 2-D Numpy array of quartic potential coefficients.
niter : int
The number of local minimizations to perform on the potential
before declaring it to be bounded-from-below.
tol : float
The tolerance for the local minimizer.
Returns
-------
np.array(bool)
a 1-D NumPy array of labels for each set of quartic coefficients
in lam. Labels False for points where the labeller found a
negative local minimum and True otherwise.
"""
# Create a Jax random number generation key by randomly generating a 64-bit integer.
key = jrandom.PRNGKey(rng.integers((2**63) - 1))
# Split the key into a vector of random keys in order to use each with one point's minimization problem.
keys = jrandom.split(key,len(lam))
# Copy lam into a Jax Numpy array on the GPU.
jnp_lam = jnp.array(lam)
x0 = jnp.zeros(phi_len)
# Use vectorized_minTest in order to label each state.
return np.array(vectorized_minTest(func, jnp_lam, keys, x0, jnp.array(niter), jnp.array(tol), polar))
[docs]
def test_labeller(func, phi_len, polar, rng, lam, niter = 100, tol = 0.001, cutoff = 150000, niter_step = 50, count_success = 5, max_iter = 20, verbose = False):
"""A method to test the accuracy of the oracle. Will perform
label_func repeatedly for the same 2-D NumPy array of quartic
coefficients, but with niter increased each time, until the same
labels are returned for for a specified consecutive number of
iterations, or some maximum number of labellings has been completed
without finding consistent results.
Parameters
----------
func : callable
The Jax Numpy function that returns the quartic part of the
potential and the gradient.
phi_len : int
The number of real parameters necessary to uniquely specify a vev.
polar : bool
A flag denoting whether the labeller should use Cartesian
coordinates for the vev (False) or convert them into polar
coordinates on the phi_len-dimensional unit hypersphere (True).
rng : np.random.Generator
The NumPy random number generator which will generate the initial
PRNGKey used by this oracle.
lam : np.array(np.float32,np.float32)
A 2-D Numpy array of quartic potential coefficients.
niter : int, default=100
The initial number of local minimizations to perform on the
potential before declaring it to be bounded-from-below-- this
value will be incremented over the running of the method.
tol : float, default=0.001
The tolerance for the local minimizer.
cutoff : int, default=150000
The maximum size of a batch of coefficient values to pass to the
GPU at one time. If lam consists of more sets of coefficient
values than this, the method will split it into digestible batches.
niter_step : int, default=50
The amount to increment the niter parameter of label_func with
each successive attempt at labelling lam.
count_success : int, default=5
The number of consecutive labelling attempts that must yield
identical labels for the function to declare that increasing
niter is no longer affecting the results of label_func.
max_iter : int, default=20
The maximum number of labelling attempts that the method will
make. If no consistent results are found before that time, the
test ends in failure.
verbose : bool, default=False
If True, print out statements informing the user of the progress
of the method.
Returns
-------
int
The minimum niter parameter such that count_success consecutive
attempts to label lam with increasing niter yielded the same
label. If max_iter attempts are made without running into
count_success consecutive identical label results, -1 is returned.
"""
# An integer to keep track of the number of consecutive iterations with
# identical labels.
count = 0
# A NumPy array to keep track of the latest iteration's labels.
current_res = np.zeros(shape=len(lam), dtype=bool)
# The index of the earliest iteration to give the same labels as the
# current iteration.
current_ind = 0
for i in range(max_iter):
# In the loop, repeatedly call label_func. With each loop
# iteration, increment niter by niter_step.
if(verbose):
print('doing round ' + str(i))
new_res = label_func(func, phi_len, polar, rng, lam, niter + i*niter_step, tol, cutoff)
if(verbose):
print('done!')
# If the labels produced from this iteration differ at all from
# those produced by the previous iteration, reset count and
# update current_res, and current_ind.
is_update = np.any(current_res != new_res)
if(is_update):
if(verbose):
print('updating the res array from ' + str(len(current_res[current_res])) + ' positives to ' + str(len(new_res[new_res])) + ' positives...')
count = 0
current_res = new_res
current_ind = i
# If the labels are the same as the previous iteration,
# increment count.
else:
count += 1
# If count_success consecutive iterations have given the same
# labels, or it is now impossible for that to occur before
# max_iter iterations are completed, break out of the loop.
if(count >= count_success or count + max_iter - i - 1 < count_success):
break
# If the function was successful in finding consistent labels,
# return the smallest niter value such that consecutive iterations
# yielded the same result. If the function was unsuccessful, return -1.
if(count < count_success):
if(verbose):
print('failed to find consistent results after ' + str(max_iter) + ' iterations. Recommend decreasing tol or starting again with niter = ' + str(niter + niter_step*(max_iter - 1)))
return -1
else:
if(verbose):
print('Found consistent results for niter >= ' + str(niter + current_ind*niter_step))
return niter + current_ind*niter_step