How to get element by Index in Spark RDD (Java)

前端 未结 3 1744
失恋的感觉
失恋的感觉 2020-12-01 06:52

I know the method rdd.firstwfirst() which gives me the first element in an RDD.

Also there is the method rdd.take(num) Which gives me the first "num" elemen

3条回答
  •  一整个雨季
    2020-12-01 07:07

    I tried this class to fetch an item by index. First, when you construct new IndexedFetcher(rdd, itemClass), it counts the number of elements in each partition of the RDD. Then, when you call indexedFetcher.get(n), it runs a job on only the partition that contains that index.

    Note that I needed to compile this using Java 1.7 instead of 1.8; as of Spark 1.1.0, the bundled org.objectweb.asm within com.esotericsoftware.reflectasm cannot read Java 1.8 classes yet (throws IllegalStateException when you try to runJob a Java 1.8 function).

    import java.io.Serializable;
    
    import org.apache.spark.SparkContext;
    import org.apache.spark.TaskContext;
    import org.apache.spark.rdd.RDD;
    
    import scala.reflect.ClassTag;
    
    public static class IndexedFetcher implements Serializable {
        private static final long serialVersionUID = 1L;
        public final RDD rdd;
        public Integer[] elementsPerPartitions;
        private Class clazz;
        public IndexedFetcher(RDD rdd, Class clazz){
            this.rdd = rdd;
            this.clazz = clazz;
            SparkContext context = this.rdd.context();
            ClassTag intClassTag = scala.reflect.ClassTag$.MODULE$.apply(Integer.class);
            elementsPerPartitions = (Integer[]) context.runJob(rdd, IndexedFetcher.countFunction(), intClassTag);
        }
        public static class IteratorCountFunction extends scala.runtime.AbstractFunction2, Integer> implements Serializable {
            private static final long serialVersionUID = 1L;
            @Override public Integer apply(TaskContext taskContext, scala.collection.Iterator iterator) {
                int count = 0;
                while (iterator.hasNext()) {
                    count++;
                    iterator.next();
                }
                return count;
            }
        }
        static  scala.Function2, Integer> countFunction() {
            scala.Function2, Integer> function = new IteratorCountFunction();
            return function;
        }
        public E get(long index) {
            long remaining = index;
            long totalCount = 0;
            for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
                if (remaining < elementsPerPartitions[partition]) {
                    return getWithinPartition(partition, remaining);
                }
                remaining -= elementsPerPartitions[partition];
                totalCount += elementsPerPartitions[partition];
            }
            throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
        }
        public static class FetchWithinPartitionFunction extends scala.runtime.AbstractFunction2, E> implements Serializable {
            private static final long serialVersionUID = 1L;
            private final long indexWithinPartition;
            public FetchWithinPartitionFunction(long indexWithinPartition) {
                this.indexWithinPartition = indexWithinPartition;
            }
            @Override public E apply(TaskContext taskContext, scala.collection.Iterator iterator) {
                int count = 0;
                while (iterator.hasNext()) {
                    E element = iterator.next();
                    if (count == indexWithinPartition)
                        return element;
                    count++;
                }
                throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
            }
        }
        public E getWithinPartition(int partition, long indexWithinPartition) {
            System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
            SparkContext context = rdd.context();
            scala.Function2, E> function = new FetchWithinPartitionFunction(indexWithinPartition);
            scala.collection.Seq partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
            ClassTag classTag = scala.reflect.ClassTag$.MODULE$.apply(this.clazz);
            E[] result = (E[]) context.runJob(rdd, function, partitions, true, classTag);
            return result[0];
        }
    }
    
        

    提交回复
    热议问题