基本原理

全阶段代码生成,即WSCG。它指的是基于同一 Stage 内操作符之间的调用关系,把多个 RDD 的 compute 函数捏合成一个,生成一份“手写代码”
然后把这一个函数一次性地作用在输入数据上,真正把所有计算融合为一个统一的函数

全阶段代码生成具有以下的优势:

  • 消除虚函数调度。
  • 将中间数据从内存移动到 CPU 寄存器。
  • 利用现代 CPU 功能循环展开和使用 SIMD。通过向量化技术,引擎将加快对复杂操作代码生成运行的速度。

Catalyst全阶段代码生成的入口是CollapseCodegenStages规则,它的注入是在代码执行前的preparations阶段。

1
2
3
4
5
6
7
8
9
  def apply(plan: SparkPlan): SparkPlan = {
    if (conf.wholeStageEnabled && CodegenObjectFactoryMode.withName(conf.codegenFactoryMode)
      != CodegenObjectFactoryMode.NO_CODEGEN) {
      insertWholeStageCodegen(plan)
    } else {
      plan
    }
  }
}

如果开启了下面这个参数,才会生成代码

1
spark.sql.codegen.wholeStage

预处理部分

插入WholeStageCodegenExec算子需要满足以下条件:

  • 算子的输出个数为1且数据类型为ObjectType,表明其不是unsafe row,跳过,迭代判断其孩子节点。
  • plan为LocalTableScanExec的不进行处理。
  • plan为CommandResultExec的不进行处理。
  • plan为CodegenSupport的还需判断其所有表达式是否都支持Codegen, 当前plan和其孩子plan的schema的字段个数是否超过了conf.wholeStageMaxNumFields(默认100)。
  • 另外需要注意的是whole-stage-codegen是基于row的,如果plan支持columnar, 则不能同时支持全阶段代码生成。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
  private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = {
    plan match {
      // For operators that will output domain object, do not insert WholeStageCodegen for it as
      // domain object can not be written into unsafe row.
      case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
        plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
      case plan: LocalTableScanExec =>
        // Do not make LogicalTableScanExec the root of WholeStageCodegen
        // to support the fast driver-local collect/take paths.
        plan
      case plan: CommandResultExec =>
        // Do not make CommandResultExec the root of WholeStageCodegen
        // to support the fast driver-local collect/take paths.
        plan
      case plan: CodegenSupport if supportCodegen(plan) =>
        // The whole-stage-codegen framework is row-based. If a plan supports columnar execution,
        // it can't support whole-stage-codegen at the same time.
        assert(!plan.supportsColumnar)
        WholeStageCodegenExec(insertInputAdapter(plan))(codegenStageCounter.incrementAndGet())
      case other =>
        other.withNewChildren(other.children.map(insertWholeStageCodegen))
    }
  }

对于不支持的,或者 SortMergeJoinExec、ShuffledHashJoinExec 会插入 InputAdapter

InputAdapter 作用

  • 他们都可以看作一个codegen的分割点,可将整个物理计划拆分成多个代码段
  • 而 InputAdapter节点可以看作是对应 WholeStageCodegenExec所包含子树的叶 子节点,起到 InternalRow的数据输入作用
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
  private def insertInputAdapter(plan: SparkPlan): SparkPlan = {
    plan match {
      case p if !supportCodegen(p) =>
        // collapse them recursively
        InputAdapter(insertWholeStageCodegen(p))
      case j: SortMergeJoinExec =>
        // The children of SortMergeJoin should do codegen separately.
        j.withNewChildren(j.children.map(
          child => InputAdapter(insertWholeStageCodegen(child))))
      case j: ShuffledHashJoinExec =>
        // The children of ShuffledHashJoin should do codegen separately.
        j.withNewChildren(j.children.map(
          child => InputAdapter(insertWholeStageCodegen(child))))
      case p => p.withNewChildren(p.children.map(insertInputAdapter))
    }
  }

