Spark3 - ListenerBus 消息总线源码解读

Posted on Thu, Dec 30, 2021 Spark 源码 Scala

Spark 版本:3.1.0

1. ListenerBus 概述

ListenerBus 是 Spark 的消息总线接口,会维护一个 Listener 队列,并提供一个全局 Event 分发功能,将事件分发给注册了的 Listener,事件的具体处理逻辑则交由 Listener 自行实现,其继承结构如下:

Spark 2.3 开始为 Listener 添加了 Event 处理时间的统计功能,可以很方便的查看各个 Event 的处理时间,能够帮助开发员人快速定位瓶颈。

Listener 的时间统计功能通过 spark.scheduler.listenerbus.logSlowEventspark.scheduler.listenerbus.logSlowEvent.threshold 参数控制,需要 ListenerBus 子类自行实现 org.apache.spark.util.ListenerBus#getTimer 方法,此处不是重点,暂且不关注。

ListenerBus 核心代码如下:

 private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
   // 带时间统计的 Listener 队列
   private[this] val listenersPlusTimers = new CopyOnWriteArrayList[(L, Option[Timer])]
 
   // Marked `private[spark]` for access in tests.
   private[spark] def listeners = listenersPlusTimers.asScala.map(_._1).asJava
 
   // 添加 Listener
   final def addListener(listener: L): Unit = {
     listenersPlusTimers.add((listener, getTimer(listener)))
   }
 
   // 移除队列中的 Listener
   final def removeListener(listener: L): Unit = {
     listenersPlusTimers.asScala.find(_._1 eq listener).foreach { listenerAndTimer =>
       listenersPlusTimers.remove(listenerAndTimer)
     }
   }
 
   // 向队列中所有的 Listener 投递消息
   def postToAll(event: E): Unit = {
     // JavaConverters can create a JIterableWrapper if we use asScala.
     // However, this method will be called frequently. To avoid the wrapper cost, here we use
     // Java Iterator directly.
     val iter = listenersPlusTimers.iterator
     while (iter.hasNext) {
       val listenerAndMaybeTimer = iter.next()
       val listener = listenerAndMaybeTimer._1
       val maybeTimer = listenerAndMaybeTimer._2
       val maybeTimerContext = if (maybeTimer.isDefined) {
         maybeTimer.get.time()
       } else {
         null
       }
       lazy val listenerName = Utils.getFormattedClassName(listener)
       try {
         // 向 Listener 投递事件,该方法由各个 ListenerBus 的子类自行实现
         doPostEvent(listener, event)
         if (Thread.interrupted()) {
           // We want to throw the InterruptedException right away so we can associate the interrupt
           // with this listener, as opposed to waiting for a queue.take() etc. to detect it.
           throw new InterruptedException()
         }
       } catch {
         case ie: InterruptedException =>
           logError(s"Interrupted while posting to ${listenerName}. Removing that listener.", ie)
           removeListenerOnError(listener)
         case NonFatal(e) if !isIgnorableException(e) =>
           logError(s"Listener ${listenerName} threw an exception", e)
       } finally {
         if (maybeTimerContext != null) {
           val elapsed = maybeTimerContext.stop()
           if (logSlowEventEnabled && elapsed > logSlowEventThreshold) {
             logInfo(s"Process of event ${redactEvent(event)} by listener ${listenerName} took " +
               s"${elapsed / 1000000000d}s.")
           }
         }
       }
     }
   }
 
   // 子类需要实现 Listener 如何接收事件
   protected def doPostEvent(listener: L, event: E): Unit
 }

可以看到 ListenerBus 维护了一个带执行时间统计功能的队列 listenersPlusTimers 以及不带时间统计功能的队列 listeners ,本质上都是同一条队列,后者用于测试使用。

ListenerBus 提供了添加、移除 Listener 的常用功能,代码比较简单,此处不做赘述,重点关注 postToAll 方法。这个方法会将遍历队列中的所有 Listener,并调用子类的 doPostEvent 处理 Listener 与 event。

