例子

在 resouces 目录下新建

  • META-INF 目录
  • 再创建 services 目录
  • 创建 org.apache.spark.sql.sources.DataSourceRegister
  • 里面填入需要注册的实现类

比如

1
org.apache.spark.sql.XXRelationProvider

一个自定义的数据源实现类

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22

import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JdbcRelationProvider}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class XXRelationProvider extends RelationProvider with CreatableRelationProvider with DataSourceRegister {
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]):
  BaseRelation = {
    // Add custom handling for your `xx` parameters if needed
    val jdbcParameters = parameters + ("driver" -> "com.mysql.cj.jdbc.Driver")
    new JdbcRelationProvider().createRelation(sqlContext, jdbcParameters)
  }

  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], 
  data: DataFrame): BaseRelation = {
    val jdbcParameters = parameters + ("driver" -> "com.mysql.cj.jdbc.Driver")
    new JdbcRelationProvider().createRelation(sqlContext, mode, jdbcParameters, data)
  }

  override def shortName(): String = {
    "xx"
  }
}

启动

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def start(): Unit = {
    val spark = SparkSession.builder()
      .appName("Custom Data Source Example")
      .master("local[*]")
      .getOrCreate()

    // Reading data using the custom data source
    val df = spark.read
      .format("xx")
      .option("url", "jdbc:mysql://localhost:3306/test")
      .option("dbtable", "tt")
      .option("user", "root")
      .option("password", "123456")
      .load()

    df.show()
	
	df.write
      .format("xx")
      .option("url", "jdbc:mysql://localhost:3306/test")
      .option("dbtable", "tt")
      .option("user", "root")
      .option("password", "123456")
      .mode(SaveMode.Append)
      .save()
  }

解析

DataSourceRegister的继承体系

DataSourceRegister的子类及其别名

实现过程

  • 在 org.apache.spark.sql.execution.datasources.DataSource 中查找的
  • 包含了很多默认的实现
  • 先查找默认的,再查找 xx.DefualtProvider,最后查找文件中定义的类
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
private val backwardCompatibilityMap: Map[String, String] = {
    val jdbc = classOf[JdbcRelationProvider].getCanonicalName
    val json = classOf[JsonFileFormat].getCanonicalName
    val parquet = classOf[ParquetFileFormat].getCanonicalName
    val csv = classOf[CSVFileFormat].getCanonicalName
    val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat"
    val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
    val nativeOrc = classOf[OrcFileFormat].getCanonicalName
    val socket = classOf[TextSocketSourceProvider].getCanonicalName
    val rate = classOf[RateStreamProvider].getCanonicalName

    Map(
      "org.apache.spark.sql.jdbc" -> jdbc,
      "org.apache.spark.sql.jdbc.DefaultSource" -> jdbc,
      "org.apache.spark.sql.execution.datasources.jdbc.DefaultSource" -> jdbc,
      "org.apache.spark.sql.execution.datasources.jdbc" -> jdbc,
      "org.apache.spark.sql.json" -> json,
      "org.apache.spark.sql.json.DefaultSource" -> json,
      "org.apache.spark.sql.execution.datasources.json" -> json,
      "org.apache.spark.sql.execution.datasources.json.DefaultSource" -> json,
      "org.apache.spark.sql.parquet" -> parquet,
      "org.apache.spark.sql.parquet.DefaultSource" -> parquet,
      "org.apache.spark.sql.execution.datasources.parquet" -> parquet,
      "org.apache.spark.sql.execution.datasources.parquet.DefaultSource" -> parquet,
      "org.apache.spark.sql.hive.orc.DefaultSource" -> orc,
      "org.apache.spark.sql.hive.orc" -> orc,
      "org.apache.spark.sql.execution.datasources.orc.DefaultSource" -> nativeOrc,
      "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc,
      "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
      "org.apache.spark.ml.source.libsvm" -> libsvm,
      "com.databricks.spark.csv" -> csv,
      "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
      "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
    )
  }

spark.read实现 在SparkSession中是这样的

1
def read: DataFrameReader = new DataFrameReader(self)

