--------------------------------------------------------------------------------
-- Amethyst: Neural Network Verification in Agda
--
-- This module contains an implementation of naive linearisation, which
-- approximates non-linear functions using piecewise-linear approximations,
-- i.e., with a sequence of connected line segments.
-- The function `linearise` approximates a function `f : Float → Float` over
-- an interval `(lower, upper)` using `pieces` line segments:
--
-- 1. We compute a step size `step` which divides the interval `(lower,
-- upper)` into `pieces` subintervals of size `step`.
-- 2. For each sub-interval `(lowerᵢ, upperᵢ)`, where `upperᵢ` is `lowerᵢ +
-- step`, we pick the line segment `fᵢ(x) = mᵢ * x + bᵢ`, with slope `mᵢ`
-- and y-intercept `bᵢ`, from `(lowerᵢ, f(lowerᵢ))` to `(upperᵢ, f(upperᵢ))`.
-- 3. Finally, we connect all line segments `fᵢ`. The result is a piecewise-
-- linear approximation for `f` over the interval `(lower, upper)`.
--
-- The number of segments determines the granularity of the approximation,
-- though the approximations do not necessarily becomes less precise with an
-- increased number of segments: e.g., for the exponential function and tanh,
-- we observe that approximations which use an odd number `pieces` outperform
-- approximations which use `pieces + 1` segments.
--
-- Exports:
--
-- - LineSegment (slope; intercept)
-- - LineSegments ([]; _∷_)
-- - PiecewiseLinearFn ([]; _∷_)
-- - OutOfBoundsStrategy (constant; nearest; extrapolate)
-- - linearise
--
-- NOTE: The module also exports the following definitions, but these should be
-- considered "package private", and should not be relied upon.
--
-- - head
-- - _+[_*_]
-- - last
--
--------------------------------------------------------------------------------
module Amethyst.PiecewiseLinear.Base where
open import Data.Bool as Bool using (Bool; true; false; if_then_else_)
open import Data.Float as Float using (Float; _+_; _-_; _*_; _÷_; _≤ᵇ_)
open import Data.Nat as Nat using (ℕ; suc; zero; pred; NonZero)
open import Data.Vec as Vec using (Vec)
-- |Repeated addition of floating-point numbers, by recursion on a natural number.
--
-- NOTE: The reason this is done by recusion on the natural `z`, instead of by just
-- evaluating the floating-point expression, is that (at the time of writing)
-- it's not possible to prove properties about floating-point arithmetic in
-- Agda. By recursing on a natural, we can construct exactly the structure
-- that arises from the LineSegments construction.
--
_+[_*_] : (x y : Float) (z : ℕ) → Float
x +[ y * zero ] = x
x +[ y * suc z ] = (x + y) +[ y * z ]
-- |A line segment is defined by a lower and upper bound on the x-values, a slope,
-- and and intercept.
record LineSegment (lower upper : Float) : Set where
field
slope : Float
intercept : Float
-- |A sequence of contiguous line segments.
data LineSegments (lower step : Float) : (pieces : ℕ) → Set where
[] : LineSegments lower step 0
_∷_ : ∀ {pieces : ℕ}
(ls : LineSegment lower (lower + step))
→ (pl : LineSegments (lower + step) step pieces)
→ LineSegments lower step (suc pieces)
-- |Different strategies for handling out of bounds values
-- How should a piecewise-linear approximation behave outside of the interval?
-- We have three simple options:
--
-- 1. `constant`: specify a constant value
-- 2. `extrapolate`: extrapolate the last line segment beyond the interval boundaries.
-- 3. `nearest`: return the value of the nearest point within the extrapolated intervals
--
-- The second option is unsound, as it may result in cases where the codomain of
-- the piecewise-linear approximation is not a subset of the codomain of the
-- approximated function. For instance, the piecewise-linear approximation of
-- the exp-function may return values <0 for a sufficiently small input.
-- However, we have found that it works well in practice. The third option is
-- sound, albeit a bit crude.
data OutOfBoundsStrategy : Set where
constant : Float → OutOfBoundsStrategy
extrapolate : OutOfBoundsStrategy
nearestValue : OutOfBoundsStrategy
-- |A piecewise linear function
record PiecewiseLinearFn : Set where
field
lowerOOBStrat : OutOfBoundsStrategy
{lower step} : Float
{pieces} : ℕ
.{{pieces≢0}} : NonZero pieces
lineSegments : LineSegments lower step pieces
upperOOBStrat : OutOfBoundsStrategy
upper : Float
upper = lower +[ step * suc pieces ]
private
variable
pieces : ℕ
lower : Float
step : Float
upper : Float
-- |Return the first line segment in a piecewise-linear function.
first : .{{NonZero pieces}} → (pl : LineSegments lower step pieces)
→ LineSegment lower (lower + step)
first (l ∷ ls) = l
-- |Return the last line segment in a piecewise-linear function.
last : .{{NonZero pieces}} → (pl : LineSegments lower step pieces)
→ LineSegment (lower +[ step * (pred pieces) ]) (lower +[ step * pieces ])
last (l ∷ []) = l
last (_ ∷ ls@(_ ∷ _)) = last ls
private
-- |Approximate the function f between `lower` and `lower + step` using one line segment.
lineSegment : (f : Float → Float) (lower step : Float) → LineSegment lower (lower + step)
lineSegment f lower step =
let upper = lower + step
slope = (f upper - f lower) ÷ (upper - lower)
intercept = f lower - (slope * lower)
in record { slope = slope ; intercept = intercept }
-- |Approximate the function f from lower, using line segments of length step.
lineSegments : (f : Float → Float) (lower step : Float) (pieces : ℕ)
→ LineSegments lower step pieces
lineSegments f lower step zero = []
lineSegments f lower step (suc pieces) = lineSegment f lower step ∷ lineSegments f (lower + step) step pieces
-- |Approximate the function f between lower and upper using line segments.
linearise : (f : Float → Float) (lower upper : Float)
(pieces : ℕ) → .{{NonZero pieces}}
→ (lowerOOB upperOOB : OutOfBoundsStrategy)
→ PiecewiseLinearFn
linearise f lower upper pieces@(suc _) lowerOOB upperOBB = record
{ lowerOOBStrat = lowerOOB
; lineSegments = lineSegments f lower ((upper - lower) ÷ Float.fromℕ pieces) pieces
; upperOOBStrat = upperOBB
}