需要注意的是,每个 ListenerBus 的子类都有各自的 Listener 类型,代码中使用了一个 [L <: AnyRef 的泛型表示 ,并未强制做限定。

每个 ListenerBus 的泛型都会实现各自的 doPostEvent 方法,用于处理 Listener 与 Event,以 SparkListenerBus 为例,其 Listener 类型需要 SparkListenerInterface 特质:

 private[spark] trait SparkListenerBus
   extends ListenerBus[SparkListenerInterface, SparkListenerEvent] {
 
   protected override def doPostEvent(
       listener: SparkListenerInterface,
       event: SparkListenerEvent): Unit = {
     event match {
       case stageSubmitted: SparkListenerStageSubmitted =>
         listener.onStageSubmitted(stageSubmitted)
       case stageCompleted: SparkListenerStageCompleted =>
         listener.onStageCompleted(stageCompleted)
       case jobStart: SparkListenerJobStart =>
         listener.onJobStart(jobStart)
       // 后续代码省略...

这样做本质上是一种监听者模式,各个 ListenerBus 子类只需要在 doPostEvent 方法中专注于自己关心的 Event 即可,并且也方便后续扩展关心的事件。

2. Listener 概述

虽然 ListenerBus 维护了一组注册的 Listener,但并没有为其定义一个公共的 Listener 接口,从其定义也可以看出来:

 // org.apache.spark.util.ListenerBus
 private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
   ...
 }

虽然这里使用了一个泛型表示 Listener,但在 Spark 的实现中,许多 Listener 都是 SparkListenerInterface 的子类,除此之外还有 ExternalCatalogEventListener、StreamingListener、QueryExecutionListener 等。

每个 ListenerBus 都可以自定义 Listener 类型,这里以具体的消息总线 SparkListenerBus、事件监听器 SparkListener 以及事件类型 SparkListenerEvent 来举例说明。

查看 SparkListenerBus 定义如下:

 // org.apache.spark.scheduler.SparkListenerBus
 private[spark] trait SparkListenerBus
   extends ListenerBus[SparkListenerInterface, SparkListenerEvent] {
   ...
 }

可以看到 SparkListenerBus 持有的是一个 SparkListenerInterface 类型的 Listener,这个接口的实现类有很多,其中最常见的就是 SparkListener,其定义如下:

 @DeveloperApi
 abstract class SparkListener extends SparkListenerInterface {
   override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { }
 
   override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { }
   ...
 }

其中每个 onXXX 方法都对应了 Spark 在运行时的各个生命周期,每个生命周期关心的事件也不同,事件类型定义如下:

 @DeveloperApi
 case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent

可以看到,其实这些事件都是 SparkListenerEvent 的子类。

这里有个值得一提的地方是,整个 event 的

3. LiveListenerBus 与 AsyncEventQueue

Spark 通过 LiveListenerBus 进行 event 的异步投递,LiveListenerBus 构造如下:

 private var _listenerBus: LiveListenerBus = _
 _listenerBus = new LiveListenerBus(_conf)
NOTE: Spark 2.3 版本之前,LiveListenerBus 也是 ListenerBus 的子类,从 2.3 开始将其剥离了出来,存储多个 org.apache.spark.scheduler.AsyncEventQueue(SparkListenerBus 的子类)对象,并以异步的方式将事件投递到 AsyncEventQueue 中。

LiveListenerBus 中定义了四类消息总线:

   // org.apache.spark.scheduler.LiveListenerBus
   private[scheduler] val SHARED_QUEUE = "shared"
 
   private[scheduler] val APP_STATUS_QUEUE = "appStatus"
 
   private[scheduler] val EXECUTOR_MANAGEMENT_QUEUE = "executorManagement"
 
   private[scheduler] val EVENT_LOG_QUEUE = "eventLog"
   // ...
   def addToSharedQueue(listener: SparkListenerInterface): Unit = {
     addToQueue(listener, SHARED_QUEUE)
   }
 
   /** Add a listener to the executor management queue. */
   def addToManagementQueue(listener: SparkListenerInterface): Unit = {
     addToQueue(listener, EXECUTOR_MANAGEMENT_QUEUE)
   }
 
   /** Add a listener to the application status queue. */
   def addToStatusQueue(listener: SparkListenerInterface): Unit = {
     addToQueue(listener, APP_STATUS_QUEUE)
   }
 
   /** Add a listener to the event log queue. */
   def addToEventLogQueue(listener: SparkListenerInterface): Unit = {
     addToQueue(listener, EVENT_LOG_QUEUE)
   }

追踪 addToQueue 方法如下:

   private[spark] def addToQueue(
       listener: SparkListenerInterface,
       queue: String): Unit = synchronized {
     if (stopped.get()) {
       throw new IllegalStateException("LiveListenerBus is stopped.")
     }
     // 查找当前消息总线队列中是否有指定类型的 AsyncEventQueue
     queues.asScala.find(_.name == queue) match {
       // 有则直接往该 AsyncEventQueue 中添加添加 Listener
       case Some(queue) =>
         queue.addListener(listener)
       // 否则先创建对应的 AsyncEventQueue,添加 Listener,再添加到消息总线队列中
       case None =>
         val newQueue = new AsyncEventQueue(queue, conf, metrics, this)
         newQueue.addListener(listener)
         if (started.get()) {
           newQueue.start(sparkContext)
         }
         queues.add(newQueue)
     }
   }
 

AsyncEventQueue 是一个继承了 SparkListenerBus 的消息总线,基于 eventQueuedispatchThread 实现了 event 的异步投递。

这里选择 org.apache.spark.SparkContext#postApplicationStart 作为切入点查看 event 投递的过程。可以看到,SparkContext 调用了 LiveListenerBus 的 post() 方法:

   // org.apache.spark.SparkContext#postApplicationStart
   private def postApplicationStart(): Unit = {
     // Note: this code assumes that the task scheduler has been initialized and has contacted
     // the cluster manager to get an application ID (in case the cluster manager provides one).
     listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId),
       startTime, sparkUser, applicationAttemptId, schedulerBackend.getDriverLogUrls,
       schedulerBackend.getDriverAttributes))
     _driverLogger.foreach(_.startSync(_hadoopConfiguration))
   }

