Shape-dependent computations in Scala ... and Agda!

Published: Oct 24, 2018 by Juan Manuel Serrano

In this post we will solve a little programming problem, mainly with the excuse of talking about dependent types. As usual, Scala will be our programming language of choice. However, this time we will also use Agda, a programming language which boasts full-fledged support for dependent types. The ultimate goal of this post is comparing both implementations and sharing our experiences with dependently typed programming.

You can find all the code in this post in the following Github repository.

## Our little problem

Let’s consider the following type of (non-empty) binary trees, implemented in Scala as a common algebraic data type:

``````sealed abstract class Tree[A]
case class Leaf[A](a : A) extends Tree[A]
case class Node[A](left: Tree[A], root: A, right: Tree[A]) extends Tree[A]
``````

We want to implement two functions that allow us to get and update the leaves of a given tree. As a first attempt (there will be several attempts more before we reach the solution, be patient!), we may come about with the following signatures:

``````class Leaves[A]{
def get(tree: Tree[A]): List[A] = ???
def update(tree: Tree[A]): List[A] => Tree[A] = ???
}
``````

The `get` function bears no problem: there may be one or several leaves in the input tree, and the resulting list can cope with that. The `update` function, however, while essentially being what we want, poses some problems. This method returns a function which updates the leaves of the tree given a list of new values for those nodes. Ideally, we would expect to receive a list with exactly as many values as leaves are there in the tree. But given this signature, this may not happen at all: we may receive less values or more. In the former case, we are forced to make a choice: either to return the original tree or throwing an exception (abandoning purity). In the latter, it would be fair to return the exceeding values, besides the updated tree. In sum, the following signature seems to be more compliant with the problem at hand:

``````class Leaves[A]{
def get(tree: Tree[A]): List[A] = ???
def update(tree: Tree[A]): List[A] => Option[(List[A], Tree[A])] = ???
}
``````

Essentially, the `update` method now returns a stateful computation, i.e. a value of the famous `StateT` monad. This computation is run by giving an initial list of values, and will finish with a value `None` (meaning that it couldn’t complete the computation) or `Some(l, t)`, i.e. the updated tree `t` and the list of exceeding values `l` (possibly, empty). We won’t show the implementation of these methods, but you can find it in the repository of this post.

Ok, this is nice, but we are stubborn and keep insisting on finding a way to prevent the user to pass a wrong number of values to the `update` method. I mean, we want to program the signature in such a way that the compiler throws an error if the programmer tries to call our function with less or more values than needed. Is it that possible?

## Solving the problem with dependent types

A possible signature that solves our problem is the following one:

``````def update(tree: Tree[A]): Vec[A, n_leaves(tree)] => Tree[A]
``````

where `n_leaves: Tree[A] => Integer` is a function that returns the number of leaves of the specified tree, and the `Vec` type represents lists of a fixed size. This signature gives the Scala compiler the required information to grant execution of the following call:

``````scala> update(Node(Leaf(1), 2, Leaf(3)))(Vec(3, 1))
res11: Tree[Int] = Node(Leaf(3), 2, Leaf(1))
``````

and block the following one instead, with a nice compiler error:

``````scala> update(Node(Leaf(1), 2, Leaf(3)))(Vec(3))
:18: error: type mismatch;
found : Vec[Int, 1]
required: Vec[Int, 2]
update(Node(Leaf(1), 2, Leaf(3)))(Vec(3))
``````

… wouldn’t this be beautiful?

Alas, the above signature is not legal Scala 2.12. The problem is in the `Vec[? , ? : Nat]` type constructor. As we said, it holds two parameters. There is no problem with the first one: type constructors in Scala do indeed receive types as arguments. Another way of saying this is that types in Scala can be parameterised with respect to types. And yet another way is saying that types in Scala can be made dependent on types. But the second parameter of the `Vec` constructor is not a type, it’s a value! And we can’t parameterise types in Scala with respect to values, only to types.

A type whose definition refers to values is called a dependent type. Indeed, the type `List[A]` in Scala also depends on something, to wit the type `A`. So, in a sense, we may rightfully call it a dependent type as well. However, the “dependent” qualifier is conventionally reserved for types that are parameterised with respect to values.

Can’t we solve our problem in Scala, then? Yes, we will see that we can indeed solve this problem in Scala, albeit in a different way. But before delving into the Scala solution, let’s see how we can solve this problem in a language with full-fledged dependent types, in line with the solution sketched at the beginning of this section.