3

模拟程序

用来 模拟 code-gen 的代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
  test("mock code-gen") {
    import org.apache.spark.sql.catalyst.expressions.{
      And, AttributeReference, GreaterThan, LessThan, Literal}
    import org.apache.spark.sql.execution.{CollapseCodegenStages, FilterExec, RangeExec}
    import org.apache.spark.sql.types.LongType

    val ref = AttributeReference("id", LongType)()
    val and = And(
      GreaterThan(ref, new Literal(0L, LongType)),
      LessThan(ref, new Literal(10L, LongType))
    )
    val range = RangeExec(
      new org.apache.spark.sql.catalyst.plans.logical.Range(0, 10, 1, None)
    )
    val filter = FilterExec(and, range)
    val c = CollapseCodegenStages()
    val p = c.apply(filter)
    println(p)
  }

打印结果:

1
2
*(1) !Filter ((id#0L > 0) AND (id#0L < 10))
+- *(1) Range (0, 10, step=1, splits=16)

递归的从根节点一直遍历到子节点,如果有不合适的则插入 InputAdapter 1

代码生成骨架部分

各个算子的 doProduce(),只是简单的继续调用其子类的 produce()

1
2
3
  protected override def doProduce(ctx: CodegenContext): String = {
    child.asInstanceOf[CodegenSupport].produce(ctx, this)
  }

CodegenSupport 的主要函数

produce()是在父接口 CodegenSupport 中定义的,它会继续调用子节点的 doProduce
于是一层层的递归调用下去,直到子节点,或者 shuffle 边界

1
2
3
4
5
6
7
8
  final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
    this.parent = parent
    ctx.freshNamePrefix = variablePrefix
    s"""
       |${ctx.registerComment(s"PRODUCE: ${this.simpleString(conf.maxToStringFields)}")}
       |${doProduce(ctx)}
     """.stripMargin
  }

CodegenSupport#consume的逻辑:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
  final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
    val inputVarsCandidate =
      if (outputVars != null) {
        assert(outputVars.length == output.length)
        // outputVars will be used to generate the code for UnsafeRow, so we should copy them
        outputVars.map(_.copy())
      } else {
        assert(row != null, "outputVars and row cannot both be null.")
        ctx.currentVars = null
        ctx.INPUT_ROW = row
        output.zipWithIndex.map { case (attr, i) =>
          BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
        }
      }

    val inputVars = inputVarsCandidate match {
      case stream: Stream[ExprCode] => stream.force
      case other => other
    }

    val rowVar = prepareRowVar(ctx, row, outputVars)

    // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
    // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
    // generate code of `rowVar` manually.
    ctx.currentVars = inputVars
    ctx.INPUT_ROW = null
    ctx.freshNamePrefix = parent.variablePrefix
    val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)

    // Under certain conditions, we can put the logic to consume the rows of this operator into
    // another function. So we can prevent a generated function too long to be optimized by JIT.
    // The conditions:
    // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
    // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
    //    all variables in output (see `requireAllOutput`).
    // 3. The number of output variables must less than maximum number of parameters in Java method
    //    declaration.
    val confEnabled = conf.wholeStageSplitConsumeFuncByOperator
    val requireAllOutput = output.forall(parent.usedInputs.contains(_))
    val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
    val consumeFunc = if (confEnabled && requireAllOutput
        && CodeGenerator.isValidParamLength(paramLength)) {
      constructDoConsumeFunction(ctx, inputVars, row)
    } else {
      parent.doConsume(ctx, inputVars, rowVar)
    }
    s"""
       |${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")}
       |$evaluated
       |$consumeFunc
     """.stripMargin
  }

这里会继续调用父类的doConsume,完成进一步代码填充
需要注意的是,这里判断了代码是否太大,如果太大,则将调用逻辑封装为一个函数,然后直接调用这个函数