post() 方法实现如下:

   // org.apache.spark.scheduler.LiveListenerBus#post
   def post(event: SparkListenerEvent): Unit = {
     if (stopped.get()) {
       return
     }
 
     metrics.numEventsPosted.inc()
 
     // 如果发现事件队列为空,证明总线已经启动,可以直接投递 event
     if (queuedEvents == null) {
       postToQueues(event)
       return
     }
 
     // 否则,检查下总线是否启动,未启动则将 event 添加到事件队列中,并返回
     synchronized {
       if (!started.get()) {
         queuedEvents += event
         return
       }
     }
 
     // 如果上一步 check 期间总线启动了,则继续投递 event
     postToQueues(event)
   }

其中 postToQueues 实际上就是将事件投递到 AsyncEventQueue

   // org.apache.spark.scheduler.LiveListenerBus#postToQueues
   private def postToQueues(event: SparkListenerEvent): Unit = {
     // 遍历消息总线队列,传递 event
     val it = queues.iterator()
     while (it.hasNext()) {
       // 调用 AsyncEventQueue 的 post 方法
       it.next().post(event)
     }
   }

再来看下 AsyncEventQueuepost() 方法:

   def post(event: SparkListenerEvent): Unit = {
     if (stopped.get()) {
       return
     }
 
     eventCount.incrementAndGet()
     // 往事件队列中添加 event
     if (eventQueue.offer(event)) {
       return
     }
 
     // ...
   }

可以看到,AsyncEventQueue 会将接收到的事件放到 eventQueue 里,同时 AsyncEventQueue 在实例化时会启动一个线程去不断消费这个队列里的事件:

   // org.apache.spark.scheduler.AsyncEventQueue
   private val dispatchThread = new Thread(s"spark-listener-group-$name") {
     setDaemon(true)
     override def run(): Unit = Utils.tryOrStopSparkContext(sc) {
       dispatch()
     }
   }

通过 Spark UI 界面,也可以看到 spark-listener-group-$name 线程:

再看下线程中调用的 dispatch 方法:

   // org.apache.spark.scheduler.AsyncEventQueue#dispatch
   private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) {
     var next: SparkListenerEvent = eventQueue.take()
     // 当事件队列中获取到的事件不为 POISON_PILL 时,循环消费事件队列中的事件
     while (next != POISON_PILL) {
       val ctx = processingTime.time()
       try {
         // 调用父类 ListenerBus 的事件投递方法向所有 Listener 投递 event
         super.postToAll(next)
       } finally {
         ctx.stop()
       }
       eventCount.decrementAndGet()
       next = eventQueue.take()
     }
     eventCount.decrementAndGet()
   }

此处注意 dispatch() 方法里的 super.postToAll(next) 是调用了父类 ListenerBus 的事件投递方法,将 event 分发给了其他所有 Listener。

至此,一个 event 的投递就算完成了。

4. 如何自定义消息总线

Spark 对消息总线进行了高度封装,再回顾下 org.apache.spark.SparkContext#postApplicationStart 方法:

   /** Post the application start event */
   private def postApplicationStart(): Unit = {
     // Note: this code assumes that the task scheduler has been initialized and has contacted
     // the cluster manager to get an application ID (in case the cluster manager provides one).
     listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId),
       startTime, sparkUser, applicationAttemptId, schedulerBackend.getDriverLogUrls,
       schedulerBackend.getDriverAttributes))
     _driverLogger.foreach(_.startSync(_hadoopConfiguration))
   }

可以看到,Spark 内部在提交 event 的时候,不同 event 封装了不同的信息,最终会被投递给其他 Listener。

因此如果想要获取这些 event,方式十分简单,只需要继承一个 Listener 接口,重写关注的阶段,注册到 SparkContext 中即可。

继续以 SparkListener 为例,我们只需要在 SparkContext 中添加自定义 Listener,并重写生命周期方法。这里对 onTaskEnd 方法进行重写:

   def main(args: Array[String]): Unit = {
     val spark = SparkSession
       .builder()
       .appName("zkx-test1")
       .master("local[2]")
       .getOrCreate()
     spark.sparkContext.addSparkListener(new SparkListener {
       override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = println("Task End...")
     })
     spark.sql("SELECT 1").show()
   }

执行后输出如下:

 ...
 21/06/12 15:57:26 INFO DAGScheduler: Job 0 finished: show at TransportRPCTest.scala:19, took 0.823017 s
 Task End...
 21/06/12 15:57:26 INFO CodeGenerator: Code generated in 21.0197 ms
 +---+
 |  1|
 +---+
 |  1|
 +---+
 ...