## The solution in Agda

First, we must define the tree data type:

``````module Trees where
data Tree (A : Set) : Set where
leaf : A -> Tree A
node : Tree A -> A -> Tree A -> Tree A
``````

This a common algebraic data type definition, with constructors `leaf` and `node`. The definition is parameterised with respect to `A`, which is declared to be a regular type, i.e. `Set`. The resulting type `Tree A` is also a regular type (i.e. not a type constructor, which would be declared as `Set -> Set`). Next, we have to define the following function:

``````open import Data.Nat

n_leaves : {A : Set} -> Tree A -> ℕ
n_leaves (leaf _) = 1
n_leaves (node l _ r) = n_leaves l + n_leaves r
``````

The `n_leaves` function returns the number of leaves held by a given tree (as a natural number ℕ declared in the `Data.Nat` module). The implementation is based on pattern matching, using the same underscore symbol that we use in Scala whenever we are not interested in some value.

Let’s implement now the promised `get` and `update` functions, which will be part of a module named `Leaves`:

``````module Leaves where

open import Data.Vec
open Trees

get : {A : Set} -> (s : Tree A) -> Vec A (n_leaves s) = ?
update : {A : Set} -> (s : Tree A) -> Vec A (n_leaves s) -> Tree A = ?
``````

As you can see, we can now use the `n_leaves s` value in a type definition! Indeed, the `Vec (A : Set) (n : ℕ)` type is a truly dependent type. It represents lists of values of a fixed size `n`. Moreover, the size does not need to be a constant such as 1, 2, 3, etc. It can be the result of a function, as this example shows. The implications of this are huge, as we will soon realise.

Let’s expand the definition of the `get` function:

``````get : {A : Set} -> (s : Tree A) -> Vec A (n_leaves s)
get (leaf x) = x ∷ []
get (node l _ r) = get l ++ get r
``````

If the tree is a leaf, we just return its value in a vector of length one. Otherwise, we collect recursively the leaves of the left and right subtrees and return their concatenation. What would happen if we implemented the first clause in the pattern matching as `get (leaf x) = []` (i.e. if we attempted to return the empty vector for a leaf tree)? The compiler would complain with the following error:

``````0 != 1 of type .Agda.Builtin.Nat.Nat
when checking that the expression [] has type
Vec .A (n_leaves (leaf x))
``````

This error says that 0, i.e. the length of the empty vector `[]`, does not equal 1, i.e. the number of leaves of the input tree `leaf x`. All this while attempting to check that the proposed output `[]`, whose type is `Vec A 0`, has the required type `Vec .A (n_leaves (leaf x))`, i.e. `Vec A 1`. Similarly, in the second clause, the compiler will care itself to check that `n_leaves l + n_leaves r`, which is the resulting length of the vector concatenation `get l :: get r`, equals the value `n_leaves (node l _ r)`, which according to the definition of the `n_leaves` function is indeed the case. In sum, we can’t cheat the compiler and return a vector with a number of values different to the number of leaves in the input tree. This property is hardwired in the signature of the function, thanks to the expressiveness of the Agda type system. And to be able to guarantee that, Agda needs to be able to perform computations on values at compile time.

The implementation of the `update` function is similarly beautiful:

``````update : {A : Set} -> (s : Tree A) -> Vec A (n_leaves s) -> Tree A
update (leaf _) (x ∷ []) = leaf x
update (node l x r) v = node updatedL x updatedR
where
updatedL = update l (take (n_leaves l) v)
updatedR = update r (drop (n_leaves l) v)
``````

Note that in the first clause of the pattern matching, we were able to deconstruct the input vector into the shape `x ∷ []`, without the compiler complaining about missing clauses for the `leaf` constructor. This is because Agda knows (by evaluating the `n_leaves` function) that any possible leaf tree has a number of leaves equals to one. In the second clause, the input vector has type `v : Vec A (n_leaves (node l x r))`, which Agda knows to be `v : Vec A (n_leaves l + n_leaves r)` by partially evaluating the `n_leaves` function. This is what makes the subsequent calls to update the left and right subtrees typesafe. Indeed, to update the left subtree `l` we need a vector with a number of elements equal to its number of leaves `n_leaves l`. This vector has to be a subvector of the input vector `v`, which Agda knows to have length `n_leaves l + n_leaves r` as we mentioned before. So, the expression `take (n_leaves l) v` will compile without problems. Similarly, Agda knows that the length of the `drop (n_leaves l) v` vector will be `n_leaves r` (by checking the definition of the concatenation function `++`), which is precisely what the `update r` function needs.

