Scala: generic weighted average function

后端 未结 2 1830
旧时难觅i
旧时难觅i 2020-12-11 17:00

I want to implement a generic weighted average function which relaxes the requirement on the values and the weights being of the same type. ie, I want to support sequences o

2条回答
  •  夕颜
    夕颜 (楼主)
    2020-12-11 17:09

    First of all your template is wrong. (And sorry if the "template" expression is wrong - I am a novice in scala). Your functions expect tuples where both elements are of the same type ( [A: Numeric] ) , instead of tuples where elements are of different types ( [A: Numeric, B: Numeric] ) ( (Int, Float) vs (Float, Float) )

    Anyway the below compiles and hopefully will work well after you fill it with your desired calculus.

    import scala.collection._
    
    def weightedSum[A: Numeric, B: Numeric](weightedValues: GenSeq[(A,B)]): (A,B) = {
      weightedValues.foldLeft((implicitly[Numeric[A]].zero, implicitly[Numeric[B]].zero)) { (z, t) =>
        (   implicitly[Numeric[A]].plus(z._1, t._1), 
            implicitly[Numeric[B]].plus(z._2, t._2)
        ) 
      } 
    }
    
    def weightedAverage[A: Numeric, B: Numeric](weightedValues: GenSeq[(A,B)]): A = {
      val (weightSum, weightedValueSum) = weightedSum(weightedValues)
      implicitly[Numeric[A]] match {
        case num: Fractional[A] => implicitly[Numeric[A]].zero
        case num: Integral[A] => implicitly[Numeric[A]].zero
        case _ => sys.error("Undivisable numeric!")
      }
    }
    
    val values1: Seq[(Float, Float)] = List((1, 2f), (1, 3f))
    val values2: Seq[(Int, Float)] = List((1, 2f), (1, 3f))
    
    val wa1 = weightedAverage(values1)
    val wa2 = weightedAverage(values2)
    

提交回复
热议问题