--------------------------------------------------------------------------------
-- Amethyst: Neural Network Verification in Agda
--
-- This module exports the basic types for representing neural networks.
--
-- Exports:
--
--   - Activation (linear; relu; sigmoid; softmax; tanh)
--   - Layer      (activation; weights; biases)
--   - Network    ([]; _∷_)
--
--------------------------------------------------------------------------------

module Amethyst.Network.Base where

import Data.Vec.Extra as Vec

open import Data.Fin.Base as Fin using (Fin)
open import Data.Float as Float using (Float)
open import Data.Nat.Base as Nat using (; suc; zero)
open import Data.Vec.Base as Vec using (Vec; []; _∷_)
open import Function using (id)

private
  variable
    A : Set
    n : 
    hidden  : 
    layers  : 

data Activation : Set where
  linear  : Activation
  relu    : Activation
  sigmoid : Activation
  softmax : Activation
  tanh    : Activation

record Layer (A : Set) (inputs outputs : ) : Set where
  field
    activation : Activation
    weights    : Vec (Vec A outputs) inputs
    biases     : Vec A outputs

LayerSpec :   Set
LayerSpec = Vec 

∣_₀∣ : LayerSpec n  
 xs ₀∣ = Vec.headOr xs 0

∣_ₙ∣ : LayerSpec n  
 xs ₙ∣ = Vec.lastOr xs 0

infixr 5 _∷_

data Network (A : Set) :  {#layers : }  Vec  #layers  Set where
  []  :  {n}  Network A (n  [])
  _∷_ :  {n |lᵢ| |lᵢ₊₁|} {xs : LayerSpec n} 
        Layer A |lᵢ| |lᵢ₊₁|  Network A (|lᵢ₊₁|  xs)  Network A (|lᵢ|  |lᵢ₊₁|  xs)