Let’s exercise these definitions in the following module:

``````module TestLeaves where

open import Data.Nat
open import Data.Vec
open Trees

t1 : Tree ℕ
t1 = node (node (leaf 1) 2 (leaf 3)) 4 (leaf 5)

l1 : Vec ℕ 3
l1 = Leaves.get t1

t2 : Tree ℕ
t2 = Leaves.update t1 (5 ∷ 3 ∷ 1 ∷ [])

// CHECK

open import Relation.Binary.PropositionalEquality

eq1 : l1 ≡ (1 ∷ 3 ∷ 5 ∷ [])
eq1 = refl

eq2 : t2 ≡ (node (node (leaf 5) 2 (leaf 3)) 4 (leaf 1))
eq2 = refl

-- WON'T COMPILE

{- Error: 3 != 4 of type ℕ
when checking that the expression get t1 has type Vec ℕ 4

l2 : Vec ℕ 4
l2 = Leaves.get t1
-}

{- Error: 0 != 2 of type ℕ
when checking that the expression [] has type Vec ℕ 2

t3 : Tree ℕ
t3 = Leaves.update t1 (5 ∷ [])
-}
``````

The `l1` variable represents the leaves of the sample tree `t1`, namely values 1, 3 and 5. Accordingly, the type of the variable is `Vec ℕ 3`. The variable `t2` is the result of updating the tree with a new collection of leaves. In both cases, we make reference to the functions `get` and `update` declared in the module `Leaves`.

The next lines prove that the values of these variables are the expected ones, making use of the equality type constructor `_≡_` and its `refl` constructor (note that `_≡_` is parameterised with respect two values, so it’s a dependent type). The proof is plain `refl`exivity, i.e. `x ≡ x`, since `l1` and `t2` actually evaluate to the same values.

Note that the fact that this code compiles is enough to show that the tests pass. We don’t need to run anything! On the other hand, Agda allows us to test that our functions work as expected by implementing much more complex proofs for more expressive properties. We will leave that for another post.

Let’s come back to Scala.

## The solution in Scala

We can’t make computations on values in Scala at compile time, but we can do it on types! And this suffices to solve our problem, albeit in a different form to Agda. We will reconcile both approaches in the next section.

Type-level computation in Scala proceeds through the implicits mechanism. But before we can exploit implicits, we first need to re-implement our `Tree` data type so that we don’t loose the shapes of trees:

``````sealed abstract class Tree[A]
case class Leaf[A](value: A) extends Tree[A]
case class Node[L <: Tree[A], A, R <: Tree[A]]( left: L, root: A, right: R) extends Tree[A]
``````

This new implementation differs with the previous one in the types of the recursive arguments of the `Node` constructor. Now, they are generic parameters `L` and `R`, declared to be subtypes of `Tree[A]`, i.e. either leaves or nodes. Essentially, this allows us to preserve the exact type of the tree; what we will call its shape. In essence, this is the same trick commonly used to implement heterogeneous lists in Scala (see, e.g. their implementation in the shapeless framework). For instance, let’s compare both implementations in the REPL, with the old implementation of the `Tree` data type located in the `P` module, and the new one in the current scope:

``````scala> val p_tree = P.Node(P.Node(P.Leaf(1), 2, P.Leaf(3)), 4, P.Leaf(5))
p_tree: P.Node[Int] = ...

scala> val tree = Node(Node(Leaf(1), 2, Leaf(3)), 4, Leaf(5))
tree: Node[Node[Leaf[Int], Int, Leaf[Int]], Int, Leaf[Int]] = ...
``````

As we can see, the type of `p_tree` is simply `Node[Int]`, whereas the type of `tree` is much more informative: we don’t only know that it is a node tree; we know that it holds exactly five elements, three of which are leaves. Its shape has not been lost.

We can apply the same trick to the `List` type, in order to preserve information about the shape of list instances (essentially, how many values it stores). This is the resulting definition:

``````sealed abstract class List[A]
case class Nil[A]() extends List[A]
case class Cons[A, T <: List[A]](head: A, tail: T) extends List[A]
``````

Let’s see now how can we exploit these shape-aware, algebraic data types, to support shape-dependent, type-level computations … and finally solve our little problem. Recall the original signatures for the `get/update` functions, which built upon the common, non-shape aware definitions of the `Tree` and `List` data types:

