Spark Broadcast 的实现,c,,

Posted on Sun, Dec 26, 2021 Spark 源码 Scala

总体流程

Spark 版本:3.1.0

获取 -> TorrentBroadcast(_value) -> BroadcastManager(cachedValues) -> BlockManager(本地 + 远端)

思维导图:

BroadcastManager 的创建

从下面一段代码出发:

     val spark = SparkSession
       .builder
       .master("local[3]")
       .appName("Spark Pi")
       .getOrCreate()
     // 创建广播变量
     val name = spark.sparkContext.broadcast("zkx")
     spark.range(10).rdd.foreach(x => println(name.value + "-" + x))
     spark.stop()

其中 spark.sparkContext.broadcast("zkx") 方法如下:

   // org.apache.spark.SparkContext#broadcast
   def broadcast[T: ClassTag](value: T): Broadcast[T] = {
     assertNotStopped()
     require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
       "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
     // 创建广播变量
     val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
     val callSite = getCallSite
     logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
     cleaner.foreach(_.registerBroadcastForCleanup(bc))
     bc
   }

可以从 env.broadcastManager 中看出广播变量的管理类是 BroadcastManager,并且是在 SparkEnv 中创建的,其创建方式如下:

 val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

除了传入的三个参数外,BroadcastManager 内部还有以下成员变量:

BroadcastManager 在创建时会调用 initialize() 方法,创建一个 TorrentBroadcastFactory 类型的 broadcastFactory 实例,并将 initialized 标记为 true:

   // org.apache.spark.broadcast.BroadcastManager#initialize
   // Called by SparkContext or Executor before using Broadcast
   private def initialize(): Unit = {
     synchronized {
       if (!initialized) {
         broadcastFactory = new TorrentBroadcastFactory
         // 内部实际是一个空方法
         broadcastFactory.initialize(isDriver, conf, securityManager)
         initialized = true
       }
     }
   }

至此,BroadcastManager 的创建就完成了,在 Driver/Executor 创建 SparkContext 时,会初始化 SparkEnv,创建 BroadcastManager,并触发其初始化方法。

Driver 端广播变量的创建

继续跟进 env.broadcastManager.newBroadcast[T](value, isLocal),其实现如下:

   // org.apache.spark.broadcast.BroadcastManager#newBroadcast
   def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
     val bid = nextBroadcastId.getAndIncrement()
     value_ match {
       case pb: PythonBroadcast => pb.setBroadcastId(bid)
       case _ => // do nothing
     }
     broadcastFactory.newBroadcast[T](value_, isLocal, bid)
   }

可以看到调用初始化时创建的 TorrentBroadcastFactory 实例的 newBroadcast(...),并返回一个 TorrentBroadcast

  // org.apache.spark.broadcast.TorrentBroadcastFactory#newBroadcast 
  override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
     new TorrentBroadcast[T](value_, id)
   }

可以看到,广播变量实际上是一个 TorrentBroadcast 实例,该类是 Broadcast 的一个子类。

除了将需要广播的对象以及广播变量 ID 传给 TorrentBroadcast 外,其内部其他成员变量如下:

其中 numBlocks 会调用 writeBlocks(obj) 方法,方法内部会将需要广播的值写入 BlockManager:

// org.apache.spark.broadcast.TorrentBroadcast#writeBlocks
  /**
   * Divide the object into multiple blocks and put those blocks in the block manager.
   *
   * @param value the object to divide
   * @return number of blocks this broadcast variable is divided into
   */
  private def writeBlocks(value: T): Int = {
    import StorageLevel._
    // 获取 BlockManager
    val blockManager = SparkEnv.get.blockManager
    // 将广播变量值以对象的方式存到 BlockManager 中,以避免在 Driver 端重复创建广播变量值
    if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
      throw new SparkException(s"Failed to store $broadcastId in BlockManager")
    }
    try {
      // 按照 blockSize 大小分配 ByteBuffer,将广播变量写入 ByteBuffer 数组
      val blocks =
        TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
      if (checksumEnabled) {
        checksums = new Array[Int](blocks.length)
      }
      // 根据 ByteBuffer 个数构造 BroadcastBlockId,并作为 key 存到 BlockManager 中
      blocks.zipWithIndex.foreach { case (block, i) =>
        if (checksumEnabled) {
          checksums(i) = calcChecksum(block)
        }
        val pieceId = BroadcastBlockId(id, "piece" + i)
        val bytes = new ChunkedByteBuffer(block.duplicate())
        // 以多个 Block 的形式存储到 BlockManager 中
        if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
          throw new SparkException(s"Failed to store $pieceId of $broadcastId " +
            s"in local BlockManager")
        }
      }
      blocks.length
    } catch {
      // ...
    }
  }

