One option is to perform a broadcast join by collecting rdd1
to the driver and broadcasting it to all mappers; done correctly, this will let us avoid an expensive shuffle of the large rdd2
RDD:
val rdd1 = sc.parallelize(Seq((1, "A"), (2, "B"), (3, "C")))
val rdd2 = sc.parallelize(Seq(((1, "Z"), 111), ((1, "ZZ"), 111), ((2, "Y"), 222), ((3, "X"), 333)))
val rdd1Broadcast = sc.broadcast(rdd1.collectAsMap())
val joined = rdd2.mapPartitions({ iter =>
val m = rdd1Broadcast.value
for {
((t, w), u) <- iter
if m.contains(t)
} yield ((t, w), (u, m.get(t).get))
}, preservesPartitioning = true)
The preservesPartitioning = true
tells Spark that this map function doesn't modify the keys of rdd2
; this will allow Spark to avoid re-partitioning rdd2
for any subsequent operations that join based on the (t, w)
key.
This broadcast could be inefficient since it involves a communications bottleneck at the driver. In principle, it's possible to broadcast one RDD to another without involving the driver; I have a prototype of this that I'd like to generalize and add to Spark.
Another option is to re-map the keys of rdd2
and use the Spark join
method; this will involve a full shuffle of rdd2
(and possibly rdd1
):
rdd1.join(rdd2.map {
case ((t, w), u) => (t, (w, u))
}).map {
case (t, (v, (w, u))) => ((t, w), (u, v))
}.collect()
On my sample input, both of these methods produce the same result:
res1: Array[((Int, java.lang.String), (Int, java.lang.String))] = Array(((1,Z),(111,A)), ((1,ZZ),(111,A)), ((2,Y),(222,B)), ((3,X),(333,C)))
A third option would be to restructure rdd2
so that t
is its key, then perform the above join.