``````class Leaves[A]{
def get(tree: Tree[A]): List[A] = ???
def update(tree: Tree[A]): List[A] => Tree[A] = ???
}
``````

Now we can explain their limitations in a more precise way. For instance, let’s consider the resulting function of `update`. The input of this function is declared to be any `List[A]`, not lists of a particular shape. That’s relevant to our problem because we want the compiler to be able to block invocations for trees of an undesired shape, i.e. length. But how can we represent the shape of an algebraic data type in the Scala type system? The answer is subtyping, i.e. we can declare the result of that function to be some `L <: List[A]`, instead of a plain `List[A]`. There is a one-to-one correspondence between the subtypes of the algebraic data type `List[A]` and its possible shapes.

Similarly, the input trees of `get` and `update` are declared to be any `Tree[A]`, instead of trees of a particular shape `T <: Tree[A]`. This is bad, because in that way we won’t be able to determine which is the exact list shape that must be returned for a given tree. Ok, but how can we determine the shape of list corresponding to a given shape of tree? The answer is using type-level functions which operates on input/output types that represent shapes.

These shape-dependent functions are declared as traits and defined through the implicits mechanism. For instance, the declaration of the type-level function between trees and lists is as follows:

``````trait LeavesShape[In <: Tree[A]]{
type Out <: List[A] def get(t: In): Out
def update(t: In): Out => In
}
``````

The `LeavesShape` trait is parameterised with respect to any shape of tree. Its instance for a particular shape will give us the list shape that we can use to store the current leaves of the tree, or the new values required for those leaves. Moreover, for that particular shape of tree we also obtain its corresponding get and update implementations.

Concerning the implementation of the shape-dependent function `LeavesShape`, i.e. how do we compute the shape of list corresponding to a given shape of tree, we proceed through implicits defined in its companion object. The following signatures (not for the faint of heart …) suffice:

``````object LeavesShape{
type Output[T <: Tree[A], _Out] = LeavesShape[T]{ type Out = _Out }

implicit def leafCase: Output[Leaf[A], Cons[A, Nil[A]]] = ???

implicit def nodeCase[
L <: Tree[A],
R <: Tree[A],
LOut <: List[A],
ROut <: List[A]](implicit
ShapeL: Output[L, LOut],
ShapeR: Output[R, ROut],
Conc: Concatenate[A, LOut, ROut]
): Output[Node[L, A, R], Conc.Out] = ???
}
``````

We omit the implementations of the `get` and `update` functions to focus on the list shape computation, which is shown through the type alias `Output`. The first case is easy: the shape of list which we need to hold the leaves of a tree of type `Leaf[A]` is the one that allows us to store a single element of type `A`, i.e. `Cons[A, Nil[A]]`. For arbitrary node trees, the situation is in appearance more complicated, though conceptually simple. Given a tree of shape `Node[L, A, R]`, we first need to know the list shapes for the left and right subtrees `L` and `R`. The implicit arguments `ShapeL` and `ShapeR` provide us with the `LOut` and `ROut` shapes. The resulting list shape will be precisely their concatenation, which we achieve through an auxiliary type-level function `Concatenate` (not shown for brevity, but implemented in a similar way). The shape concatenation will be accessible through the `Out` type member variable of that function. The `Conc.Out` type is an example of path-dependent type, a truly dependent type since it depends on the value `Conc` obtained through the implicits mechanism.

We are about to finish. The only thing that is needed is some way to call the `get` and `update` member functions of the `LeavesShape` type-level function, for a given tree value. We achieve that with two auxiliary definitions, located in a definitive `Leaves` module (where the type-level function and its companion object are also implemented):

``````class Leaves[A]{
def get[In <: Tree[A]](t : In)(implicit S: LeavesShape[In]): S.Out = S.get(t)
def update[In <: Tree[A]](t : In)(implicit S: LeavesShape[In]): S.Out => In = S.update(t)

trait LeavesShape[In <: Tree[A]]{ ... } object LeavesShape{ ... }
}
``````

The auxiliary functions `get` and `update` are the typesafe counterparts of the original signatures. The first difference that we may emphasise is that the type of input trees is not a plain, uninformative `Tree[A]`, but a particular shape of tree `In`. The compiler can then use this shape as input to the type-level function `LeavesShape`, to compute the shape of the resulting list `S.Out`. The output of these functions is thus declared as a path-dependent type. Last, note that the implementation of these functions is wholly delegated to the corresponding implementations of the inferred type-level function. Let’s see how this works in the following REPL session:

