## DBSCAN算法的Spark实现

### 一,实现思路

DBSCAN算法的分布式实现需要解决以下一些主要的问题。


**1,如何计算样本点中两两之间的距离?**

在单机环境下,计算样本点两两之间的距离比较简单,是一个双重遍历的过程。
为了减少计算量,可以用空间索引如Rtree进行加速。

在分布式环境,样本点分布在不同的分区,难以在不同的分区之间直接进行双重遍历。
为了解决这个问题,我的方案是将样本点不同的分区分成多个批次拉到Driver端,
然后依次广播到各个excutor分别计算距离,将最终结果union,从而间接实现双重遍历。

为了减少计算量,广播前对拉到Driver端的数据构建空间索引Rtree进行加速。



**2,如何构造临时聚类簇?**

这个问题不难,单机环境和分布式环境的实现差不多。

都是通过group的方式统计每个样本点周边邻域半径R内的样本点数量,

并记录它们的id,如果这些样本点数量超过minpoints则构造临时聚类簇,并维护核心点列表。


**3,如何合并相连的临时聚类簇得到聚类簇?**

这个是分布式实现中最最核心的步骤。

在单机环境下,标准做法是对每一个临时聚类簇,

判断其中的样本点是否在核心点列表,如果是,则将该样本点所在的临时聚类簇与当前临时聚类簇合并。并在核心点列表中删除该样本点。

重复此过程,直到当前临时聚类簇中所有的点都不在核心点列表。

在分布式环境下,临时聚类簇分布在不同的分区,无法直接扫描全局核心点列表进行临时聚类簇的合并。

我的方案是先在每一个分区内部对各个临时聚类簇进行合并,然后缩小分区数量重新分区,再在各个分区内部对每个临时聚类簇进行合并。

不断重复这个过程,最终将所有的临时聚类簇都划分到一个分区,完成对全部临时聚类簇的合并。

为了降低最后一个分区的存储压力,我采用了不同于标准的临时聚类簇的合并算法。

对每个临时聚类簇只关注其中的核心点id,而不关注非核心点id,以减少存储压力。合并时将有共同核心点id的临时聚类簇合并。

为了加快临时聚类的合并过程,分区时并非随机分区,而是以每个临时聚类簇的核心点id中的最小值min_core_id作为分区的Hash参数,

具有共同核心点id的临时聚类簇有更大的概率被划分到同一个分区,从而加快了合并过程。


![](./data/DBSCAN算法步骤.png)

### 二,核心代码

In [3]:
import org.apache.spark.sql.SparkSession

val spark = SparkSession
.builder()
.appName("dbscan")
.getOrCreate()

val sc = spark.sparkContext
import spark.implicits._

spark = org.apache.spark.sql.SparkSession@51a18d93
sc = org.apache.spark.SparkContext@35a42a03


org.apache.spark.SparkContext@35a42a03

**1,寻找核心点形成临时聚类簇。**

该步骤一般要采用空间索引 + 广播的方法,此处从略,假定已经得到了临时聚类簇。

In [4]:
//rdd_core的每一行代表一个临时聚类簇:(min_core_id, core_id_set)
//core_id_set为临时聚类簇所有核心点的编号,min_core_id为这些编号中取值最小的编号
var rdd_core = sc.parallelize(List((1L,Set(1L,2L)),(2L,Set(2L,3L,4L)),
 (6L,Set(6L,8L,9L)),(4L,Set(4L,5L)),
 (9L,Set(9L,10L,11L)),(15L,Set(15L,17L)),
 (10L,Set(10L,11L,18L))))
rdd_core.collect.foreach(println)

(1,Set(1, 2))
(2,Set(2, 3, 4))
(6,Set(6, 8, 9))
(4,Set(4, 5))
(9,Set(9, 10, 11))
(15,Set(15, 17))
(10,Set(10, 11, 18))


rdd_core = ParallelCollectionRDD[1] at parallelize at :34


ParallelCollectionRDD[1] at parallelize at :34

![](data/dbscan核心算法的输入.png)

**2,合并临时聚类簇得到聚类簇。**

In [5]:
import scala.collection.mutable.ListBuffer
import org.apache.spark.HashPartitioner

//定义合并函数:将有共同核心点的临时聚类簇合并
val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{
 var result = ListBuffer[Set[Long]]()
 while (set_list.size>0){
 var cur_set = set_list.remove(0)
 var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
 while(intersect_idxs.size>0){
 for(idx<-intersect_idxs){
 cur_set = cur_set|set_list(idx)
 }
 for(idx<-intersect_idxs){
 set_list.remove(idx)
 }
 intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
 }
 result = result:+cur_set
 }
 result
}

