--------------------------------------------------------------------------------
-- Amethyst: Neural Network Verification in Agda
--
-- This module contains an implementation of naive linearisation, which
-- approximates non-linear functions using piecewise-linear approximations,
-- i.e., with a sequence of connected line segments.
-- The function `linearise` approximates a function `f : Float →  Float` over
-- an interval `(lower, upper)` using `pieces` line segments:
--
--   1. We compute a step size `step` which divides the interval `(lower,
--      upper)` into `pieces` subintervals of size `step`.
--   2. For each sub-interval `(lowerᵢ, upperᵢ)`, where `upperᵢ` is `lowerᵢ +
--      step`, we pick the line segment `fᵢ(x) = mᵢ * x + bᵢ`, with slope `mᵢ`
--      and y-intercept `bᵢ`, from `(lowerᵢ, f(lowerᵢ))` to `(upperᵢ, f(upperᵢ))`.
--   3. Finally, we connect all line segments `fᵢ`. The result is a piecewise-
--      linear approximation for `f` over the interval `(lower, upper)`.
--
-- The number of segments determines the granularity of the approximation,
-- though the approximations do not necessarily becomes less precise with an
-- increased number of segments: e.g., for the exponential function and tanh,
-- we observe that approximations which use an odd number `pieces` outperform
-- approximations which use `pieces + 1` segments.
--
-- Exports:
--
--   - LineSegment (slope; intercept)
--   - LineSegments ([]; _∷_)
--   - PiecewiseLinearFn ([]; _∷_)
--   - OutOfBoundsStrategy (constant; nearest; extrapolate)
--   - linearise
--
-- NOTE: The module also exports the following definitions, but these should be
--       considered "package private", and should not be relied upon.
--
--   - head
--   - _+[_*_]
--   - last
--
--------------------------------------------------------------------------------
module Amethyst.PiecewiseLinear.Base where

open import Data.Bool as Bool using (Bool; true; false; if_then_else_)
open import Data.Float as Float using (Float; _+_; _-_; _*_; _÷_; _≤ᵇ_)
open import Data.Nat as Nat using (; suc; zero; pred; NonZero)
open import Data.Vec as Vec using (Vec)

-- |Repeated addition of floating-point numbers, by recursion on a natural number.
--
--  NOTE: The reason this is done by recusion on the natural `z`, instead of by just
--        evaluating the floating-point expression, is that (at the time of writing)
--        it's not possible to prove properties about floating-point arithmetic in
--        Agda. By recursing on a natural, we can construct exactly the structure
--        that arises from the LineSegments construction.
--
_+[_*_] : (x y : Float) (z : )  Float
x +[ y * zero  ] = x
x +[ y * suc z ] = (x + y) +[ y * z ]

-- |A line segment is defined by a lower and upper bound on the x-values, a slope,
--  and and intercept.
record LineSegment (lower upper : Float) : Set where
  field
    slope     : Float
    intercept : Float

-- |A sequence of contiguous line segments.
data LineSegments (lower step : Float) : (pieces : )  Set where
  []  : LineSegments lower step 0
  _∷_ :  {pieces : }
        (ls : LineSegment lower (lower + step))
       (pl : LineSegments (lower + step) step pieces)
       LineSegments lower step (suc pieces)

-- |Different strategies for handling out of bounds values
-- How should a piecewise-linear approximation behave outside of the interval?
-- We have three simple options:
--
--   1. `constant`: specify a constant value
--   2. `extrapolate`: extrapolate the last line segment beyond the interval boundaries.
--   3. `nearest`: return the value of the nearest point within the extrapolated intervals
--
-- The second option is unsound, as it may result in cases where the codomain of
-- the piecewise-linear approximation is not a subset of the codomain of the
-- approximated function. For instance, the piecewise-linear approximation of
-- the exp-function may return values <0 for a sufficiently small input.
-- However, we have found that it works well in practice. The third option is
-- sound, albeit a bit crude.
data OutOfBoundsStrategy : Set where
  constant      : Float  OutOfBoundsStrategy
  extrapolate   : OutOfBoundsStrategy
  nearestValue  : OutOfBoundsStrategy

-- |A piecewise linear function
record PiecewiseLinearFn : Set where
  field
    lowerOOBStrat : OutOfBoundsStrategy
    {lower step}  : Float
    {pieces}      : 
    .{{pieces≢0}} : NonZero pieces
    lineSegments  : LineSegments lower step pieces
    upperOOBStrat : OutOfBoundsStrategy
  
  upper : Float
  upper = lower +[ step * suc pieces ]


private
  variable
    pieces : 
    lower  : Float
    step   : Float
    upper  : Float

-- |Return the first line segment in a piecewise-linear function.
first : .{{NonZero pieces}}  (pl : LineSegments lower step pieces)
       LineSegment lower (lower + step)
first (l  ls) = l

-- |Return the last line segment in a piecewise-linear function.
last : .{{NonZero pieces}}  (pl : LineSegments lower step pieces)
      LineSegment (lower +[ step * (pred pieces) ]) (lower +[ step * pieces ])
last (l  [])         = l
last (_  ls@(_  _)) = last ls

private
  -- |Approximate the function f between `lower` and `lower + step` using one line segment.
  lineSegment : (f : Float  Float) (lower step : Float)  LineSegment lower (lower + step)
  lineSegment f lower step =
    let upper     = lower + step
        slope     = (f upper - f lower) ÷ (upper - lower)
        intercept = f lower - (slope * lower)
    in record { slope = slope ; intercept = intercept }

  -- |Approximate the function f from lower, using line segments of length step.
  lineSegments : (f : Float  Float) (lower step : Float) (pieces : )
                LineSegments lower step pieces
  lineSegments f lower step zero = []
  lineSegments f lower step (suc pieces) = lineSegment f lower step  lineSegments f (lower + step) step pieces

-- |Approximate the function f between lower and upper using line segments.
linearise : (f : Float  Float) (lower upper : Float)
            (pieces : )  .{{NonZero pieces}}
           (lowerOOB upperOOB : OutOfBoundsStrategy)
           PiecewiseLinearFn
linearise f lower upper pieces@(suc _) lowerOOB upperOBB = record
  { lowerOOBStrat = lowerOOB
  ; lineSegments  = lineSegments f lower ((upper - lower) ÷ Float.fromℕ pieces) pieces
  ; upperOOBStrat = upperOBB
  }