Idris - map an operation on a n-dimensional vector

那年仲夏 提交于 2021-01-27 07:35:22

问题


I defined n-dimensional vectors in Idris as follows:

import Data.Vect

NDVect : (Num t) => (rank : Nat) -> (shape : Vect rank Nat) -> (t : Type) -> Type
NDVect Z     []      t = t
NDVect (S n) (x::xs) t = Vect x (NDVect n xs t)

Then I defined the following function which maps a function f to every entry in the tensor.

iterateT : (f : t -> t') -> (v : NDVect r s t) -> NDVect r s t'
iterateT {r = Z}   {s = []}    f v = f v
iterateT {r = S n} {s = x::xs} f v = map (iterateT f) v

But when I try to call iteratorT in the following function:

scale : Num t => (c : t) -> (v : NDVect rank shape t) -> NDVect rank shape t
scale c v = iterateT (*c) v

I get the following error message saying there is a type mismatched, which seems pretty fine to me

 When checking right hand side of scale with expected type
         NDVect rank shape t

 When checking argument v to function Main.iterateT:
         Type mismatch between
                 NDVect rank shape t (Type of v)
         and
                 NDVect r s t (Expected type)

         Specifically:
                 Type mismatch between
                         NDVect rank shape t
                 and
                         NDVect r s t             
         Specifically:
                 Type mismatch between
                         NDVect rank shape t
                 and
                         NDVect r s t

回答1:


I have also been wondering how to express n-dimensional vectors (i.e. tensors) in Idris. I had a play with the type definition in the question, but encountered various issues, so I expressed the NDVect function as a data type:

data NDVect : (rank : Nat) -> (shape : Vect rank Nat) -> Type -> Type where
  NDVZ : (value : t) -> NDVect Z [] t
  NDV  : (values : Vect n (NDVect r s t)) -> NDVect (S r) (n::s) t

And implemented map as follows:

nmap : (t -> u) -> (NDVect r s t) -> NDVect r s u
nmap f (NDVZ value) = NDVZ (f value)
nmap f (NDV values) = NDV (map (nmap f) values)

The following now works:

*Main> NDVZ 5
NDVZ 5 : NDVect 0 [] Integer
*Main> nmap (+4) (NDVZ 5)
NDVZ 9 : NDVect 0 [] Integer
*Main> NDV [NDVZ 1, NDVZ 2, NDVZ 3]
NDV [NDVZ 1, NDVZ 2, NDVZ 3] : NDVect 1 [3] Integer
*Main> nmap (+4) (NDV [NDVZ 1, NDVZ 2, NDVZ 3])
NDV [NDVZ 5, NDVZ 6, NDVZ 7] : NDVect 1 [3] Integer

Unfortunately, having all the type constructors makes things a bit ugly. I'd love to know if there's a cleaner way to solve this.

Edit:

Here's a slightly shorter type signature that doesn't explicitly encode the tensor rank in the type:

data NDVect : (shape : List Nat) -> Type -> Type where
  NDVZ : (value : t) -> NDVect [] t
  NDV  : (values : Vect n (NDVect s t)) -> NDVect (n::s) t

nmap : (t -> u) -> (NDVect s t) -> NDVect s u
nmap f (NDVZ value) = NDVZ (f value)
nmap f (NDV values) = NDV (map (nmap f) values)


来源:https://stackoverflow.com/questions/47558158/idris-map-an-operation-on-a-n-dimensional-vector

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!