# Tail Recursive Functions (in Scala)

Turning imperative algorithms to tail-recursive functions isn’t necessarily obvious. In this article (and video) I’m showing you the trick you need, and in doing so, we’ll discover the Zen of Functional Programming.

## The Trick #

Let’s start with a simple function that calculates the length of a list:

def len(l: List[_]): Int =
l match {
case Nil => 0
case _ :: tail => len(tail) + 1
}


It’s a recursive function with a definition that is mathematically correct. However, if we try to test it, this will fail with a StackOverflowError:

len(List.fill(100000)(1))


The problem is that the input list is too big. And because the VM still has work to do after that recursive call, needing to do a + 1, the call isn’t in “tail position”, so the call-stack must be used. A StackOverflowError is a memory error, and in this case it’s a correctness issue, because the function will fail on reasonable input.

First let’s describe it as a dirty while loop instead:

def len(l: List[_]): Int = {
var count = 0
var cursor = l

while (cursor != Nil) {
count += 1
cursor = cursor.tail
}
count
}


THE TRICK for turning such functions into tail-recursions is to turn those variables, holding state, into function parameters.

def len(l: List[_]): Int = {
// Using an inner function to encapsulate this implementation
@tailrec
def loop(cursor: List[_], count: Int): Int =
cursor match {
// Our end condition, copied after that while
case Nil => count
case _ :: tail =>
// Copying the same logic from that while statement
loop(cursor = tail, count = count + 1)
}
// Go, go, go
loop(l, 0)
}


Now this version is fine. Note the use of the @tailrec annotation — all this annotation does is to make the compiler throw an error in case the function is not actually tail-recursive. That’s because that call is error-prone, and it needs repeating, this is an issue of correctness.

Let’s do a more complex example to really internalize this. Let’s calculate the N-th number in the Fibonacci sequence — here’s the memory unsafe recursive version:

def fib(n: Int): BigInt =
if (n <= 0) 0
else if (n == 1) 1
else fib(n - 1) + fib(n - 2)

fib(0) // 0
fib(1) // 1
fib(2) // 1
fib(3) // 2
fib(4) // 3
fib(5) // 5

fib(100000) // StackOverflowError (also, really slow)


First turn this into a dirty while loop:

def fib(n: Int): BigInt = {
// Kids, don't do this at home 😅
if (n <= 0) return 0
// Going from 0 to n, instead of vice-versa
var a: BigInt = 0 // instead of fib(n - 2)
var b: BigInt = 1 // instead of fib(n - 1)
var i = n

while (i > 1) {
val tmp = a
a = b
b = tmp + b
i -= 1
}
b
}


Then turn its 3 variables into function parameters:

def fib(n: Int): BigInt = {
@tailrec
def loop(a: BigInt, b: BigInt, i: Int): BigInt =
// first condition
if (i <= 0) 0
// end of while loop
else if (i == 1) b
// logic inside while loop statement
else loop(a = b, b = a + b, i = i - 1)

loop(0, 1, n)
}


## (Actual) Recursion #

Tail-recursions are just loops. But some algorithms are actually recursive, and can’t be described via a while loop that uses constant memory. What makes an algorithm actually recursive is usage of a stack. In imperative programming, for low-level implementations, that’s how you can tell if recursion is required … does it use a manually managed stack or not?

But even in such cases we can use a while loop, or a @tailrec function. Doing so has some advantages. Let’s start with a Tree data-structure:

sealed trait Tree[+A]

case class Node[+A](value: A, left: Tree[A], right: Tree[A])
extends Tree[A]
case object Empty
extends Tree[Nothing]


Defining a fold, which we could use to sum-up all values for example, will be challenging:

def foldTree[A, R](tree: Tree[A], seed: R)(f: (R, A) => R): R =
tree match {
case Empty => seed
case Node(value, left, right) =>
// Recursive call for the left child
val leftR = foldTree(left, f(seed, value))(f)
// Recursive call for the right child
foldTree(right, leftR)(f)
}


This is the simple version. And it should be clear that the size of the call-stack will be directly proportional to the height of the tree. And turning it into a @tailrec version means we need to manually manage a stack:

def foldTree[A, R](tree: Tree[A], seed: R)(f: (R, A) => R): R = {
@tailrec def loop(stack: List[Tree[A]], state: R): R =
stack match {
// End condition, nothing left to do
case Nil => state
// Ignore empty elements
case Empty :: tail => loop(tail, state)
// Step in our loop
case Node(value, left, right) :: tail =>
// Adds left and right nodes to stack, evolves the state
loop(left :: right :: tail, f(state, value))
}
// Go, go, go!
loop(List(tree), seed)
}


If you want to internalize this notion — recursion == usage of a stack — a great exercise is the backtracking algorithm. Implement it with recursive functions, or with dirty loops and a manually managed stack, and compare. The plot thickens for backtracking solutions using 2 stacks 🙂

Does this manually managed stack buy us anything?

Well yes, if you need such recursive algorithms, such a stack can take up your whole heap memory, which means it can handle a bigger input. But note that with the right input, your process can still blow up, this time with an out-of-memory error (OOM).

NOTE — in real life, shining examples of algorithms using manually managed stacks are Cats-Effect’s IO and Monix’s Task, since they literally replace the JVM’s call-stack 😄

## Zen of Functional Programming? #

In FP, you turn variables into (immutable) function parameters. And state gets evolved via function calls 💡

That’s it, that’s all there is to FP (plus the design patterns, and the pain of dealing with I/O 🙂).

Enjoy!

| Written by
Tags: Algorithms | FP | Scala | Video