介绍

Byzer是一门面向 Data 和 AI 的低代码、云原生的开源编程语言

官方网址:https://www.mlsql.tech/home

可以快速的实现数据的价值,因为从数据接入、清洗、存储、计算、机器学习都是一站式解决的,并且大部分情况下都是用SQL就可以解决的,如果有特别的需求,还可以支持插件开发等扩展能力。

因为工作项目上的原因,需要使用xgboost的算法能力,但是又不想增加使用的成本和难度,所以我们还是希望可以直接使用byzer的操作方式来实现xgboost的算法建模能力。

下面就简单介绍下我们当前的扩展流程。

环境

引擎版本:byzer-lang-3.1.1-2.12-2.3.2

3.1.1是spark的版本

2.12是scala的版本

2.3.2是byzer-lang的版本

byzer-lang源代码下载:git@github.com:byzer-org/byzer-lang.git

byzer-extension源代码下载:git@github.com:byzer-org/byzer-extension.git

开发

下载代码和编译代码

1
2
git clone git@github.com:byzer-org/byzer-lang.git
./dev/package.sh

编译文件拷贝到执行的环境

1
cp 开发目录/byzer-lang/streamingpro-mlsql/target/byzer-lang-3.1.1-2.12-2.3.0-SNAPSHOT.jar 执行目录/byzer-lang/main

下载插件扩展项目源代码

1
git clone git@github.com:byzer-org/byzer-extension.git

安装byzer-lang的包

1
2
# 在byzer-lang下执行一下命令安装jar
mvn clean install  -DskipTests -Ponline   -Phive-thrift-server -Pjython-support -Pscala-2.12 -Pspark-3.0.0 -Pstreamingpro-spark-3.0.0-adaptor

创建新的插件项目模块

1
2
3
1. 创建一个模块项目byzer-xgboost
2. 项目文件名称是byzer-xgboost,模块的名称是byzer-xgboost-3.0_2.12
3. 创建pom文件和配置扩展插件所需要的依赖
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>mlsql-plugins-3.0_2.12</artifactId>
        <groupId>tech.mlsql</groupId>
        <version>0.1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>byzer-xgboost-3.0_2.12</artifactId>

    <properties>
        <!-- spark 3.0 start -->
        <scala.version>2.12.10</scala.version>
        <scala.binary.version>2.12</scala.binary.version>
        <spark.version>3.1.1</spark.version>
        <spark.bigversion>3.0</spark.bigversion>
        <scalatest.version>3.0.3</scalatest.version>
        <!-- spark 3.0 end -->
    </properties>

    <dependencies>
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j_${scala.binary.version}</artifactId>
            <version>1.6.2</version>
        </dependency>
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
            <version>1.6.2</version>
        </dependency>

    </dependencies>

    <profiles>
        <profile>
            <id>shade</id>
            <build>
                <plugins>
                    <plugin>
                        <groupId>org.apache.maven.plugins</groupId>
                        <artifactId>maven-shade-plugin</artifactId>
                        <version>3.2.0</version>
                        <configuration>
                            <filters>
                                <filter>
                                    <artifact>*:*</artifact>
                                    <excludes>
                                        <exclude>META-INF/*.SF</exclude>
                                        <exclude>META-INF/*.DSA</exclude>
                                        <exclude>META-INF/*.RSA</exclude>
                                    </excludes>
                                </filter>
                            </filters>
                            <createDependencyReducedPom>false</createDependencyReducedPom>
                            <relocations>
                                <relocation>
                                    <pattern>org.apache.poi</pattern>
                                    <shadedPattern>shadeio.poi</shadedPattern>
                                </relocation>
                                <relocation>
                                    <pattern>com.norbitltd.spoiwo</pattern>
                                    <shadedPattern>shadeio.spoiwo</shadedPattern>
                                </relocation>
                                <relocation>
                                    <pattern>com.github.pjfanning</pattern>
                                    <shadedPattern>shadeio.pjfanning</shadedPattern>
                                </relocation>
                                <relocation>
                                    <pattern>org.apache.commons.compress</pattern>
                                    <shadedPattern>shadeio.commons.compress</shadedPattern>
                                </relocation>
                            </relocations>
                        </configuration>

                        <executions>
                            <execution>
                                <phase>package</phase>
                                <goals>
                                    <goal>shade</goal>
                                </goals>
                            </execution>
                        </executions>
                    </plugin>
                </plugins>
            </build>
        </profile>
    </profiles>

</project>
1
4. 创建插件配置文件desc.plugin
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
moduleName=byzer-xgboost-3.0
mainClass=tech.mlsql.plugins.xgboost.app.XGBoostETApp
scala_version=2.12
spark_version=3.0
version=0.1.0-SNAPSHOT
author=tangf
mlsqlVersions=""
githubUrl=https://github.com/byzer-org/byzer-extension/tree/master/byzer-xgboost
mlsqlPluginType=app
desc=byzer-xgboost
1
5. 创建注册类 tech.mlsql.plugins.xgboost.app.XGBoostETApp
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package tech.mlsql.plugins.xgboost.app


import tech.mlsql.common.utils.log.Logging
import tech.mlsql.ets.register.ETRegister
import tech.mlsql.plugins.xgboost.ets.SQLXGBoostClassifier
import tech.mlsql.version.VersionCompatibility

/**
 * @author tangfei(tangfeizz@outlook.com)
 * @date 2022/9/27 14:29
 */