所以是委托给 DataFrameReader

  • 做 format
  • option
  • load,加载实现类

load 的实现,就是用 DataSource 去查找的

1
2
3
4
    DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider =>
      DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions,
        source, paths: _*)
    }.getOrElse(loadV1Source(paths: _*))

DataFrame 的内部表示

1
type DataFrame = Dataset[Row]

这个Dataset 包含了

  • 各种 sql 的算子,如 join,agg,select,group 等
  • toDF,schema,filter,writeTo 等等

df.write 的实现

1
2
3
4
5
6
7
8
  def write: DataFrameWriter[T] = {
    if (isStreaming) {
      logicalPlan.failAnalysis(
        errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
        messageParameters = Map("methodName" -> toSQLId("write")))
    }
    new DataFrameWriter[T](this)
  }

save实现

  • 判断是 V1 数据源,还是 V2的
  • 根据 save 的模式,以及是否 truncate
  • 选择是删除表再新建插入,还是直接插入

在save 中

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
// 会获取 v2体系的 table catalog
val catalog: TableCatalog = CatalogV2Util.getTableProviderCatalog(
                supportsExtract, catalogManager, dsOptions)
				
// 会获取 v2 体系的 table
val table: Table = catalog.loadTable(ident)

// 还会创建 v2 的relation
val relation: DataSourceV2Relation = DataSourceV2Relation
			.create(table, catalog, ident, dsOptions)
			
// 会创建 execution
new QueryExecution(session, command, df.queryExecution.tracker)

// 调用 execution 的 runComand
runCommand(df.sparkSession) {
  AppendData.byName(relation, df.logicalPlan, finalOptions)
}

还会用的 DataSourceV2ScanRelation DataSourceV2Relation

CheckpointRDDPartition抽象类

  • LocalCheckpointRDD
  • ReliableCheckpointRDD

ReliableCheckpointRDD

  • 包含了 读、写 checkpoint
  • 写checkpoint 拿到分区信息,然后写每个分区的 checkpoint
  • 这个写的动作会交给 SparkSession#runJob()执行
  • 之后 DAGScheduler 会提交这个 job,调度到 StandaloneSchedulerBackend,最后发布到 executor中执行

SparkSession的属性如下

  • sparkContext:即SparkContext。
  • sharedState:在多个SparkSession之间共享的状态(包括SparkContext、缓存的数据、监听器及与外部系统交互的字典信息)。
  • sessionState:SparkSession的状态(SessionState)。SessionState中保存着SparkSession指定的状态信息。
  • sqlContext:即SQLContext。SQLContext是Spark SQL的上下文信息。
  • conf:类型为RuntimeConfig,是Spark运行时的配置接口类

一个完整的例子

一个 world count 例子

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;

public class WorldCount implements Serializable {
    private static final Pattern SPACE = Pattern.compile(" ");
    public static void main(String[] args) {

        String content = "/data/test/README.md";
        SparkSession spark = SparkSession
                .builder()
                .master("local[1]")
                .appName("JavaWordCount")
                .getOrCreate();
        JavaRDD<String> lines = spark.read().textFile(content).javaRDD();

        JavaRDD<String> words = lines.flatMap(
                new FlatMapFunction<String, String>() {
                    @Override
                    public Iterator<String> call(String s) throws Exception {
                        return Arrays.asList(SPACE.split(s)).iterator();
                }
            }
        );

        JavaPairRDD<String, Integer> ones = words.mapToPair(
            new PairFunction<String, String, Integer>() {
                @Override
                public Tuple2<String, Integer> call(String s) throws Exception {
                    return new Tuple2<>(s, 1);
                }
            }
        );

        JavaPairRDD<String, Integer> counts = ones.reduceByKey(
            new Function2<Integer, Integer, Integer>() {
                @Override
                public Integer call(Integer v1, Integer v2) throws Exception {
                    return null;
                }
            }
        );

        List<Tuple2<String, Integer>> output = counts.collect();
        for (Tuple2<?,?> tuple : output) {
            System.out.println(tuple._1() + ":" + tuple._2());
        }
        spark.stop();
    }
}

参考