///对rdd_core分区后在每个分区合并,不断将分区数量减少,最终合并到一个分区
//如果数据规模十分大,难以合并到一个分区,也可以最终合并到多个分区,得到近似结果。
//rdd: (min_core_id,core_id_set)

def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):
org.apache.spark.rdd.RDD[(Long,Set[Long])] = {
 val rdd_merged = rdd.partitionBy(new HashPartitioner(partition_cnt))
 .mapPartitions(iter => {
 val buffer = ListBuffer[Set[Long]]()
 for(t<-iter){
 val core_id_set:Set[Long] = t._2
 buffer.append(core_id_set)
 }
 val merged_buffer = mergeSets(buffer)
 var result = List[(Long,Set[Long])]()
 for(core_id_set<-merged_buffer){
 val min_core_id = core_id_set.min
 result = result:+(min_core_id,core_id_set)
 }
 result.iterator
 })
 rdd_merged
}


mergeSets = > scala.collection.mutable.ListBuffer[Set[Long]] = 


mergeRDD: (rdd: org.apache.spark.rdd.RDD[(Long, Set[Long])], partition_cnt: Int)org.apache.spark.rdd.RDD[(Long, Set[Long])]


> scala.collection.mutable.ListBuffer[Set[Long]] = 

In [6]:
//分区迭代计算,可以根据需要调整迭代次数和分区数量
rdd_core = mergeRDD(rdd_core,8)
rdd_core = mergeRDD(rdd_core,4)
rdd_core = mergeRDD(rdd_core,1)
rdd_core.collect.foreach(println)

(1,Set(5, 1, 2, 3, 4))
(6,Set(10, 6, 9, 18, 11, 8))
(15,Set(15, 17))


rdd_core = MapPartitionsRDD[7] at mapPartitions at :63
rdd_core = MapPartitionsRDD[7] at mapPartitions at :63
rdd_core = MapPartitionsRDD[7] at mapPartitions at :63


MapPartitionsRDD[7] at mapPartitions at :63

![](data/dbscan核心算法的输出.png)

### 三,完整范例

以下为完整范例代码,使用了和《20分钟学会DBSCAN聚类算法》文中完全一样的数据源和参数,并且得到了完全一样的结果。

不同的是,以下代码是一种基于Spark的分布式实现,可以很好地扩展到大数据集上。


In [7]:
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.{DataFrame, Row, Column}
import org.apache.spark.sql.functions._
import org.locationtech.jts.geom.{Geometry,GeometryFactory,Coordinate,Point}
import org.locationtech.jts.index.strtree.STRtree
import org.apache.spark.sql.jts.registerTypes
import scala.collection.mutable.WrappedArray
import scala.collection.JavaConversions._

def printlog(info:String): Unit ={
 val dt = new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new java.util.Date)
 println("=========="*8+dt)
 println(info+"\n")
}

val spark = SparkSession
 .builder()
 .appName("dbscan")
 .enableHiveSupport()
 .getOrCreate()

val sc = spark.sparkContext
import spark.implicits._
registerTypes


spark = org.apache.spark.sql.SparkSession@51a18d93
sc = org.apache.spark.SparkContext@35a42a03


printlog: (info: String)Unit


org.apache.spark.SparkContext@35a42a03

In [8]:
/*================================================================================*/
// 一,读入数据 dfinput
/*================================================================================*/
printlog("step1: get input data -> dfinput ...")

val dfdata = spark.read.option("header","true") 
 .option("inferSchema","true") 
 .option("delimiter", "\t") 
 .csv("data/moon_points.csv")

spark.udf.register("makePoint", (x:Double,y:Double) =>{
 val gf = new GeometryFactory
 val pt = gf.createPoint(new Coordinate(x,y))
 pt
})

val dfinput = dfdata.selectExpr("makePoint(feature1,feature2) as point")
 .rdd.map(row=>row.getAs[Point]("point"))
 .zipWithIndex().toDF("geometry","id").selectExpr("id","geometry")
 .persist(StorageLevel.MEMORY_AND_DISK)

dfinput.show

step1: get input data -> dfinput ...