class XGBoostETApp extends tech.mlsql.app.App with VersionCompatibility with Logging {
  override def run(args: Seq[String]): Unit = {
    ETRegister.register("XGBoostClassifier", classOf[SQLXGBoostClassifier].getName)

  }


  override def supportedVersions: Seq[String] = {
    XGBoostETApp.versions
  }
}

object XGBoostETApp {
  val versions = Seq(">=2.0.0")
}

1
6. 创建扩展的核心类 tech.mlsql.plugins.xgboost.ets.SQLXGBoostClassifier
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
package tech.mlsql.plugins.xgboost.ets

import ml.dmlc.xgboost4j.scala.spark.{TrackerConf, XGBoostClassificationModel, XGBoostClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.ml.util.MLWritable
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import streaming.dsl.mmlib._
import streaming.dsl.mmlib.algs.{CodeExampleText, Functions, MetricValue, MllibFunctions, SQLPythonFunc}
import streaming.dsl.mmlib.algs.classfication.BaseClassification
import streaming.dsl.mmlib.algs.param.BaseParams

import scala.collection.mutable.ArrayBuffer

/**
 * @author tangfei(tangfeizz@outlook.com)
 * @date 2022/9/5 14:37
 */
class SQLXGBoostClassifier(override val uid: String) extends SQLAlg with MllibFunctions with Functions with BaseClassification {

  // 是否开启网格搜索,默认不开启
  private var cvEnabled = false

  def this() = this(BaseParams.randomUID())

  override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {

    val keepVersion = params.getOrElse("keepVersion", "true").toBoolean

    cvEnabled = params.getOrElse("cvEnabled", "false").toBoolean

    setKeepVersion(keepVersion)

    val evaluateTable = params.get("evaluateTable")
    setEvaluateTable(evaluateTable.getOrElse("None"))

    SQLPythonFunc.incrementVersion(path, keepVersion)
    val spark = df.sparkSession

    trainModelsWithMultiParamGroupOrCrossValidator(df, path, params, () => {
      getXGBoostClassifier(params)
    }, (_model, fitParam) => {
      evaluateTable match {
        case Some(etable) =>
          val model = _model.asInstanceOf[XGBoostClassificationModel]
          val evaluateTableDF = spark.table(etable)
          val predictions = model.transform(evaluateTableDF)
          multiclassClassificationEvaluate(predictions, (evaluator) => {
            evaluator.setLabelCol(fitParam.getOrElse("labelCol", "label"))
            evaluator.setPredictionCol("prediction")
          })

        case None => List()
      }
    }
    )

    formatOutput(getModelMetaData(spark, path))
  }


  override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
    val model = load(df.sparkSession, path, params).asInstanceOf[ArrayBuffer[XGBoostClassificationModel]].head
    model.transform(df)
  }

  override def explainParams(sparkSession: SparkSession): DataFrame = {
    _explainParams(sparkSession, () => {
      new XGBoostClassifier(Map("tracker_conf" -> TrackerConf(0L, "scala")))
    })
  }

  override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {

    val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
    val model = XGBoostClassificationModel.load(bestModelPath(0))
    ArrayBuffer(model)
  }


  override def explainModel(sparkSession: SparkSession, path: String, params: Map[String, String]): DataFrame = {
    val models = load(sparkSession, path, params).asInstanceOf[ArrayBuffer[XGBoostClassificationModel]]
    val rows = models.flatMap { model =>
      val modelParams = model.params.filter(param => model.isSet(param)).map { param =>
        val tmp = model.get(param).get
        val str = if (tmp == null) {
          "null"
        } else tmp.toString
        Seq(("fitParam.[group]." + param.name), str)
      }
      Seq(
        Seq("uid", model.uid)
      ) ++ modelParams
    }.map(Row.fromSeq(_))
    sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(rows, 1),
      StructType(Seq(StructField("name", StringType), StructField("value", StringType))))
  }

  override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
    predict_classification(sparkSession, _model, name)
  }

  override def modelType: ModelType = AlgType

  override def doc: Doc = Doc(MarkDownDoc,
    """
      |
      |Xgboost是一个优化的分布式梯度增强库,设计为高效、灵活和可移植,它在梯度增强框架下实现机器学习算法,Xgboost提供并行树增强,也被称为以快速和准确的方式解决许多数据科学问题,相同的代码在主要的分布式环境hadoop上运行,可以解决数十亿例以上的问题
      |
    """.stripMargin)


  override def codeExample: Code = Code(SQLCode, CodeExampleText.jsonStr +
    """
      |-- 测试数据集
      |set jsonStr='''
      |{"features":[5.1,3.5,1.4,0.2],"label":0.0},
      |{"features":[4.9,3.0,1.4,0.2],"label":0.0},
      |{"features":[4.9,3.1,1.5,0.1],"label":0.0},
      |{"features":[5.4,3.7,1.5,0.2],"label":0.0},
      |{"features":[5.1,3.8,1.9,0.4],"label":0.0},
      |{"features":[6.4,3.2,4.5,1.5],"label":1.0},
      |{"features":[6.9,3.1,4.9,1.5],"label":1.0},
      |{"features":[5.7,2.9,4.2,1.3],"label":1.0},
      |{"features":[6.2,2.9,4.3,1.3],"label":1.0},
      |{"features":[5.1,2.5,3.0,1.1],"label":1.0},
      |{"features":[5.7,2.8,4.1,1.3],"label":1.0},
      |{"features":[6.3,3.3,6.0,2.5],"label":2.0},
      |{"features":[7.2,3.6,6.1,2.5],"label":2.0},
      |{"features":[6.2,2.8,4.8,1.8],"label":2.0},
      |{"features":[6.1,3.0,4.9,1.8],"label":2.0}
      |''';
      |-- 加载数据集
      |load jsonStr.`jsonStr` as data;
      |select vec_dense(features) as features ,label as label from data
      |as data1;
      |
      |-- 使用XGBoost算法
      |train data1 as XGBoostClassifier.`/tmp/model` where
      |
      |-- 设置为true,可以保留训练的模型版本
      |keepVersion="true"
      |
      |-- 用于评估的数据集
      |and evaluateTable="data1"
      |
      |-- 用于网格搜索参数,默认是不启用
      |and cvEnabled="false"
      |
      |-- 可以设置多个参数组
      |-- cv开头的参数必须开启cvEnabled="true"
      |-- cv_xx 参数是网格搜索的一般参数
      |-- cv_xgb_xx 参数是xgboost的参数
      |-- cv_grid_xx 参数是网格搜索的范围参数
      |and `fitParam.0.featuresCol`="features"
      |and `fitParam.0.labelCol`="label"
      |and `fitParam.0.cv_xgb_objective`="multi:softprob"
      |and `fitParam.0.cv_xgb_numClass`="3"
      |and `fitParam.0.cv_grid_maxDepth`="2,8"
      |;
    """.stripMargin)

  /**
   * 获取XGBoostClassifier
   *
   * @param params
   * @return
   */
  def getXGBoostClassifier(params: Map[String, String]): Params = {

    var modelParams: Params = new XGBoostClassifier(Map(
      "tracker_conf" -> TrackerConf(0L, "scala")
    ))

    if (cvEnabled) {
      val booster = modelParams.asInstanceOf[XGBoostClassifier]
      modelParams = new CrossValidator()
        .setEstimator(booster)
    }

    modelParams
  }

  /**
   * 获取最佳模型
   *
   * @param alg
   * @param trainData
   * @param fitParam
   * @return
   */
  def getBestXGBoostModel(alg: Params, trainData: DataFrame, fitParam: Map[String, String]): XGBoostClassificationModel = {

    val cvModel: CrossValidator = alg.asInstanceOf[CrossValidator]
    val booster = cvModel.getEstimator.asInstanceOf[XGBoostClassifier]

    //配置网格搜索的一般参数
    if (fitParam.isDefinedAt("cv_numFolds")) {
      cvModel.setNumFolds(fitParam.getOrElse("cv_numFolds", "3").toInt)
    }
    if (fitParam.isDefinedAt("cv_parallelism")) {
      cvModel.setParallelism(fitParam.getOrElse("cv_parallelism", "1").toInt)
    }

    //网格搜索
    val paramGrid = new ParamGridBuilder()
    //配置cvgrid参数
    fitParam.filter(a => a._1.startsWith("cv_grid_")).foreach(f = a => {
      val cvName = a._1.stripPrefix("cv_grid_")
      if (booster.hasParam(cvName)) {
        val para = booster.getParam(cvName)
        val values = a._2.split(",")
        para match {
          case _ if para.isInstanceOf[IntParam] => paramGrid.addGrid(para, values.map(b => b.toInt))
          case _ if para.isInstanceOf[FloatParam] => paramGrid.addGrid(para, values.map(b => b.toFloat))
          case _ => paramGrid.addGrid(para, values)
        }
      }
    })

    //配置xgb参数
    fitParam.filter(a => a._1.startsWith("cv_xgb_")).foreach(a => {
      val cvName = a._1.stripPrefix("cv_xgb_")
      if (booster.hasParam(cvName)) {
        val para = booster.getParam(cvName)
        para match {
          case _ if para.isInstanceOf[IntParam] => booster.set(booster.getParam(cvName), a._2.toInt)
          case _ if para.isInstanceOf[LongParam] => booster.set(booster.getParam(cvName), a._2.toLong)
          case _ if para.isInstanceOf[FloatParam] => booster.set(booster.getParam(cvName), a._2.toFloat)
          case _ if para.isInstanceOf[DoubleParam] => booster.set(booster.getParam(cvName), a._2.toDouble)
          case _ => booster.set(booster.getParam(cvName), a._2)
        }

      }
    })

    //配置模型评估
    val evaluator = new MulticlassClassificationEvaluator()
    if (fitParam.isDefinedAt("cv_evaluator")) {
      evaluator.setMetricName(fitParam.getOrElse("cv_evaluator", "f1"))
    }
    evaluator.setLabelCol("label")
    evaluator.setPredictionCol("prediction")

    cvModel
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid.build())

    val cvm = cvModel.fit(trainData)

    val model = cvm.asInstanceOf[CrossValidatorModel].bestModel.asInstanceOf[XGBoostClassificationModel]

    model
  }

  /**
   * 训练模型,支持多组参数和网格搜索
   *
   * @param df
   * @param path
   * @param params
   * @param modelType
   * @param evaluate
   */
  def trainModelsWithMultiParamGroupOrCrossValidator(df: DataFrame, path: String, params: Map[String, String],
                                                     modelType: () => Params,
                                                     evaluate: (Params, Map[String, String]) => List[MetricValue]
                                                    ) = {

    val keepVersion = params.getOrElse("keepVersion", "true").toBoolean

    val mf = (trainData: DataFrame, fitParam: Map[String, String], modelIndex: Int) => {
      val alg = modelType()
      configureModel(alg, fitParam)

      logInfo(format(s"[training] [alg=${alg.getClass.getName}] [keepVersion=${keepVersion}]"))

      var status = "success"
      var info = ""
      val modelTrainStartTime = System.currentTimeMillis()
      val modelPath = SQLPythonFunc.getAlgModelPath(path, keepVersion) + "/" + modelIndex
      var scores: List[MetricValue] = List()

      var model: XGBoostClassificationModel = null
      try {
        if (cvEnabled) {
          model = getBestXGBoostModel(alg, trainData, fitParam)
        } else {
          model = alg.asInstanceOf[XGBoostClassifier].fit(trainData)
        }
        model.asInstanceOf[MLWritable].write.overwrite().save(modelPath)
        scores = evaluate(model, fitParam)
        logInfo(format(s"[trained] [alg=${alg.getClass.getName}] [metrics=${scores}] [model hyperparameters=${
          model.asInstanceOf[Params].explainParams().replaceAll("\n", "\t")
        }]"))
      } catch {
        case e: Exception =>
          info = format_exception(e)
          logInfo(info)
          status = "fail"
      }
      val modelTrainEndTime = System.currentTimeMillis()

      val metrics = scores.map(score => Row.fromSeq(Seq(score.name, score.value))).toArray
      Row.fromSeq(Seq(
        modelPath.substring(path.length),
        modelIndex,
        alg.getClass.getName,
        metrics,
        status,
        info,
        modelTrainStartTime,
        modelTrainEndTime,
        fitParam))
    }
    var fitParam = arrayParamsWithIndex("fitParam", params)

    if (fitParam.size == 0) {
      fitParam = Array((0, Map[String, String]()))
    }

    val wowRes = fitParam.map { fp =>
      mf(df, fp._2, fp._1)
    }

    val wowRDD = df.sparkSession.sparkContext.parallelize(wowRes, 1)

    df.sparkSession.createDataFrame(wowRDD, StructType(Seq(
      StructField("modelPath", StringType),
      StructField("algIndex", IntegerType),
      StructField("alg", StringType),
      StructField("metrics", ArrayType(StructType(Seq(
        StructField(name = "name", dataType = StringType),
        StructField(name = "value", dataType = DoubleType)
      )))),

      StructField("status", StringType),
      StructField("message", StringType),
      StructField("startTime", LongType),
      StructField("endTime", LongType),
      StructField("trainParams", MapType(StringType, StringType))
    ))).
      write.
      mode(SaveMode.Overwrite).
      parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/0")
  }
}

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
7. 安装插件工具
# 因为需要使用python所以建议在conda中创建一个环境安装
# 在python的环境下执行安装
pip install mlsql_plugin_tool

