How to do automatic differentiation on complex datatypes?

后端 未结 1 1070
余生分开走
余生分开走 2020-12-10 08:46

Given a very simple Matrix definition based on Vector:

import Numeric.AD
import qualified Data.Vector as V

newtype Mat a = Mat { unMat :: V.Vector a }

scal         


        
相关标签:
1条回答
  • 2020-12-10 08:58

    The gradientDescent function from ad has the type

    gradientDescent :: (Traversable f, Fractional a, Ord a) =>
                       (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) ->
                       f a -> [f a]
    

    Its first argument requires a function of the type f r -> r where r is forall s. (Reverse s a). go has the type [a] -> a where a is the type bound in the signature of diffTest. These as are the same, but Reverse s a isn't the same as a.

    The Reverse type has instances for a number of type classes that could allow us to convert an a into a Reverse s a or back. The most obvious is Fractional a => Fractional (Reverse s a) which would allow us to convert as into Reverse s as with realToFrac.

    To do so, we'll need to be able to map a function a -> b over a Mat a to obtain a Mat b. The easiest way to do this will be to derive a Functor instance for Mat.

    {-# LANGUAGE DeriveFunctor #-}
    
    newtype Mat a = Mat { unMat :: V.Vector a }
        deriving Functor
    

    We can convert the m and fs into any Fractional a' => Mat a' with fmap realToFrac.

    diffTest m fs as0 = gradientDescent go as0
      where go xs = eq3' (fmap realToFrac m) xs (fmap (fmap realToFrac) fs)
    

    But there's a better way hiding in the ad package. The Reverse s a is universally qualified over all s but the a is the same a as the one bound in the type signature for diffTest. We really only need a function a -> (forall s. Reverse s a). This function is auto from the Mode class, for which Reverse s a has an instance. auto has the slightly wierd type Mode t => Scalar t -> t but type Scalar (Reverse s a) = a. Specialized for Reverse auto has the type

    auto :: (Reifies s Tape, Num a) => a -> Reverse s a
    

    This allows us to convert our Mat as into Mat (Reverse s a)s without messing around with conversions to and from Rational.

    {-# LANGUAGE ScopedTypeVariables #-}
    {-# LANGUAGE TypeFamilies #-}
    
    diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
    diffTest m fs as0 = gradientDescent go as0
      where
        go :: forall t. (Scalar t ~ a, Mode t) => [t] -> t
        go xs = eq3' (fmap auto m) xs (fmap (fmap auto) fs)
    
    0 讨论(0)
提交回复
热议问题