Thursday, 21 January 2016

Multiplication table using matrix in spark

In this post, I am sharing a code snippet to demonstrate multiplication table in spark.

In addition, it also demonstrates the following:

  • How to create a custom Spark RDD
  • How to create an Iterator/Generator
  • How to work with matrices in spark

Assumption :

  • You are using JDK 8.
  • You added spark and spark-mllib libraries to the classpath. I assume you use something like Maven and add these dependencies.



import org.apache.spark.Partition;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.rdd.RDD;
import scala.collection.AbstractIterator;
import scala.collection.Iterator;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ClassManifestFactory$;

/**
 * This class illustrates multiplication table on spark!
 * @author  Thamme Gowda N
 *
 */
public class Main {

    /**
     * This RDD has numbers in it.
     * This RDD can be used for illustrating math operations on Spark
     */
    public static class SequenceRDD extends RDD<Long>{

        private long start;
        private long end;
        private Partition[] partitions;

        public SequenceRDD(SparkContext _sc, long start, long end) {
            super(_sc, new ArrayBuffer<> (),
                    ClassManifestFactory$.MODULE$.fromClass(Long.class));
            this.start = start;
            this.end = end;
            this.partitions = new Partition[]{
                    () -> 0 //just one part at index 0
            };
        }

        @Override
        public Iterator<Long> compute(Partition split, TaskContext ctx) {
            //we have only one part, so this is it
            return  new SequenceIterator(start, end);
        }

        @Override
        public Partition[] getPartitions() {
            return partitions;
        }
    }

    /**
     * This iterator yields numbers so we can build an RDD on it
     */
    private static class SequenceIterator
            extends AbstractIterator<Long>
            implements java.util.Iterator<Long> {

        private long nextStart;
        private long end;

        /**
         * Number generator for [start, end]
         * @param start the start
         * @param end the end, inclusive
         */
        public SequenceIterator(long start, long end) {
            this.nextStart = start;
            this.end = end;
            assert end >= start : "Invalid Args";
        }

        @Override
        public boolean hasNext(){
            return nextStart <= end;
        }

        @Override
        public Long next() {
            return nextStart++;
        }
    }

    public static void main(String[] args) {
        long n = 1_000;
        String outpath = "multiplication-matrix";

        long st = System.currentTimeMillis();
        SparkConf conf = new SparkConf()
                .setAppName("Large Matrix")
                .setMaster("local[2]");
        JavaSparkContext ctx = new JavaSparkContext(conf);
        JavaRDD<Long> rdd = new SequenceRDD(ctx.sc(), 0, n)
                .toJavaRDD()
                .cache();

        JavaPairRDD<Long, Long> pairs = rdd.cartesian(rdd);
        JavaRDD<MatrixEntry> entries = pairs.map(tup ->
                new MatrixEntry(tup._1(), tup._2(), tup._1() * tup._2()));
        CoordinateMatrix matrix = new CoordinateMatrix(entries.rdd());
        matrix.toIndexedRowMatrix()
                .rows()
                .saveAsTextFile(outpath);

        System.out.printf("n=%d\t outpath=%s\nTime taken : %dms\n", n,
                outpath, System.currentTimeMillis() - st);
        ctx.stop();
    }
}

No comments:

Post a Comment