8. 编译打包指定的插件 
# 这里可能会有些报错,主要就是看看byzer-lang的依赖安装的版本,如果编译找不到版本,记得修改pom文件里面的版本号,重新编译
mlsql_plugin_tool build --module_name byzer-xgboost --spark spark311

9. 拷贝编译好的插件jar文件到byzer-lang的执行环境下的plugin的目录中
cp byzer-xgboost/build/byzer-xgboost-3.0_2.12-0.1.0-SNAPSHOT.jar 执行目录/byzer-lang/plugin

10. 修改byzer-lang的配置文件 byzer.properties 或者在byzer.properties.override修改
# 添加刚才新增插件的类路径 tech.mlsql.plugins.xgboost.app.XGBoostETApp
1
streaming.plugin.clzznames=tech.mlsql.plugins.ds.MLSQLExcelApp,tech.mlsql.plugins.assert.app.MLSQLAssert,tech.mlsql.plugins.shell.app.MLSQLShell,tech.mlsql.plugins.ext.ets.app.MLSQLETApp,tech.mlsql.plugins.mllib.app.MLSQLMllib,tech.mlsql.plugins.visualization.ByzerVisualizationApp,tech.mlsql.plugins.xgboost.app.XGBoostETApp

测试

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
-- 创建测试数据集
set train_json='''
{"features":[5.8,4.0,1.2,0.2],"label":0.0},
{"features":[5.7,4.4,1.5,0.4],"label":0.0},
{"features":[5.4,3.9,1.3,0.4],"label":0.0},
{"features":[5.1,3.5,1.4,0.3],"label":0.0},
{"features":[5.7,3.8,1.7,0.3],"label":0.0},
{"features":[5.2,3.5,1.5,0.2],"label":0.0},
{"features":[5.2,3.4,1.4,0.2],"label":0.0},
{"features":[4.9,3.1,1.5,0.1],"label":0.0},
{"features":[4.4,3.2,1.3,0.2],"label":0.0},
{"features":[5.0,3.5,1.6,0.6],"label":0.0},
{"features":[5.1,3.8,1.9,0.4],"label":0.0},
{"features":[5.6,2.7,4.2,1.3],"label":1.0},
{"features":[5.7,3.0,4.2,1.2],"label":1.0},
{"features":[5.7,2.9,4.2,1.3],"label":1.0},
{"features":[6.2,2.9,4.3,1.3],"label":1.0},
{"features":[5.1,2.5,3.0,1.1],"label":1.0},
{"features":[5.7,2.8,4.1,1.3],"label":1.0}
''';

