package com.uncarved.helpers

/**
 *  Factory object for producing generalized memoized functions of up to four
 *  arguments.
 *
 *  Usage:
 *  <code><pre>
 *  val circArea =
 *      Memoize((r:Double)=>Math.Pi * Math.pow(r,2))
 *
 *  val area = circArea(2)
 *
 *  val cylVol = Memoize {
 *      (r:Double, h:Double) => 
 *        val vol = Math.Pi * Math.pow(r,2) * h
 *        println("Radius: " + r + " Height: " +
 *            h + " Volume: " + vol)
 *        vol
 *  }
 *
 *  val vol = cylVol(2, 3.5)
 *  val twoRCyl = cylVol.curry(2) //We can partially apply as normal
 *  val vol2 = twoRCyl(3.5)       //...and the memoization cache is shared
 *
 *  def sphereVol(r: Double) = Math.Pi * Math.pow(r,3)
 *
 *  val memoSV = Memoize(sphereVol _)
 *  </pre></code>
 *
 *  @see http://www.itl.nist.gov/div897/sqg/dads/HTML/memoize.html
 *  @see http://www.uncarved.com/blog/memoization.mrk
 *
 **/
object Memoize {
    /**
     * A memoized single-argument function
     **/
    class MemoizedFunction1[-T, +R](f: T => R) extends (T => R) {
        import scala.collection.mutable
        private[this] val cache = mutable.Map.empty[T,R]

        def apply(x: T) : R = {
            try {
                cache(x)
            }
            catch {
                case _ : NoSuchElementException =>
                    val res = f(x)
                    cache + ((x->res))
                    res
            }
        }
    }

    /**
     * A memoized two-argument function
     **/
    class MemoizedFunction2[-T1,-T2,+R](f: (T1,T2)=>R)
        extends ((T1,T2)=>R) {
        import scala.collection.mutable
        private[this] val cache = mutable.Map.empty[(T1,T2),R]

        def apply(x: T1, y: T2) : R = {
            val key = (x->y)
            try {
                cache(key)
            }
            catch {
                case _ : NoSuchElementException =>
                    val res = f(x,y)
                    cache + ((key->res))
                    res
            }
        }
    }


    /**
     * A memoized three-argument function
     **/
    class MemoizedFunction3[-T1,-T2,-T3,+R](f: ((T1,T2,T3)=>R))
        extends ((T1,T2,T3)=>R) {
        import scala.collection.mutable
        private[this] val cache = mutable.Map.empty[(T1,T2,T3),R]

        def apply(x: T1, y: T2, z: T3) : R = {
            val key = (x,y,z)
            try {
                cache(key)
            }
            catch {
                case _ : NoSuchElementException =>
                    val res = f(x,y,z)
                    cache + ((key->res))
                    res
            }
        }
    }


    /**
     * A memoized four-argument function
     **/
    class MemoizedFunction4[-T1,-T2,-T3,-T4,+R](f: ((T1,T2,T3,T4)=>R))
        extends ((T1,T2,T3,T4)=>R) {
        import scala.collection.mutable
        private[this] val cache = mutable.Map.empty[(T1,T2,T3,T4),R]

        def apply(a: T1, b: T2, c: T3, d: T4) : R = {
            val key = (a,b,c,d)
            try {
                cache(key)
            }
            catch {
                case _ : NoSuchElementException =>
                    val res = f(a,b,c,d)
                    cache + ((key->res))
                    res
            }
        }
    }

    def apply[T,R](f: T=>R) = new MemoizedFunction1(f)
    def apply[T1,T2,R](f: (T1,T2)=>R) = new MemoizedFunction2(f)
    def apply[T1,T2,T3,R](f: (T1,T2,T3)=>R) = new MemoizedFunction3(f)
    def apply[T1,T2,T3,T4,R](f: (T1,T2,T3,T4)=>R) = new MemoizedFunction4(f)
}

// vim: set ts=4 sw=4 noet: