I know the method rdd.firstwfirst() which gives me the first element in an RDD.
Also there is the method rdd.take(num) Which gives me the first "num" elemen
I tried this class to fetch an item by index. First, when you construct new IndexedFetcher(rdd, itemClass)
, it counts the number of elements in each partition of the RDD. Then, when you call indexedFetcher.get(n)
, it runs a job on only the partition that contains that index.
Note that I needed to compile this using Java 1.7 instead of 1.8; as of Spark 1.1.0, the bundled org.objectweb.asm within com.esotericsoftware.reflectasm cannot read Java 1.8 classes yet (throws IllegalStateException when you try to runJob a Java 1.8 function).
import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;
import scala.reflect.ClassTag;
public static class IndexedFetcher implements Serializable {
private static final long serialVersionUID = 1L;
public final RDD rdd;
public Integer[] elementsPerPartitions;
private Class> clazz;
public IndexedFetcher(RDD rdd, Class> clazz){
this.rdd = rdd;
this.clazz = clazz;
SparkContext context = this.rdd.context();
ClassTag intClassTag = scala.reflect.ClassTag$.MODULE$.apply(Integer.class);
elementsPerPartitions = (Integer[]) context.runJob(rdd, IndexedFetcher.countFunction(), intClassTag);
}
public static class IteratorCountFunction extends scala.runtime.AbstractFunction2, Integer> implements Serializable {
private static final long serialVersionUID = 1L;
@Override public Integer apply(TaskContext taskContext, scala.collection.Iterator iterator) {
int count = 0;
while (iterator.hasNext()) {
count++;
iterator.next();
}
return count;
}
}
static scala.Function2, Integer> countFunction() {
scala.Function2, Integer> function = new IteratorCountFunction();
return function;
}
public E get(long index) {
long remaining = index;
long totalCount = 0;
for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
if (remaining < elementsPerPartitions[partition]) {
return getWithinPartition(partition, remaining);
}
remaining -= elementsPerPartitions[partition];
totalCount += elementsPerPartitions[partition];
}
throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
}
public static class FetchWithinPartitionFunction extends scala.runtime.AbstractFunction2, E> implements Serializable {
private static final long serialVersionUID = 1L;
private final long indexWithinPartition;
public FetchWithinPartitionFunction(long indexWithinPartition) {
this.indexWithinPartition = indexWithinPartition;
}
@Override public E apply(TaskContext taskContext, scala.collection.Iterator iterator) {
int count = 0;
while (iterator.hasNext()) {
E element = iterator.next();
if (count == indexWithinPartition)
return element;
count++;
}
throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
}
}
public E getWithinPartition(int partition, long indexWithinPartition) {
System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
SparkContext context = rdd.context();
scala.Function2, E> function = new FetchWithinPartitionFunction(indexWithinPartition);
scala.collection.Seq