可以看到 Driver 端在一开始就会将广播变量以对象的形式存到 BlockManager 中,以及以 ChunkedByteBuffer 的形式(内部是一个 ByteBuffer 数组)存储。

至此,Driver 端的广播变量便已经生成完毕,在内存中分别以对象以及一系列的块(内部是 ByteBuffer 数组)的方式存在

NOTE: BlockManager 如何管理、存储数据暂不深入,待后续再进行探究。

广播变量的获取

广播变量可以通过 value() 方法获取:

 spark.range(10).rdd.foreach(x => println(name.value + "-" + x))

value() 方法实现如下:

   // org.apache.spark.broadcast.Broadcast#value
   def value: T = {
     assertValid()
     getValue()
   }

由前面广播变量的创建可知,广播变量的实际类型为 TorrentBroadcast,getValue() 的实现也是在 TorrentBroadcast 里的:

   // org.apache.spark.broadcast.TorrentBroadcast#getValue
   override protected def getValue() = synchronized {
     val memoized: T = if (_value == null) null.asInstanceOf[T] else _value.get
     if (memoized != null) {
       memoized
     } else {
       // 当广播变量值为空时,重新构建获取广播变量值
       val newlyRead = readBroadcastBlock()
       _value = new SoftReference[T](newlyRead)
       newlyRead
     }
   }

其中 readBroadcastBlock 实现如下:

// org.apache.spark.broadcast.TorrentBroadcast#readBroadcastBlock
  private def readBroadcastBlock(): T = Utils.tryOrIOException {
    // 基于 BroadcastID 加锁
    TorrentBroadcast.torrentBroadcastLock.withLock(broadcastId) {
      val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
      // 如果无法从 cachedValues 中获取到广播值,则从 BlockManager 中获取
      Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
        setConf(SparkEnv.get.conf)
        val blockManager = SparkEnv.get.blockManager
        // 尝试从本地的 memoryStore 或者 diskStore 中根据 blockId 获取广播变量值
        blockManager.getLocalValues(broadcastId) match {
		      // 如果能从本地获取到广播变量值,则存到 cachedValues 中            
          case Some(blockResult) =>
            if (blockResult.data.hasNext) {
              val x = blockResult.data.next().asInstanceOf[T]
              releaseBlockManagerLock(broadcastId)
              if (x != null) {
                broadcastCache.put(broadcastId, x)
              }
              x
            } else {
              throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
            }
          // 否则从远端按照块拉取并合并生成 Java 对象,再存到 BlockManager 中
          case None =>
            val estimatedTotalSize = Utils.bytesToString(numBlocks * blockSize)
            logInfo(s"Started reading broadcast variable $id with $numBlocks pieces " +
              s"(estimated total size $estimatedTotalSize)")
            val startTimeNs = System.nanoTime()
            // 按块向远端拉取广播变量,并以字节数组方式存到 BlockManager 中
            val blocks = readBlocks()
            logInfo(s"Reading broadcast variable $id took ${Utils.getUsedTimeNs(startTimeNs)}")

            try {
              // 合并 blocks 生成 Java 对象
              val obj = TorrentBroadcast.unBlockifyObject[T](
                blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
              // 以 Object 的方式存到 BlockManager 中,同个 executor 上后续的 task 可以复用
              val storageLevel = StorageLevel.MEMORY_AND_DISK
              if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
                throw new SparkException(s"Failed to store $broadcastId in BlockManager")
              }
			        // cachedValue 同样存一份
              if (obj != null) {
                broadcastCache.put(broadcastId, obj)
              }

              obj
            } finally {
              blocks.foreach(_.dispose())
            }
        }
      }
    }
  }

需要注意的是,BlockManager 是根据 BlockId 获取对象的,而广播变量的 BroadcastId 是 BlockId 的子类。

关于 readBlocks 方法,有一点需要注意的是当前 Executor 上的 BlockManager 会向 Driver 端的 BlockManagerMaster 获取 Block 的位置信息,如果相同节点上其他 Executor 有该 Block,则直接从本地节点拉取即可,否则再向 Driver 端拉取。