WholeStageCodegenExec 的 doExecute 流程
代码生成可以看作是两个方向相反的递归过程:

  • 代码的整体框架由 produce/doProduce 方法负责,父节点调用子节点
  • 代码具体处理逻辑由 consume/doConsume 方法负责,由子节点调用父节点
  • 整个物理算子树的执行过程被InputAdapter分隔开

详细分析

简单查询执行过程

下面这个SQL

1
SELECT  id  FROM people WHERE id > 0 AND id < 100

InputRDDCodegen#doProduce主要逻辑:

1
2
3
4
5
6
7
8
    s"""
       | while ($limitNotReachedCond $input.hasNext()) {
       |   InternalRow $row = (InternalRow) $input.next();
       |   ${updateNumOutputRowsMetrics}
       |   ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim}
       |   ${shouldStopCheckCode}
       | }
     """.stripMargi

FilterExec#doCosume()的主要生成逻辑:

1
2
3
4
5
6
7
    s"""
       |do {
       |  $predicateCode
       |  $numOutput.add(1);
       |  ${consume(ctx, resultVars)}
       |} while(false);
     """.stripMargin

生成的核心代码片段: 可读性不好,但是也大概能看出来,是计算 id > 0 AND id < 100

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
boolean rdd_isNull_1 = rdd_row_0.isNullAt(1);
long rdd_value_1 = rdd_isNull_1 ?
  -1L : (rdd_row_0.getLong(1));

boolean filter_value_2 = !rdd_isNull_1;
if (!filter_value_2) continue;


boolean filter_value_3 = false;
        filter_value_3 = rdd_value_1 > 0L;
if (!filter_value_3) continue;
boolean filter_value_8 = !rdd_isNull_1;
if (!filter_value_8) continue;


boolean filter_value_9 = false;
        filter_value_9 = rdd_value_1 < 100L;
if (!filter_value_9) continue;

ProjectExec#doConsume()主要逻辑:

1
2
3
4
5
6
7
    s"""
       |// common sub-expressions
       |${evaluateVariables(localValInputs)}
       |$subExprsCode
       |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
       |${consume(ctx, resultVars)}
     """.stripMargin

生成的 时序图:
2

生成过程:

  • 从 WholeStageCodegenExec#doExecute出发,在这里调用 doCodeGen 开始生成代码
  • 调用 CodegenSupport#produce(),这里只是产生一些metrics的逻辑,然后调用doProduce
  • 每个具体算子中 doProduce 只是简单的调用子节点的 produce(),于是由回到CodegenSupport,继续递归往下调用
  • 到叶子节点 RDDScanExec时候,其父类 InputRDDCodegen 开始生成骨架代码,然后调用CodegenSupport#consume
  • consume中调用每个具体算子的 doConsume,也是继续递归调用到 FilterExec的 doConsume,ProjectExec#doConsume
  • 叶子节点RDD 算子只是一个 while循环,读取读取数据,然后调用子节点 Filter,Project等来处理这些数据
  • 最后调用到最上层节点 WholeStageCodegenExec#doConsume 获取结果
  • 整个过程是从 WholeStageCodegenExec 开始递归往下调用的,一直到叶子节点(读取数据的节点)
  • 在读取数据的节点生成骨架代码,然后又递归调用父节点,填充骨架代码

另一个视角 4
spark性能优化

