Scala: Get sum of nth element from tuple array/RDD

一个人想着一个人 提交于 2019-12-22 06:41:28

问题


I have a array of tuple like this:

val a = Array((1,2,3), (2,3,4))

I want to write a generic method for a method like below:

def sum2nd(aa: Array[(Int, Int, Int)]) = {
      aa.map { a => a._2 }.sum
      }

So what I am looking for a method like:

def sumNth(aa: Array[(Int, Int, Int)], n: Int)

回答1:


There are a few ways you can go about this. The simplest is to use productElement:

def unsafeSumNth[P <: Product](xs: Seq[P], n: Int): Int =
  xs.map(_.productElement(n).asInstanceOf[Int]).sum

And then (note that indexing starts at zero, so n = 1 gives us the second element):

scala> val a = Array((1, 2, 3), (2, 3, 4))
a: Array[(Int, Int, Int)] = Array((1,2,3), (2,3,4))

scala> unsafeSumNth(a, 1)
res0: Int = 5

This implementation can crash at runtime in two different ways, though:

scala> unsafeSumNth(List((1, 2), (2, 3)), 3)
java.lang.IndexOutOfBoundsException: 3
  at ...

scala> unsafeSumNth(List((1, "a"), (2, "b")), 1)
java.lang.ClassCastException: java.lang.String cannot be cast to java.lang.Integer
  at ...

I.e., if the tuple doesn't have enough elements, or if the element you're asking for isn't an Int.

You can write a version that doesn't crash at runtime:

import scala.util.Try

def saferSumNth[P <: Product](xs: Seq[P], n: Int): Try[Int] = Try(
  xs.map(_.productElement(n).asInstanceOf[Int]).sum
)

And then:

scala> saferSumNth(a, 1)
res4: scala.util.Try[Int] = Success(5)

scala> saferSumNth(List((1, 2), (2, 3)), 3)
res5: scala.util.Try[Int] = Failure(java.lang.IndexOutOfBoundsException: 3)

scala> saferSumNth(List((1, "a"), (2, "b")), 1)
res6: scala.util.Try[Int] = Failure(java.lang.ClassCastException: ...

This is an improvement, since it forces callers to address the possibility of failure, but it's also kind of annoying, since it forces callers to address the possibility of failure.

If you're willing to use Shapeless you can have the best of both worlds:

import shapeless._, shapeless.ops.tuple.At

def sumNth[P <: Product](xs: Seq[P], n: Nat)(implicit
  atN: At.Aux[P, n.N, Int]
): Int = xs.map(p => atN(p)).sum

And then:

scala> sumNth(a, 1)
res7: Int = 5

But the bad ones don't even compile:

scala> sumNth(List((1, 2), (2, 3)), 3)
<console>:17: error: could not find implicit value for parameter atN: ...

This still isn't perfect, though, since it means the second argument has to be a literal number (since it needs to be known at compile time):

scala> val x = 1
x: Int = 1

scala> sumNth(a, x)
<console>:19: error: Expression x does not evaluate to a non-negative Int literal
       sumNth(a, x)
                 ^

In many cases that's not a problem, though.

To sum up: If you're willing to take responsibilty for reasonable code crashing your program, use productElement. If you want a little more safety (at the cost of some inconvenience), use productElement with Try. If you want compile-time safety (but some limitations), use Shapeless.




回答2:


You could do something like this, though it's not really type safe:

  def sumNth(aa: Array[Product], n: Int)= {
    aa.map { a =>
      a.productElement(n) match {
        case i:Int => i
        case _ => 0
      }
    }.sum
  }

sumNth(Array((1,2,3), (2,3,4)), 2) // 7



回答3:


Antoher typesafe way to do it without using shapeless is to provide a function to extract the element you need:

def sumNth[T, E: Numeric](array: Array[T])(extract: T => E) =
  array.map(extract).sum

Then you can define sum2nd like this:

def sum2nd(array: Array[(Int, Int, Int)]): Int = sumNth(array)(_._2)

Or like this:

val sum2nd: Array[(Int, Int, Int)] => Int = sumNth(_)(_._2)


来源:https://stackoverflow.com/questions/35091354/scala-get-sum-of-nth-element-from-tuple-array-rdd

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!