``````scala> val tree = Node(Node(Leaf(1), 2, Leaf(3)), 4, Leaf(5))
tree: Node[Node[Leaf[Int], Int, Leaf[Int]], Int, Leaf[Int]] = ...

scala> get(tree)
res2: Cons[Int, Cons[Int, Cons[Int, Nil[Int]]]] = Cons(1,Cons(3,Cons(5,Nil())))

scala> update(tree).apply(Cons(5, Cons(3, Cons(1, Nil[Int]()))))
res1: Node[Node[Leaf[Int], Int, Leaf[Int]], Int, Leaf[Int]] =
Node(Node(Leaf(5), 2, Leaf(3)), 4, Leaf(1))

scala> update(tree).apply(Cons(5, Nil[Int]()))
:22: error: type mismatch;
found : Nil[Int]
required: Cons[Int, Cons[Int, Nil[Int]]]
update(tree).apply(Cons(5, Nil[Int]())
^
``````

As expected, when we pass lists of the right shape, everything works. On the contrary, as shown in the last example, if we pass a list of the wrong size, the compiler will complain. In particular, the error message tells us that it found a list of type `Nil[Int]` where it expected a list of size two. This is because `update(tree)` returns a list of shape three, and we only pass a list of size one. This is exactly the same behaviour that we got with the Agda implementation.

## Reconciling Scala and Agda

The Scala and Agda implementations seem very different. In Scala, we exploit the expressiveness of its type system to preserve the shape of algebraic data type values, and perform type-level, shape-dependent computations at compile time. In Agda, we exploit its capability to declare full-fledged dependent types, and perform value-level computations at compile time.

Nonetheless, let’s recall the signatures of both implementations and try to reconcile their differences:

``````-- AGDA VERSION

module Leaves where
open import Data.Vec
open Trees

get : {A : Set} -> (s : Tree A) -> Vec A (n_leaves s) = ?
update : {A : Set} -> (s : Tree A) -> Vec A (n_leaves s) -> Tree A = ?
``````
``````// SCALA VERSION

class Leaves[A]{
def get[In <: Tree[A]](t : In)(implicit S: LeavesShape[In]): S.Out = S.get(t)
def update[In <: Tree[A]](t : In)(implicit S: LeavesShape[In]): S.Out => In = S.update(t)

trait LeavesShape[In <: Tree[A]]{
type Out <: List[A] def get(t: In): Out
def update(t: In): Out => In
}

object LeavesShape{ ... }
}
``````

In a sense, the Scala signature is simpler: there is no need to use a different type `Vec (A : Set) n : Nat`. The very same algebraic data type `List[A]` (albeit implemented in a shape-aware fashion), and subtyping suffice for representing shapes. In Agda, the new vector type is introduced precisely to represent the shapes of lists, which are in one to one correspondence with the natural numbers.

The `#length` function is then used to compute the required shape for a given tree. In Scala, there is no particular need for that, since the shape is computed along the implementation of the `get` and `update` functions in the type-level function `LeavesShape`.

The downside of the Scala implementation is, evidently, its verbosity and the amount of techniques and tricks involved: path-dependent types, traits, subtyping, implicits, auxiliary functions, … This is a recognised problem which is being tackled for the future Scala 3.0 version.

## Conclusion

We may have mimicked the Agda implementation style in Scala. In the `shapeless` framework, for instance, we have available the `Sized` and `Nat` types to represent lists of a fixed size (see the implementation here), and we may even use literal types to overcome the limitation of using values in type declarations. Alternatively, we proposed an implementation fully based on shape-aware algebraic data types. This version is in our opinion more idiomatic to solve our particular problem in Scala. It also allows us to grasp the idiosyncrasy of Scala with respect to competing approaches like the one proposed in Agda. In this regard, we found the notion of shape to be extremely useful.

In next posts we will likely go on exploring Agda in one of its most characteristic applications: certified programming. For instance, we may generalise the example shown in this post and talk about traversals (a kind of optic, like lenses) and its laws. One of these laws, applied to our example, tells us that if you update the leaves of the tree with its current leaf values, you will obtain the same tree. Using Agda, we can state that law and prove that our implementation satisfies it. No need to enumerate test cases, or empirically test the given property (e.g., as in Scalacheck). Till the next post!

Theme built by C.S. Rhymes