ProjectExec + FilterExec + ScanRDDExec 生成的代码片段:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
 while ( rdd_input_0.hasNext()) {
   InternalRow rdd_row_0 = (InternalRow) rdd_input_0.next();
   ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
   do {
  
boolean rdd_isNull_1 = rdd_row_0.isNullAt(1);
long rdd_value_1 = rdd_isNull_1 ?
  -1L : (rdd_row_0.getLong(1));

boolean filter_value_2 = !rdd_isNull_1;
if (!filter_value_2) continue;


boolean filter_value_3 = false;
        filter_value_3 = rdd_value_1 > 0L;
if (!filter_value_3) continue;
boolean filter_value_8 = !rdd_isNull_1;
if (!filter_value_8) continue;


boolean filter_value_9 = false;
        filter_value_9 = rdd_value_1 < 100L;
if (!filter_value_9) continue;
   
  ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* numOutputRows */).add(1);
  
// common sub-expressions

UTF8String project_value_0;
if (false) {
  project_value_0 = UTF8String.fromString("NULL");
} else {
  project_value_0 = UTF8String.fromString(String.valueOf(rdd_value_1));
}
rdd_mutableStateArray_0[2].reset();


rdd_mutableStateArray_0[2].write(0, project_value_0);
append((rdd_mutableStateArray_0[2].getRow()));
      
} while(false);
   if (shouldStop()) return;
 }

JOIN

Broadcast Hash Join

boradcast-hash-join 的待生成 code-gen 时的物理计划

  • 到了 BroadcastHashJoinExec 时候分为左右两个节点
  • 其中右边节点是 RowDataSourceScanExec,在下面是 JDBCRDD
  • 左边节点是 InputAdapter,相当于对生成代码做了分割
  • 这部分跟非code-gen 的 bhj 类似,使用了 BroadcastQueryStageExec 再委托BroadcastExchangeExec 获取广播数据
  • 广播exchange 下面又是 WholeStageCodegenExec,下面是又是代码生成部分了 8
1
2
3
4
5
6
7
8
HashJoin 中先建立好 build 表,然后将这个 build 放到 拼接代码中

for(遍历 流表的 RDD) {
    执行 bhj,探测 build 表
    if (找到了匹配值) {
        做投影,再输出结果
    }
}

获取 build 表的时序图,这里跟上面的依赖父子依赖关系展示的类似
读取数据也是通过 代码生成的
然后通过广播的方式发送出去,这里还使用了线程池,等待超时等机制
9

bhj 的总体流程

  • 跟普通的查询的时序图差不多
    -produce()和consume()也是委托给 CodegenSupport 来调用的
  • 先读取数据,也就是InputRDDCode部分,生成读取 JDBCRDD的代码
  • 然后调用 bhj,执行join逻辑,在 join 中拿到事先构建好的 build 表去判断
  • 如果有匹配的则输出 10

Shuffle Hash Join

shuffle-hash-join 的待生成 code-gen 时的物理计划
总体来说,跟 bhj 类似,只是左右两个子节点都增加了 InputAdapter,作为code-gen 的分割
这两个子节点都是通过广播获取的
11

时序图

  • 总体来说,跟 bhj 非常类似
  • 除了把 执行的 join exec 换成了 ShuffledHashJoinExec,其他调用过程都是一样
  • 执行 join 的code-gen是父类 HashJoin 完成的 12

Shuffle Sort Merge Join

待生成 code-gen 时的物理计划
13 上图中

  • 总体比 bhj,shj 要复杂很多
  • sort-merge-join 左右两边都是 InputAdapter,对code-gen做了分割
  • 这之后又调用了 SortExec,所以再次增加了 InputAdapter 做了分割
  • 之后就是 shuffle逻辑,跟 shj 差不多了
  • 这里一共会生成 5个代码片段
  • sort-merge-join 一个,左边的 sort,右边的 sort,以及左右的读取数据

执行的时序图如下:

  • 相比 bhj,shj 要简单不少
  • SortMergeJoinExec 左右节点都是 InputAdapter 类型的,所以到它这里就结束了 14

SortMergeJoinExec 的 codegenInner 的生成代码逻辑:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
    s"""
       |while ($findNextJoinRows) {
       |  $beforeLoop
       |  while ($matchIterator.hasNext()) {
       |    InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
       |    $conditionCheck
       |    $outputRow
       |  }
       |  if (shouldStop()) return;
       |}
       |$eagerCleanup
     """.stripMargin