+---+--------------------+
| id| geometry|
+---+--------------------+
| 0|POINT (0.31655567...|
| 1|POINT (0.74088269...|
| 2|POINT (0.87172637...|
| 3|POINT (0.55552787...|
| 4|POINT (2.03872887...|
| 5|POINT (1.99136342...|
| 6|POINT (0.22384428...|
| 7|POINT (0.97295674...|
| 8|POINT (-0.9213036...|
| 9|POINT (0.46670632...|
| 10|POINT (0.49217803...|
| 11|POINT (-0.4223529...|
| 12|POINT (0.31358610...|
| 13|POINT (0.64848081...|
| 14|POINT (0.31549460...|
| 15|POINT (-0.9118786...|
| 16|POINT (1.70164131...|
| 17|POINT (0.10851453...|
| 18|POINT (-0.3098724...|
| 19|POINT (-0.2040816...|
+---+--------------------+
only showing top 20 rows



dfdata = [feature1: double, feature2: double]
dfinput = [id: bigint, geometry: point]


[id: bigint, geometry: point]

In [9]:
/*================================================================================*/
// 二,分批次广播RTree得到邻近关系 dfnear
/*================================================================================*/
printlog("step2: looking for neighbours by broadcasting Rtree -> dfnear ...")

spark.udf.register("getBufferBox", (p: Point) => p.getEnvelope.buffer(0.2).getEnvelope)

//分批次进行广播
val partition_cnt = 10
val rdd_input = dfinput.rdd.repartition(20).persist(StorageLevel.MEMORY_AND_DISK)
val dfbuffer = dfinput.selectExpr("id","getBufferBox(geometry) as envelop").repartition(partition_cnt)
var dfnear = List[(Long,Long,Point)]().toDF("s_fid","m_fid","s_geom")


for(partition_id <- 0 until partition_cnt){
 val bufferi = dfbuffer.rdd.mapPartitionsWithIndex(
 (idx, iter) => if (idx == partition_id ) iter else Iterator())
 val Rtree = new STRtree()
 bufferi.collect.foreach(x => Rtree.insert(x.getAs[Geometry]("envelop").getEnvelopeInternal, x))
 val tree_broads = sc.broadcast(Rtree)

 val dfneari = rdd_input.mapPartitions(iter => {
 var res_list = List[(Long,Long,Point)]()//s_fid,m_fid,s_geom
 val tree = tree_broads.value
 for (cur<-iter) {
 val s_fid = cur.getAs[Long]("id")
 val s_geom = cur.getAs[Point]("geometry")
 val results = tree.query(s_geom.getEnvelopeInternal).asInstanceOf[java.util.List[Row]]

 for (x<-results) {
 val m_fid = x.getAs[Long]("id")
 val m_envelop = x.getAs[Geometry]("envelop")
 if(m_envelop.intersects(s_geom)){
 res_list = res_list:+(s_fid,m_fid,s_geom)
 }
 }
 }
 res_list.iterator
 }).toDF("s_fid","m_fid","s_geom")

 dfnear = dfnear.union(dfneari)
}

dfnear.show(3)

step2: looking for neighbours by broadcasting Rtree -> dfnear ...

+-----+-----+--------------------+
|s_fid|m_fid| s_geom|
+-----+-----+--------------------+
| 19| 271|POINT (-0.2040816...|
| 19| 489|POINT (-0.2040816...|
| 19| 488|POINT (-0.2040816...|
+-----+-----+--------------------+
only showing top 3 rows



partition_cnt = 10
rdd_input = MapPartitionsRDD[40] at repartition at :64
dfbuffer = [id: bigint, envelop: geometry]
dfnear = [s_fid: bigint, m_fid: bigint ... 1 more field]


[s_fid: bigint, m_fid: bigint ... 1 more field]

In [10]:
dfnear.count

13062

In [11]:
/*================================================================================*/
// 三,根据DBSCAN邻域半径得到有效邻近关系 dfpair
/*================================================================================*/
printlog("step3: looking for effective pairs by DNN model-> dfpair...")



val dfpair_raw = dfinput.join(dfnear, dfinput("id")===dfnear("m_fid"), "right")
 .selectExpr("s_fid","m_fid","s_geom","geometry as m_geom")
 
spark.udf.register("distance", (p: Point, q:Point) => p.distance(q))
val dfpair = dfpair_raw.where("distance(s_geom,m_geom) < 0.2") //邻域半径R设置为0.2
 .persist(StorageLevel.MEMORY_AND_DISK)

dfpair.show

step3: looking for effective pairs by DNN model-> dfpair...

+-----+-----+--------------------+--------------------+
|s_fid|m_fid| s_geom| m_geom|
+-----+-----+--------------------+--------------------+
| 19| 489|POINT (-0.2040816...|POINT (-0.0796163...|
| 19| 488|POINT (-0.2040816...|POINT (-0.0587587...|
| 19| 465|POINT (-0.2040816...|POINT (-0.1552430...|
| 39| 311|POINT (1.00599833...|POINT (1.16695573...|
| 39| 64|POINT (1.00599833...|POINT (1.15057045...|
| 39| 416|POINT (1.00599833...|POINT (0.89927494...|
| 59| 416|POINT (0.75345168...|POINT (0.89927494...|
| 79| 271|POINT (-0.1792310...|POINT (-0.3258320...|
| 79| 489|POINT (-0.1792310...|POINT (-0.0796163...|
| 79| 488|POINT (-0.1792310...|POINT (-0.0587587...|
| 79| 465|POINT (-0.1792310...|POINT (-0.1552430...|
| 99| 147|POINT (0.22302604...|POINT (0.33332249...|
| 99| 283|POINT (0.22302604...|POINT (0.36557375...|
| 99| 216|POINT (0.22302604...|POINT (0.19487138...|
| 119| 288|POINT (1.62074019...|POINT (1.59041478...|
| 

dfpair_raw = [s_fid: bigint, m_fid: bigint ... 2 more fields]
dfpair = [s_fid: bigint, m_fid: bigint ... 2 more fields]


[s_fid: bigint, m_fid: bigint ... 2 more fields]

In [12]:
/*================================================================================*/
// 四,创建临时聚类簇 dfcore
/*================================================================================*/
printlog("step4: looking for temporatory clusters -> dfcore ...")


val dfcore = dfpair.groupBy("s_fid").agg(
 first("s_geom") as "s_geom",
 count("m_fid") as "neighbour_cnt",
 collect_list("m_fid") as "neighbour_ids"
).where("neighbour_cnt>=20") //此处最少点数目minpoits设置为20
.persist(StorageLevel.MEMORY_AND_DISK)

dfcore.show(3)

step4: looking for temporatory clusters -> dfcore ...

+-----+--------------------+-------------+--------------------+
|s_fid| s_geom|neighbour_cnt| neighbour_ids|
+-----+--------------------+-------------+--------------------+
| 26|POINT (0.95199382...| 25|[220, 460, 26, 22...|
| 65|POINT (0.46872165...| 30|[491, 65, 258, 44...|
| 418|POINT (0.04187413...| 22|[392, 475, 291, 4...|
+-----+--------------------+-------------+--------------------+
only showing top 3 rows



dfcore = [s_fid: bigint, s_geom: point ... 2 more fields]


[s_fid: bigint, s_geom: point ... 2 more fields]

In [13]:
/*================================================================================*/
// 五,得到临时聚类簇的核心点信息 rdd_core
/*================================================================================*/
printlog("step5: get infomation for temporatory clusters -> rdd_core ...")

val dfpair_join = dfcore.selectExpr("s_fid").join(dfpair,Seq("s_fid"),"inner")
val df_fids = dfcore.selectExpr("s_fid as m_fid")
val dfpair_core = df_fids.join(dfpair_join,Seq("m_fid"),"inner")
var rdd_core = dfpair_core.groupBy("s_fid").agg(
 min("m_fid") as "min_core_id",
 collect_set("m_fid") as "core_id_set"
).rdd.map(row =>{
 val min_core_id = row.getAs[Long]("min_core_id")
 val core_id_set = row.getAs[WrappedArray[Long]]("core_id_set").toArray.toSet
 (min_core_id,core_id_set)
})

println(s"before dbscan: rdd_core.count = ${rdd_core.count}")

step5: get infomation for temporatory clusters -> rdd_core ...

before dbscan: rdd_core.count = 358


dfpair_join = [s_fid: bigint, m_fid: bigint ... 2 more fields]
df_fids = [m_fid: bigint]
dfpair_core = [m_fid: bigint, s_fid: bigint ... 2 more fields]
rdd_core = MapPartitionsRDD[192] at map at :69


MapPartitionsRDD[192] at map at :69

In [14]:
/*================================================================================*/
// 六,对rdd_core分区分步合并 rdd_core(min_core_id, core_id_set)
/*================================================================================*/
printlog("step6: run dbscan clustering ...")

//定义合并方法
val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{
 var result = ListBuffer[Set[Long]]()
 while (set_list.size>0){
 var cur_set = set_list.remove(0)
 var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
 while(intersect_idxs.size>0){
 for(idx<-intersect_idxs){
 cur_set = cur_set|set_list(idx)
 }
 for(idx<-intersect_idxs){
 set_list.remove(idx)
 }
 intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)
 }
 result = result:+cur_set
 }
 result
}

//对rdd_core分区后在每个分区合并,不断将分区数量减少,最终合并到一个分区
//如果数据规模十分大,难以合并到一个分区,也可以最终合并到多个分区,得到近似结果。
//rdd: (min_core_id,core_id_set)

def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):
org.apache.spark.rdd.RDD[(Long,Set[Long])] = {
 val rdd_merged = rdd.partitionBy(new HashPartitioner(partition_cnt))
 .mapPartitions(iter => {
 val buffer = ListBuffer[Set[Long]]()
 for(t<-iter){
 val core_id_set:Set[Long] = t._2
 buffer.add(core_id_set)
 }
 val merged_buffer = mergeSets(buffer)
 var result = List[(Long,Set[Long])]()
 for(core_id_set<-merged_buffer){
 val min_core_id = core_id_set.min
 result = result:+(min_core_id,core_id_set)
 }
 result.iterator
 })
 rdd_merged
}


//!此处需要调整分区数量和迭代次数

for(pcnt<-Array(16,8,4,1)){
 rdd_core = mergeRDD(rdd_core,pcnt)
}

println(s"after dbscan: rdd_core.count = ${rdd_core.count}")

step6: run dbscan clustering ...

after dbscan: rdd_core.count = 2


mergeSets = > scala.collection.mutable.ListBuffer[Set[Long]] = 


mergeRDD: (rdd: org.apache.spark.rdd.RDD[(Long, Set[Long])], partition_cnt: Int)org.apache.spark.rdd.RDD[(Long, Set[Long])]


> scala.collection.mutable.ListBuffer[Set[Long]] = 

In [15]:
/*================================================================================*/
// 七,获取每一个core的簇信息
/*================================================================================*/
printlog("step7: get cluster ids ...")

val dfcluster_ids = rdd_core.flatMap(t => {
 val cluster_id = t._1
 val id_set = t._2
 for(core_id<-id_set) yield (cluster_id, core_id)
}).toDF("cluster_id","s_fid")

val dfclusters = dfcore.join(dfcluster_ids, Seq("s_fid"), "left")

step7: get cluster ids ...



dfcluster_ids = [cluster_id: bigint, s_fid: bigint]
dfclusters = [s_fid: bigint, s_geom: point ... 3 more fields]


[s_fid: bigint, s_geom: point ... 3 more fields]

In [16]:
/*================================================================================*/
// 八,求每一个簇的代表核心和簇元素数量
/*================================================================================*/
printlog("step8: evaluate cluster representation ...")

val rdd_cluster = dfclusters.rdd.map(row=> {
 val cluster_id = row.getAs[Long]("cluster_id")
 val s_geom = row.getAs[Point]("s_geom")
 val neighbour_cnt = row.getAs[Long]("neighbour_cnt")
 val id_set = row.getAs[WrappedArray[Long]]("neighbour_ids").toSet
 (cluster_id,(s_geom,neighbour_cnt,id_set))
})

val rdd_result = rdd_cluster.reduceByKey((a,b)=>{
 val id_set = a._3 | b._3
 val result = if(a._2>=b._2) (a._1,a._2,id_set)
 else (b._1,b._2,id_set)
 result
})

val dfresult = rdd_result.map(t=>{
 val cluster_id = t._1
 val representation_point = t._2._1
 val neighbour_points_cnt = t._2._2
 val id_set = t._2._3
 val cluster_points_cnt = id_set.size
 (cluster_id,representation_point,neighbour_points_cnt,cluster_points_cnt,id_set)
}).toDF("cluster_id","representation_point","neighbour_points_cnt","cluster_points_cnt","id_set")

dfresult.show(3)

step8: evaluate cluster representation ...

+----------+--------------------+--------------------+------------------+--------------------+
|cluster_id|representation_point|neighbour_points_cnt|cluster_points_cnt| id_set|
+----------+--------------------+--------------------+------------------+--------------------+
| 0|POINT (1.95163238...| 32| 242|[365, 138, 101, 4...|
| 2|POINT (0.95067226...| 34| 241|[69, 347, 468, 35...|
+----------+--------------------+--------------------+------------------+--------------------+



rdd_cluster = MapPartitionsRDD[215] at map at :59
rdd_result = ShuffledRDD[216] at reduceByKey at :67
dfresult = [cluster_id: bigint, representation_point: point ... 3 more fields]


[cluster_id: bigint, representation_point: point ... 3 more fields]

注意到我们的结果中

聚类簇数量为2个。

噪声点数量为500-242-241 = 17个

和调用sklearn中的结果完全一致。

![](data/sklearn的DBSCAN聚类结果.png)