plot_sparsity (generic function with 1 method)
Tracking, counters and custom callbacks for Frank Wolfe
In this example we will run the standard Frank-Wolfe algorithm while tracking the number of calls to the different oracles, namely function, gradient evaluations, and LMO calls. In order to track each of these metrics, a "Tracking" version of the Gradient, LMO and Function methods have to be supplied to the frank_wolfe algorithm, which are wrapping a standard one.
using FrankWolfe
using Test
using LinearAlgebra
using FrankWolfe: ActiveSet
The trackers for primal objective, gradient and LMO.
In order to count the number of function calls, a TrackingObjective
is built from a standard objective function f
, which will act in the same way as the original function does, but with an additional .counter
field which tracks the number of calls.
f(x) = norm(x)^2
tf = FrankWolfe.TrackingObjective(f)
@show tf.counter
tf(rand(3))
@show tf.counter
# Resetting the counter
tf.counter = 0;
tf.counter = 0
tf.counter = 1
Similarly, the tgrad!
function tracks the number of gradient calls:
function grad!(storage, x)
return storage .= 2x
end
tgrad! = FrankWolfe.TrackingGradient(grad!)
@show tgrad!.counter;
tgrad!.counter = 0
The tracking LMO operates in a similar fashion and tracks the number of compute_extreme_point
calls.
lmo_prob = FrankWolfe.ProbabilitySimplexOracle(1)
tlmo_prob = FrankWolfe.TrackingLMO(lmo_prob)
@show tlmo_prob.counter;
tlmo_prob.counter = 0
The tracking LMO can be applied for all types of LMOs and even in a nested way, which can be useful to track the number of calls to a lazified oracle. We can now pass the tracking versions tf
, tgrad
and tlmo_prob
to frank_wolfe
and display their call counts after the optimization process.
x0 = FrankWolfe.compute_extreme_point(tlmo_prob, ones(5))
fw_results = FrankWolfe.frank_wolfe(
tf,
tgrad!,
tlmo_prob,
x0,
max_iteration=1000,
line_search=FrankWolfe.Agnostic(),
callback=nothing,
)
@show tf.counter
@show tgrad!.counter
@show tlmo_prob.counter;
tf.counter = 1
tgrad!.counter = 1002
tlmo_prob.counter = 1003
Adding a custom callback
A callback is a user-defined function called at every iteration of the algorithm with the current state passed as a named tuple.
We can implement our own callback, for example with:
- Extended trajectory logging, similar to the
trajectory = true
option - Stop criterion after a certain number of calls to the primal objective function
To reuse the same tracking functions, Let us first reset their counters:
tf.counter = 0
tgrad!.counter = 0
tlmo_prob.counter = 0;
The storage
variable stores in the trajectory array the number of calls to each oracle at each iteration.
storage = []
Any[]
Now define our own trajectory logging function that extends the five default logged elements (iterations, primal, dual, dual_gap, time)
with ".counter" field arguments present in the tracking functions.
function push_tracking_state(state, storage)
base_tuple = FrankWolfe.callback_state(state)
if state.lmo isa FrankWolfe.CachedLinearMinimizationOracle
complete_tuple = tuple(
base_tuple...,
state.gamma,
state.f.counter,
state.grad!.counter,
state.lmo.inner.counter,
)
else
complete_tuple = tuple(
base_tuple...,
state.gamma,
state.f.counter,
state.grad!.counter,
state.lmo.counter,
)
end
return push!(storage, complete_tuple)
end
push_tracking_state (generic function with 1 method)
In case we want to stop the frank_wolfe algorithm prematurely after a certain condition is met, we can return a boolean stop criterion false
. Here, we will implement a callback that terminates the algorithm if the primal objective function is evaluated more than 500 times.
function make_callback(storage)
return function callback(state, args...)
push_tracking_state(state, storage)
return state.f.counter < 500
end
end
callback = make_callback(storage)
callback (generic function with 1 method)
We can show the difference between this standard run and the lazified conditional gradient algorithm which does not call the LMO at each iteration.
FrankWolfe.lazified_conditional_gradient(
tf,
tgrad!,
tlmo_prob,
x0,
max_iteration=1000,
traj_data=storage,
line_search=FrankWolfe.Agnostic(),
callback=callback,
)
total_iterations = storage[end][1]
@show total_iterations
@show tf.counter
@show tgrad!.counter
@show tlmo_prob.counter;
┌ Warning: Lazification is not known to converge with open-loop step size strategies.
└ @ FrankWolfe ~/work/FrankWolfe.jl/FrankWolfe.jl/src/fw_algorithms.jl:315
total_iterations = 500
tf.counter = 501
tgrad!.counter = 501
tlmo_prob.counter = 13
This page was generated using Literate.jl.