SortMergeJoinExec + ProjectExec 生成 的代码片段:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
while (smj_findNextJoinRows_0(smj_streamedInput_0, smj_bufferedInput_0)) {
	boolean smj_isNull_2 = false;
	int smj_value_4 = -1;
           
	boolean smj_isNull_3 = false;
	UTF8String smj_value_5 = null;
           
	smj_isNull_2 = smj_streamedRow_0.isNullAt(0);
	smj_value_4 = smj_isNull_2 ? -1 : (smj_streamedRow_0.getInt(0));
	smj_isNull_3 = smj_streamedRow_0.isNullAt(1);
	smj_value_5 = smj_isNull_3 ? null : (smj_streamedRow_0.getUTF8String(1));
	scala.collection.Iterator<UnsafeRow> smj_iterator_0 = smj_matches_0.generateIterator();
       
	while (smj_iterator_0.hasNext()) {
		InternalRow smj_bufferedRow_1 = (InternalRow) smj_iterator_0.next();
		((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
		
		// common sub-expressions
		UTF8String project_value_0;
		if (smj_isNull_2) {
			project_value_0 = UTF8String.fromString("NULL");
		} else {
			project_value_0 = UTF8String.fromString(String.valueOf(smj_value_4));
		}
		UTF8String project_value_2;
		
		if (smj_isNull_3) {
			project_value_2 = UTF8String.fromString("NULL");
		} else {
			project_value_2 = smj_value_5;
		}
		boolean smj_isNull_4 = smj_bufferedRow_1.isNullAt(0);
		int smj_value_6 = smj_isNull_4 ? -1 : (smj_bufferedRow_1.getInt(0));
		
		UTF8String project_value_4;
		if (smj_isNull_4) {
			project_value_4 = UTF8String.fromString("NULL");
		} else {
			project_value_4 = UTF8String.fromString(String.valueOf(smj_value_6));
		}
		boolean smj_isNull_5 = smj_bufferedRow_1.isNullAt(1);
		UTF8String smj_value_7 = smj_isNull_5 ? null : (smj_bufferedRow_1.getUTF8String(1));
		UTF8String project_value_6;

		if (smj_isNull_5) {
			project_value_6 = UTF8String.fromString("NULL");
		} else {
			project_value_6 = smj_value_7;
		}
		
		smj_mutableStateArray_0[1].reset();
		smj_mutableStateArray_0[1].write(0, project_value_0);
		smj_mutableStateArray_0[1].write(1, project_value_2);
		smj_mutableStateArray_0[1].write(2, project_value_4);
		smj_mutableStateArray_0[1].write(3, project_value_6);
		append((smj_mutableStateArray_0[1].getRow()).copy());    
	} //  end inner while`
	
	if (shouldStop()) return;
} //  end outer while

BroadcastNestedLoopJoin

跟普的 doExecute 很类似,广播然后用 nested loop 实现的

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
    (joinType, buildSide) match {
      case (_: InnerLike, _) => codegenInner(ctx, input)
      case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input)
      case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true)
      case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false)
      case _ =>
        throw new IllegalArgumentException(
          s"BroadcastNestedLoopJoin code-gen should not take neither $joinType as the JoinType " +
          s"nor $buildSide as the BuildSide")
    }
  }

inner join 的code-gen 函数如下:就是 一个 for 里面再不断 probe hash-table

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
  private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
    val (_, buildRowArrayTerm) = prepareBroadcast(ctx)
    val (buildRow, checkCondition, buildVars) = getJoinCondition(ctx, input, streamed, broadcast)

    val resultVars = buildSide match {
      case BuildLeft => buildVars ++ input
      case BuildRight => input ++ buildVars
    }
    val arrayIndex = ctx.freshName("arrayIndex")
    val numOutput = metricTerm(ctx, "numOutputRows")

    s"""
       |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
       |  UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex];
       |  $checkCondition {
       |    $numOutput.add(1);
       |    ${consume(ctx, resultVars)}
       |  }
       |}
     """.stripMargin
  }

CartesianProduct 没有 code-gen

参考