set test_json='''
{"features":[5.1,3.5,1.4,0.2],"label":0.0},
{"features":[4.9,3.0,1.4,0.2],"label":0.0},
{"features":[4.7,3.2,1.3,0.2],"label":0.0},
{"features":[4.6,3.1,1.5,0.2],"label":0.0},
{"features":[5.0,3.6,1.4,0.2],"label":0.0},
{"features":[5.4,3.9,1.7,0.4],"label":0.0},
{"features":[4.6,3.4,1.4,0.3],"label":0.0},
{"features":[5.0,3.4,1.5,0.2],"label":0.0},
{"features":[4.8,3.0,1.4,0.1],"label":0.0},
{"features":[4.3,3.0,1.1,0.1],"label":0.0},
{"features":[6.7,3.1,4.4,1.4],"label":1.0},
{"features":[5.6,3.0,4.5,1.5],"label":1.0},
{"features":[5.8,2.7,4.1,1.0],"label":1.0},
{"features":[6.2,2.2,4.5,1.5],"label":1.0},
{"features":[5.6,2.5,3.9,1.1],"label":1.0},
{"features":[5.9,3.2,4.8,1.8],"label":1.0},
{"features":[6.1,2.8,4.0,1.3],"label":1.0}
''';
load jsonStr.`train_json` as train_data_ds;
select vec_dense(features) as features ,label as label from train_data_ds
as train_data;

load jsonStr.`test_json` as test_data_ds;
select vec_dense(features) as features ,label as label from test_data_ds
as test_data;

-- 使用xgboost算法进行训练
train train_data as XGBoostClassifier.`/tmp/model` where
keepVersion="true" 
and evaluateTable="test_data"
and cvEnabled="true"
and `fitParam.0.featuresCol`="features"
and `fitParam.0.labelCol`="label"
and `fitParam.0.cv_grid_maxDepth`="2,8"
;

-- 预测数据
predict test_data as XGBoostClassifier.`/tmp/model` as predict_data;

总结

新增xgboost插件已经提交社区了,希望大家一起完善和改进

后续会在此基础上新增xgboost的回归能力