Each time a function is called, if it\'s result for a given set of argument values is not yet memoized I\'d like to put the result into an in-memory table. One column is mea
When using mutable map for memoization, one shall keep in mind that this would cause typical concurrency problems, e.g. doing a get when a write has not completed yet. However, thread-safe attemp of memoization suggests to do so it's of little value if not none.
The following thread-safe code creates a memoized fibonacci function, initiates a couple of threads (named from 'a' through to 'd') that make calls to it. Try the code a couple of times (in REPL), one can easily see f(2) set gets printed more than once. This means a thread A has initiated the calculation of f(2) but Thread B has totally no idea of it and starts its own copy of calculation. Such ignorance is so pervasive at the constructing phase of the cache, because all threads see no sub solution established and would enter the else clause.
object ScalaMemoizationMultithread {
// do not use case class as there is a mutable member here
class Memo[-T, +R](f: T => R) extends (T => R) {
// don't even know what would happen if immutable.Map used in a multithreading context
private[this] val cache = new java.util.concurrent.ConcurrentHashMap[T, R]
def apply(x: T): R =
// no synchronized needed as there is no removal during memoization
if (cache containsKey x) {
Console.println(Thread.currentThread().getName() + ": f(" + x + ") get")
cache.get(x)
} else {
val res = f(x)
Console.println(Thread.currentThread().getName() + ": f(" + x + ") set")
cache.putIfAbsent(x, res) // atomic
res
}
}
object Memo {
def apply[T, R](f: T => R): T => R = new Memo(f)
def Y[T, R](F: (T => R) => T => R): T => R = {
lazy val yf: T => R = Memo(F(yf)(_))
yf
}
}
val fibonacci: Int => BigInt = {
def fiboF(f: Int => BigInt)(n: Int): BigInt = {
if (n <= 0) 1
else if (n == 1) 1
else f(n - 1) + f(n - 2)
}
Memo.Y(fiboF)
}
def main(args: Array[String]) = {
('a' to 'd').foreach(ch =>
new Thread(new Runnable() {
def run() {
import scala.util.Random
val rand = new Random
(1 to 2).foreach(_ => {
Thread.currentThread().setName("Thread " + ch)
fibonacci(5)
})
}
}).start)
}
}