{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## DBSCAN算法的Spark实现" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 一,实现思路" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DBSCAN算法的分布式实现需要解决以下一些主要的问题。\n", "\n", "\n", "**1,如何计算样本点中两两之间的距离?**\n", "\n", "在单机环境下,计算样本点两两之间的距离比较简单,是一个双重遍历的过程。\n", "为了减少计算量,可以用空间索引如Rtree进行加速。\n", "\n", "在分布式环境,样本点分布在不同的分区,难以在不同的分区之间直接进行双重遍历。\n", "为了解决这个问题,我的方案是将样本点不同的分区分成多个批次拉到Driver端,\n", "然后依次广播到各个excutor分别计算距离,将最终结果union,从而间接实现双重遍历。\n", "\n", "为了减少计算量,广播前对拉到Driver端的数据构建空间索引Rtree进行加速。\n", "\n", "\n", "\n", "**2,如何构造临时聚类簇?**\n", "\n", "这个问题不难,单机环境和分布式环境的实现差不多。\n", "\n", "都是通过group的方式统计每个样本点周边邻域半径R内的样本点数量,\n", "\n", "并记录它们的id,如果这些样本点数量超过minpoints则构造临时聚类簇,并维护核心点列表。\n", "\n", "\n", "**3,如何合并相连的临时聚类簇得到聚类簇?**\n", "\n", "这个是分布式实现中最最核心的步骤。\n", "\n", "在单机环境下,标准做法是对每一个临时聚类簇,\n", "\n", "判断其中的样本点是否在核心点列表,如果是,则将该样本点所在的临时聚类簇与当前临时聚类簇合并。并在核心点列表中删除该样本点。\n", "\n", "重复此过程,直到当前临时聚类簇中所有的点都不在核心点列表。\n", "\n", "在分布式环境下,临时聚类簇分布在不同的分区,无法直接扫描全局核心点列表进行临时聚类簇的合并。\n", "\n", "我的方案是先在每一个分区内部对各个临时聚类簇进行合并,然后缩小分区数量重新分区,再在各个分区内部对每个临时聚类簇进行合并。\n", "\n", "不断重复这个过程,最终将所有的临时聚类簇都划分到一个分区,完成对全部临时聚类簇的合并。\n", "\n", "为了降低最后一个分区的存储压力,我采用了不同于标准的临时聚类簇的合并算法。\n", "\n", "对每个临时聚类簇只关注其中的核心点id,而不关注非核心点id,以减少存储压力。合并时将有共同核心点id的临时聚类簇合并。\n", "\n", "为了加快临时聚类的合并过程,分区时并非随机分区,而是以每个临时聚类簇的核心点id中的最小值min_core_id作为分区的Hash参数,\n", "\n", "具有共同核心点id的临时聚类簇有更大的概率被划分到同一个分区,从而加快了合并过程。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](./data/DBSCAN算法步骤.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 二,核心代码" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "spark = org.apache.spark.sql.SparkSession@51a18d93\n", "sc = org.apache.spark.SparkContext@35a42a03\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "org.apache.spark.SparkContext@35a42a03" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import org.apache.spark.sql.SparkSession\n", "\n", "val spark = SparkSession\n", ".builder()\n", ".appName(\"dbscan\")\n", ".getOrCreate()\n", "\n", "val sc = spark.sparkContext\n", "import spark.implicits._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**1,寻找核心点形成临时聚类簇。**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "该步骤一般要采用空间索引 + 广播的方法,此处从略,假定已经得到了临时聚类簇。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1,Set(1, 2))\n", "(2,Set(2, 3, 4))\n", "(6,Set(6, 8, 9))\n", "(4,Set(4, 5))\n", "(9,Set(9, 10, 11))\n", "(15,Set(15, 17))\n", "(10,Set(10, 11, 18))\n" ] }, { "data": { "text/plain": [ "rdd_core = ParallelCollectionRDD[1] at parallelize at :34\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "ParallelCollectionRDD[1] at parallelize at :34" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "//rdd_core的每一行代表一个临时聚类簇:(min_core_id, core_id_set)\n", "//core_id_set为临时聚类簇所有核心点的编号,min_core_id为这些编号中取值最小的编号\n", "var rdd_core = sc.parallelize(List((1L,Set(1L,2L)),(2L,Set(2L,3L,4L)),\n", " (6L,Set(6L,8L,9L)),(4L,Set(4L,5L)),\n", " (9L,Set(9L,10L,11L)),(15L,Set(15L,17L)),\n", " (10L,Set(10L,11L,18L))))\n", "rdd_core.collect.foreach(println)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](data/dbscan核心算法的输入.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**2,合并临时聚类簇得到聚类簇。**" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "mergeSets = > scala.collection.mutable.ListBuffer[Set[Long]] = \n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "mergeRDD: (rdd: org.apache.spark.rdd.RDD[(Long, Set[Long])], partition_cnt: Int)org.apache.spark.rdd.RDD[(Long, Set[Long])]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "> scala.collection.mutable.ListBuffer[Set[Long]] = " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import scala.collection.mutable.ListBuffer\n", "import org.apache.spark.HashPartitioner\n", "\n", "//定义合并函数:将有共同核心点的临时聚类簇合并\n", "val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{\n", " var result = ListBuffer[Set[Long]]()\n", " while (set_list.size>0){\n", " var cur_set = set_list.remove(0)\n", " var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)\n", " while(intersect_idxs.size>0){\n", " for(idx<-intersect_idxs){\n", " cur_set = cur_set|set_list(idx)\n", " }\n", " for(idx<-intersect_idxs){\n", " set_list.remove(idx)\n", " }\n", " intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)\n", " }\n", " result = result:+cur_set\n", " }\n", " result\n", "}\n", "\n", "///对rdd_core分区后在每个分区合并,不断将分区数量减少,最终合并到一个分区\n", "//如果数据规模十分大,难以合并到一个分区,也可以最终合并到多个分区,得到近似结果。\n", "//rdd: (min_core_id,core_id_set)\n", "\n", "def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):\n", "org.apache.spark.rdd.RDD[(Long,Set[Long])] = {\n", " val rdd_merged = rdd.partitionBy(new HashPartitioner(partition_cnt))\n", " .mapPartitions(iter => {\n", " val buffer = ListBuffer[Set[Long]]()\n", " for(t<-iter){\n", " val core_id_set:Set[Long] = t._2\n", " buffer.append(core_id_set)\n", " }\n", " val merged_buffer = mergeSets(buffer)\n", " var result = List[(Long,Set[Long])]()\n", " for(core_id_set<-merged_buffer){\n", " val min_core_id = core_id_set.min\n", " result = result:+(min_core_id,core_id_set)\n", " }\n", " result.iterator\n", " })\n", " rdd_merged\n", "}\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1,Set(5, 1, 2, 3, 4))\n", "(6,Set(10, 6, 9, 18, 11, 8))\n", "(15,Set(15, 17))\n" ] }, { "data": { "text/plain": [ "rdd_core = MapPartitionsRDD[7] at mapPartitions at :63\n", "rdd_core = MapPartitionsRDD[7] at mapPartitions at :63\n", "rdd_core = MapPartitionsRDD[7] at mapPartitions at :63\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "MapPartitionsRDD[7] at mapPartitions at :63" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "//分区迭代计算,可以根据需要调整迭代次数和分区数量\n", "rdd_core = mergeRDD(rdd_core,8)\n", "rdd_core = mergeRDD(rdd_core,4)\n", "rdd_core = mergeRDD(rdd_core,1)\n", "rdd_core.collect.foreach(println)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](data/dbscan核心算法的输出.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三,完整范例" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以下为完整范例代码,使用了和《20分钟学会DBSCAN聚类算法》文中完全一样的数据源和参数,并且得到了完全一样的结果。\n", "\n", "不同的是,以下代码是一种基于Spark的分布式实现,可以很好地扩展到大数据集上。\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "spark = org.apache.spark.sql.SparkSession@51a18d93\n", "sc = org.apache.spark.SparkContext@35a42a03\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "printlog: (info: String)Unit\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "org.apache.spark.SparkContext@35a42a03" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import org.apache.spark.sql.SparkSession\n", "import org.apache.spark.storage.StorageLevel\n", "import org.apache.spark.sql.{DataFrame, Row, Column}\n", "import org.apache.spark.sql.functions._\n", "import org.locationtech.jts.geom.{Geometry,GeometryFactory,Coordinate,Point}\n", "import org.locationtech.jts.index.strtree.STRtree\n", "import org.apache.spark.sql.jts.registerTypes\n", "import scala.collection.mutable.WrappedArray\n", "import scala.collection.JavaConversions._\n", "\n", "def printlog(info:String): Unit ={\n", " val dt = new java.text.SimpleDateFormat(\"yyyy-MM-dd HH:mm:ss\").format(new java.util.Date)\n", " println(\"==========\"*8+dt)\n", " println(info+\"\\n\")\n", "}\n", "\n", "val spark = SparkSession\n", " .builder()\n", " .appName(\"dbscan\")\n", " .enableHiveSupport()\n", " .getOrCreate()\n", "\n", "val sc = spark.sparkContext\n", "import spark.implicits._\n", "registerTypes\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================2019-11-24 18:06:12\n", "step1: get input data -> dfinput ...\n", "\n", "+---+--------------------+\n", "| id| geometry|\n", "+---+--------------------+\n", "| 0|POINT (0.31655567...|\n", "| 1|POINT (0.74088269...|\n", "| 2|POINT (0.87172637...|\n", "| 3|POINT (0.55552787...|\n", "| 4|POINT (2.03872887...|\n", "| 5|POINT (1.99136342...|\n", "| 6|POINT (0.22384428...|\n", "| 7|POINT (0.97295674...|\n", "| 8|POINT (-0.9213036...|\n", "| 9|POINT (0.46670632...|\n", "| 10|POINT (0.49217803...|\n", "| 11|POINT (-0.4223529...|\n", "| 12|POINT (0.31358610...|\n", "| 13|POINT (0.64848081...|\n", "| 14|POINT (0.31549460...|\n", "| 15|POINT (-0.9118786...|\n", "| 16|POINT (1.70164131...|\n", "| 17|POINT (0.10851453...|\n", "| 18|POINT (-0.3098724...|\n", "| 19|POINT (-0.2040816...|\n", "+---+--------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "data": { "text/plain": [ "dfdata = [feature1: double, feature2: double]\n", "dfinput = [id: bigint, geometry: point]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[id: bigint, geometry: point]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "/*================================================================================*/\n", "// 一,读入数据 dfinput\n", "/*================================================================================*/\n", "printlog(\"step1: get input data -> dfinput ...\")\n", "\n", "val dfdata = spark.read.option(\"header\",\"true\") \n", " .option(\"inferSchema\",\"true\") \n", " .option(\"delimiter\", \"\\t\") \n", " .csv(\"data/moon_points.csv\")\n", "\n", "spark.udf.register(\"makePoint\", (x:Double,y:Double) =>{\n", " val gf = new GeometryFactory\n", " val pt = gf.createPoint(new Coordinate(x,y))\n", " pt\n", "})\n", "\n", "val dfinput = dfdata.selectExpr(\"makePoint(feature1,feature2) as point\")\n", " .rdd.map(row=>row.getAs[Point](\"point\"))\n", " .zipWithIndex().toDF(\"geometry\",\"id\").selectExpr(\"id\",\"geometry\")\n", " .persist(StorageLevel.MEMORY_AND_DISK)\n", "\n", "dfinput.show" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================2019-11-24 18:06:26\n", "step2: looking for neighbours by broadcasting Rtree -> dfnear ...\n", "\n", "+-----+-----+--------------------+\n", "|s_fid|m_fid| s_geom|\n", "+-----+-----+--------------------+\n", "| 19| 271|POINT (-0.2040816...|\n", "| 19| 489|POINT (-0.2040816...|\n", "| 19| 488|POINT (-0.2040816...|\n", "+-----+-----+--------------------+\n", "only showing top 3 rows\n", "\n" ] }, { "data": { "text/plain": [ "partition_cnt = 10\n", "rdd_input = MapPartitionsRDD[40] at repartition at :64\n", "dfbuffer = [id: bigint, envelop: geometry]\n", "dfnear = [s_fid: bigint, m_fid: bigint ... 1 more field]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[s_fid: bigint, m_fid: bigint ... 1 more field]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "/*================================================================================*/\n", "// 二,分批次广播RTree得到邻近关系 dfnear\n", "/*================================================================================*/\n", "printlog(\"step2: looking for neighbours by broadcasting Rtree -> dfnear ...\")\n", "\n", "spark.udf.register(\"getBufferBox\", (p: Point) => p.getEnvelope.buffer(0.2).getEnvelope)\n", "\n", "//分批次进行广播\n", "val partition_cnt = 10\n", "val rdd_input = dfinput.rdd.repartition(20).persist(StorageLevel.MEMORY_AND_DISK)\n", "val dfbuffer = dfinput.selectExpr(\"id\",\"getBufferBox(geometry) as envelop\").repartition(partition_cnt)\n", "var dfnear = List[(Long,Long,Point)]().toDF(\"s_fid\",\"m_fid\",\"s_geom\")\n", "\n", "\n", "for(partition_id <- 0 until partition_cnt){\n", " val bufferi = dfbuffer.rdd.mapPartitionsWithIndex(\n", " (idx, iter) => if (idx == partition_id ) iter else Iterator())\n", " val Rtree = new STRtree()\n", " bufferi.collect.foreach(x => Rtree.insert(x.getAs[Geometry](\"envelop\").getEnvelopeInternal, x))\n", " val tree_broads = sc.broadcast(Rtree)\n", "\n", " val dfneari = rdd_input.mapPartitions(iter => {\n", " var res_list = List[(Long,Long,Point)]()//s_fid,m_fid,s_geom\n", " val tree = tree_broads.value\n", " for (cur<-iter) {\n", " val s_fid = cur.getAs[Long](\"id\")\n", " val s_geom = cur.getAs[Point](\"geometry\")\n", " val results = tree.query(s_geom.getEnvelopeInternal).asInstanceOf[java.util.List[Row]]\n", "\n", " for (x<-results) {\n", " val m_fid = x.getAs[Long](\"id\")\n", " val m_envelop = x.getAs[Geometry](\"envelop\")\n", " if(m_envelop.intersects(s_geom)){\n", " res_list = res_list:+(s_fid,m_fid,s_geom)\n", " }\n", " }\n", " }\n", " res_list.iterator\n", " }).toDF(\"s_fid\",\"m_fid\",\"s_geom\")\n", "\n", " dfnear = dfnear.union(dfneari)\n", "}\n", "\n", "dfnear.show(3)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "13062" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dfnear.count" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================2019-11-24 18:06:40\n", "step3: looking for effective pairs by DNN model-> dfpair...\n", "\n", "+-----+-----+--------------------+--------------------+\n", "|s_fid|m_fid| s_geom| m_geom|\n", "+-----+-----+--------------------+--------------------+\n", "| 19| 489|POINT (-0.2040816...|POINT (-0.0796163...|\n", "| 19| 488|POINT (-0.2040816...|POINT (-0.0587587...|\n", "| 19| 465|POINT (-0.2040816...|POINT (-0.1552430...|\n", "| 39| 311|POINT (1.00599833...|POINT (1.16695573...|\n", "| 39| 64|POINT (1.00599833...|POINT (1.15057045...|\n", "| 39| 416|POINT (1.00599833...|POINT (0.89927494...|\n", "| 59| 416|POINT (0.75345168...|POINT (0.89927494...|\n", "| 79| 271|POINT (-0.1792310...|POINT (-0.3258320...|\n", "| 79| 489|POINT (-0.1792310...|POINT (-0.0796163...|\n", "| 79| 488|POINT (-0.1792310...|POINT (-0.0587587...|\n", "| 79| 465|POINT (-0.1792310...|POINT (-0.1552430...|\n", "| 99| 147|POINT (0.22302604...|POINT (0.33332249...|\n", "| 99| 283|POINT (0.22302604...|POINT (0.36557375...|\n", "| 99| 216|POINT (0.22302604...|POINT (0.19487138...|\n", "| 119| 288|POINT (1.62074019...|POINT (1.59041478...|\n", "| 119| 181|POINT (1.62074019...|POINT (1.50625683...|\n", "| 119| 450|POINT (1.62074019...|POINT (1.51993839...|\n", "| 139| 144|POINT (-0.9985058...|POINT (-0.9342012...|\n", "| 159| 372|POINT (0.14456261...|POINT (0.19495316...|\n", "| 159| 20|POINT (0.14456261...|POINT (0.17214778...|\n", "+-----+-----+--------------------+--------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "data": { "text/plain": [ "dfpair_raw = [s_fid: bigint, m_fid: bigint ... 2 more fields]\n", "dfpair = [s_fid: bigint, m_fid: bigint ... 2 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[s_fid: bigint, m_fid: bigint ... 2 more fields]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "/*================================================================================*/\n", "// 三,根据DBSCAN邻域半径得到有效邻近关系 dfpair\n", "/*================================================================================*/\n", "printlog(\"step3: looking for effective pairs by DNN model-> dfpair...\")\n", "\n", "\n", "\n", "val dfpair_raw = dfinput.join(dfnear, dfinput(\"id\")===dfnear(\"m_fid\"), \"right\")\n", " .selectExpr(\"s_fid\",\"m_fid\",\"s_geom\",\"geometry as m_geom\")\n", " \n", "spark.udf.register(\"distance\", (p: Point, q:Point) => p.distance(q))\n", "val dfpair = dfpair_raw.where(\"distance(s_geom,m_geom) < 0.2\") //邻域半径R设置为0.2\n", " .persist(StorageLevel.MEMORY_AND_DISK)\n", "\n", "dfpair.show" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================2019-11-24 18:06:45\n", "step4: looking for temporatory clusters -> dfcore ...\n", "\n", "+-----+--------------------+-------------+--------------------+\n", "|s_fid| s_geom|neighbour_cnt| neighbour_ids|\n", "+-----+--------------------+-------------+--------------------+\n", "| 26|POINT (0.95199382...| 25|[220, 460, 26, 22...|\n", "| 65|POINT (0.46872165...| 30|[491, 65, 258, 44...|\n", "| 418|POINT (0.04187413...| 22|[392, 475, 291, 4...|\n", "+-----+--------------------+-------------+--------------------+\n", "only showing top 3 rows\n", "\n" ] }, { "data": { "text/plain": [ "dfcore = [s_fid: bigint, s_geom: point ... 2 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[s_fid: bigint, s_geom: point ... 2 more fields]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "/*================================================================================*/\n", "// 四,创建临时聚类簇 dfcore\n", "/*================================================================================*/\n", "printlog(\"step4: looking for temporatory clusters -> dfcore ...\")\n", "\n", "\n", "val dfcore = dfpair.groupBy(\"s_fid\").agg(\n", " first(\"s_geom\") as \"s_geom\",\n", " count(\"m_fid\") as \"neighbour_cnt\",\n", " collect_list(\"m_fid\") as \"neighbour_ids\"\n", ").where(\"neighbour_cnt>=20\") //此处最少点数目minpoits设置为20\n", ".persist(StorageLevel.MEMORY_AND_DISK)\n", "\n", "dfcore.show(3)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================2019-11-24 18:09:26\n", "step5: get infomation for temporatory clusters -> rdd_core ...\n", "\n", "before dbscan: rdd_core.count = 358\n" ] }, { "data": { "text/plain": [ "dfpair_join = [s_fid: bigint, m_fid: bigint ... 2 more fields]\n", "df_fids = [m_fid: bigint]\n", "dfpair_core = [m_fid: bigint, s_fid: bigint ... 2 more fields]\n", "rdd_core = MapPartitionsRDD[192] at map at :69\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "MapPartitionsRDD[192] at map at :69" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "/*================================================================================*/\n", "// 五,得到临时聚类簇的核心点信息 rdd_core\n", "/*================================================================================*/\n", "printlog(\"step5: get infomation for temporatory clusters -> rdd_core ...\")\n", "\n", "val dfpair_join = dfcore.selectExpr(\"s_fid\").join(dfpair,Seq(\"s_fid\"),\"inner\")\n", "val df_fids = dfcore.selectExpr(\"s_fid as m_fid\")\n", "val dfpair_core = df_fids.join(dfpair_join,Seq(\"m_fid\"),\"inner\")\n", "var rdd_core = dfpair_core.groupBy(\"s_fid\").agg(\n", " min(\"m_fid\") as \"min_core_id\",\n", " collect_set(\"m_fid\") as \"core_id_set\"\n", ").rdd.map(row =>{\n", " val min_core_id = row.getAs[Long](\"min_core_id\")\n", " val core_id_set = row.getAs[WrappedArray[Long]](\"core_id_set\").toArray.toSet\n", " (min_core_id,core_id_set)\n", "})\n", "\n", "println(s\"before dbscan: rdd_core.count = ${rdd_core.count}\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================2019-11-24 18:09:54\n", "step6: run dbscan clustering ...\n", "\n", "after dbscan: rdd_core.count = 2\n" ] }, { "data": { "text/plain": [ "mergeSets = > scala.collection.mutable.ListBuffer[Set[Long]] = \n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "mergeRDD: (rdd: org.apache.spark.rdd.RDD[(Long, Set[Long])], partition_cnt: Int)org.apache.spark.rdd.RDD[(Long, Set[Long])]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "> scala.collection.mutable.ListBuffer[Set[Long]] = " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "/*================================================================================*/\n", "// 六,对rdd_core分区分步合并 rdd_core(min_core_id, core_id_set)\n", "/*================================================================================*/\n", "printlog(\"step6: run dbscan clustering ...\")\n", "\n", "//定义合并方法\n", "val mergeSets = (set_list: ListBuffer[Set[Long]]) =>{\n", " var result = ListBuffer[Set[Long]]()\n", " while (set_list.size>0){\n", " var cur_set = set_list.remove(0)\n", " var intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)\n", " while(intersect_idxs.size>0){\n", " for(idx<-intersect_idxs){\n", " cur_set = cur_set|set_list(idx)\n", " }\n", " for(idx<-intersect_idxs){\n", " set_list.remove(idx)\n", " }\n", " intersect_idxs = List.range(set_list.size-1,-1,-1).filter(i=>(cur_set&set_list(i)).size>0)\n", " }\n", " result = result:+cur_set\n", " }\n", " result\n", "}\n", "\n", "//对rdd_core分区后在每个分区合并,不断将分区数量减少,最终合并到一个分区\n", "//如果数据规模十分大,难以合并到一个分区,也可以最终合并到多个分区,得到近似结果。\n", "//rdd: (min_core_id,core_id_set)\n", "\n", "def mergeRDD(rdd: org.apache.spark.rdd.RDD[(Long,Set[Long])], partition_cnt:Int):\n", "org.apache.spark.rdd.RDD[(Long,Set[Long])] = {\n", " val rdd_merged = rdd.partitionBy(new HashPartitioner(partition_cnt))\n", " .mapPartitions(iter => {\n", " val buffer = ListBuffer[Set[Long]]()\n", " for(t<-iter){\n", " val core_id_set:Set[Long] = t._2\n", " buffer.add(core_id_set)\n", " }\n", " val merged_buffer = mergeSets(buffer)\n", " var result = List[(Long,Set[Long])]()\n", " for(core_id_set<-merged_buffer){\n", " val min_core_id = core_id_set.min\n", " result = result:+(min_core_id,core_id_set)\n", " }\n", " result.iterator\n", " })\n", " rdd_merged\n", "}\n", "\n", "\n", "//!此处需要调整分区数量和迭代次数\n", "\n", "for(pcnt<-Array(16,8,4,1)){\n", " rdd_core = mergeRDD(rdd_core,pcnt)\n", "}\n", "\n", "println(s\"after dbscan: rdd_core.count = ${rdd_core.count}\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================2019-11-24 18:10:04\n", "step7: get cluster ids ...\n", "\n" ] }, { "data": { "text/plain": [ "dfcluster_ids = [cluster_id: bigint, s_fid: bigint]\n", "dfclusters = [s_fid: bigint, s_geom: point ... 3 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[s_fid: bigint, s_geom: point ... 3 more fields]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "/*================================================================================*/\n", "// 七,获取每一个core的簇信息\n", "/*================================================================================*/\n", "printlog(\"step7: get cluster ids ...\")\n", "\n", "val dfcluster_ids = rdd_core.flatMap(t => {\n", " val cluster_id = t._1\n", " val id_set = t._2\n", " for(core_id<-id_set) yield (cluster_id, core_id)\n", "}).toDF(\"cluster_id\",\"s_fid\")\n", "\n", "val dfclusters = dfcore.join(dfcluster_ids, Seq(\"s_fid\"), \"left\")" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================2019-11-24 18:10:13\n", "step8: evaluate cluster representation ...\n", "\n", "+----------+--------------------+--------------------+------------------+--------------------+\n", "|cluster_id|representation_point|neighbour_points_cnt|cluster_points_cnt| id_set|\n", "+----------+--------------------+--------------------+------------------+--------------------+\n", "| 0|POINT (1.95163238...| 32| 242|[365, 138, 101, 4...|\n", "| 2|POINT (0.95067226...| 34| 241|[69, 347, 468, 35...|\n", "+----------+--------------------+--------------------+------------------+--------------------+\n", "\n" ] }, { "data": { "text/plain": [ "rdd_cluster = MapPartitionsRDD[215] at map at :59\n", "rdd_result = ShuffledRDD[216] at reduceByKey at :67\n", "dfresult = [cluster_id: bigint, representation_point: point ... 3 more fields]\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[cluster_id: bigint, representation_point: point ... 3 more fields]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "/*================================================================================*/\n", "// 八,求每一个簇的代表核心和簇元素数量\n", "/*================================================================================*/\n", "printlog(\"step8: evaluate cluster representation ...\")\n", "\n", "val rdd_cluster = dfclusters.rdd.map(row=> {\n", " val cluster_id = row.getAs[Long](\"cluster_id\")\n", " val s_geom = row.getAs[Point](\"s_geom\")\n", " val neighbour_cnt = row.getAs[Long](\"neighbour_cnt\")\n", " val id_set = row.getAs[WrappedArray[Long]](\"neighbour_ids\").toSet\n", " (cluster_id,(s_geom,neighbour_cnt,id_set))\n", "})\n", "\n", "val rdd_result = rdd_cluster.reduceByKey((a,b)=>{\n", " val id_set = a._3 | b._3\n", " val result = if(a._2>=b._2) (a._1,a._2,id_set)\n", " else (b._1,b._2,id_set)\n", " result\n", "})\n", "\n", "val dfresult = rdd_result.map(t=>{\n", " val cluster_id = t._1\n", " val representation_point = t._2._1\n", " val neighbour_points_cnt = t._2._2\n", " val id_set = t._2._3\n", " val cluster_points_cnt = id_set.size\n", " (cluster_id,representation_point,neighbour_points_cnt,cluster_points_cnt,id_set)\n", "}).toDF(\"cluster_id\",\"representation_point\",\"neighbour_points_cnt\",\"cluster_points_cnt\",\"id_set\")\n", "\n", "dfresult.show(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "注意到我们的结果中\n", "\n", "聚类簇数量为2个。\n", "\n", "噪声点数量为500-242-241 = 17个\n", "\n", "和调用sklearn中的结果完全一致。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](data/sklearn的DBSCAN聚类结果.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Spark - Scala", "language": "scala", "name": "spark_scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala", "pygments_lexer": "scala", "version": "2.11.12" } }, "nbformat": 4, "nbformat_minor": 2 }