working with proofs involving CmpNat and singletons in Haskell

允我心安 提交于 2019-12-24 00:39:33

问题


I'm trying to create some functions to work with the following type. The following code uses the singletons and constraints libraries on GHC-8.4.1:

{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}

import Data.Constraint ((:-))
import Data.Singletons (sing)
import Data.Singletons.Prelude (Sing(SEQ, SGT, SLT), (%+), sCompare)
import Data.Singletons.Prelude.Num (PNum((+)))
import Data.Singletons.TypeLits (SNat)
import GHC.TypeLits (CmpNat, Nat)

data Foo where
  Foo
    :: forall (index :: Nat) (len :: Nat).
       (CmpNat index len ~ 'LT)
    => SNat index
    -> SNat len
    -> Foo

This is a GADT that contains a length and an index. The index is is guaranteed to be less than the length.

It is easy enough to write a function to create a Foo:

createFoo :: Foo
createFoo = Foo (sing :: SNat 0) (sing :: SNat 1)

However, I'm having trouble writing a function that increments the len while keeping the index the same:

incrementLength :: Foo -> Foo
incrementLength (Foo index len) = Foo index (len %+ (sing :: SNat 1))

This is failing with the following error:

file.hs:34:34: error:
    • Could not deduce: CmpNat index (len GHC.TypeNats.+ 1) ~ 'LT
        arising from a use of ‘Foo’
      from the context: CmpNat index len ~ 'LT
        bound by a pattern with constructor:
                   Foo :: forall (index :: Nat) (len :: Nat).
                          (CmpNat index len ~ 'LT) =>
                          SNat index -> SNat len -> Foo,
                 in an equation for ‘incrementLength’
        at what5.hs:34:17-29
    • In the expression: Foo index (len %+ (sing :: SNat 1))
      In an equation for ‘incrementLength’:
          incrementLength (Foo index len)
            = Foo index (len %+ (sing :: SNat 1))
    • Relevant bindings include
        len :: SNat len (bound at what5.hs:34:27)
        index :: SNat index (bound at what5.hs:34:21)
   |
34 | incrementLength (Foo index len) = Foo index (len %+ (sing :: SNat 1))
   |                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This makes sense, since the compiler knows that CmpNat index len ~ 'LT (from the definition of Foo), but doesn't know that CmpNat index (len + 1) ~ 'LT.

Is there any way to get something like this to work?

It is possible to use sCompare to do something like this:

incrementLength :: Foo -> Foo
incrementLength (Foo index len) =
  case sCompare index (len %+ (sing :: SNat 1)) of
    SLT -> Foo index (len %+ (sing :: SNat 1))
    SEQ -> error "not eq"
    SGT -> error "not gt"

However, it seems unfortunate that I have to write cases for SEQ and SGT, when I know they will never be matched.

Also, I thought it might be possible to create a type like the following:

plusOneLTProof :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
plusOneLTProof = undefined

However, this gives an error like the following:

file.hs:40:8: error:
    • Couldn't match type ‘CmpNat n0 m0’ with ‘CmpNat n m’
      Expected type: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
        Actual type: (CmpNat n0 m0 ~ 'LT) :- (CmpNat n0 (m0 + 1) ~ 'LT)
      NB: ‘CmpNat’ is a non-injective type family
      The type variables ‘n0’, ‘m0’ are ambiguous
    • In the ambiguity check for ‘bar’
      To defer the ambiguity check to use sites, enable AllowAmbiguousTypes
      In the type signature:
        bar :: (CmpNat n m ~  'LT) :- (CmpNat n (m + 1) ~  'LT)
   |
40 | bar :: (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
   |        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This makes sense, I guess, since CmpNat is non-injective. However, I know that this implication is true, so I'd like to be able to write this function.


I'd like an answer to the following two questions:

  1. Is there a way to write incrementLength where you'd only have to match on SLT? I'd be fine with changing the definition of Foo to make this easier.

  2. Is there a way to write plusOneLTProof, or at least something similar?


Update: I ended up using the suggestion from Li-yao Xia to write plusOneLTProof and incrementLength like the following:

incrementLength :: Foo -> Foo
incrementLength (Foo index len) =
  case plusOneLTProof index len of
    Sub Dict -> Foo index (len %+ (sing :: SNat 1))

plusOneLTProof :: forall n m. SNat n -> SNat m -> (CmpNat n m ~ 'LT) :- (CmpNat n (m + 1) ~ 'LT)
plusOneLTProof SNat SNat = Sub axiom
  where
    axiom :: CmpNat n m ~ 'LT => Dict (CmpNat n (m + 1) ~ 'LT)
    axiom = unsafeCoerce (Dict :: Dict (a ~ a))

This requires that you pass in two SNats to plusOneLTProof, but it doesn't require AllowAmbiguousTypes.


回答1:


The compiler is rejecting plusOneLTProof because its type is ambiguous. We can disable that constraint with the extension AllowAmbiguousTypes. I would recommend using that together with ExplicitForall (which is implied by ScopedTypeVariables, that we'll certainly need anyway, or RankNTypes). That's for defining it. A definition that has an ambiguous type can be used with TypeApplications.

However, GHC still can't reason about naturals, so we can't define plusOneLTProof = Sub Dict, much less incrementLength, not safely.

But we can still create the proof out of thin air with unsafeCoerce. This is in fact how the Data.Constraint.Nat module in constraints is implemented; unfortunately it currently doesn't contain any facts about CmpNat. The coercion works because there is no runtime content in type equalities. Even if the runtime values look fine, thus asserting incoherent facts can still lead to programs to go wrong.

plusOneLTProof :: forall n m. (CmpNat n m ~ 'LT) :- (CmpNat n (m+1) ~ 'LT)
plusOneLTProof = Sub axiom
  where
    axiom :: (CmpNat n m ~ 'LT) => Dict (CmpNat n (m+1) ~ 'LT)
    axiom = unsafeCoerce (Dict :: Dict (a ~ a))

To use the proof, we specialize it (with TypeApplications) and pattern match on it to introduce the RHS in the context, checking that the LHS holds.

incrementLength :: Foo -> Foo
incrementLength (Foo (n :: SNat n) (m :: SNat m)) =
  case plusOneLTProof @n @m of
    Sub Dict -> Foo n (m %+ (sing :: SNat 1))


来源:https://stackoverflow.com/questions/49663871/working-with-proofs-involving-cmpnat-and-singletons-in-haskell

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!