spark GBTClassifier 无法指定 rawPrediction这种情况如何计算auc?

关注者
4
被浏览
481
登录后你可以
不限量看优质回答私信答主深度交流精彩内容一键收藏

谢邀!

AUC是分类模型常用的评价手段,目前的Spark mllib里面evaluation包中所提供的auc方法是拿到了roc曲线中的各个点之后再进行auc的计算,但是实际应用场景中(以逻辑回归为例),我们常常是对每个样本进行打分之后整合样本的label直接进行auc的计算,输入可能是(label, predict_score)这样的形式,mllib中提供的方案就不太适用了,所以这里提供了另一种计算方法,采用了针对0,1分类问题的近似计算方案,叫做BinaryAuc:

首先对predict_score进行排序,然后根据样本正负例的情况,分别计算每个小梯形的面积,最后汇总成为最终的auc值(由于在spark中数据是分布式RDD的形态,所以计算梯形面积的时候需要知道前一个RDD的offset,这里需要先遍历数据,但是避免了汇总到单机进行计算):



  1. package org.apache.spark.mllib.wml


  2. /**

  3. * @author wangben 2015

  4. */

  5. import org.apache.spark.SparkContext

  6. import org.apache.spark.SparkConf

  7. import org.apache.spark.rdd.RDD

  8. import org.apache.spark.mllib.rdd.RDDFunctions._

  9. import scala.collection.Iterator

  10. import Array._


  11. class BinaryAUC extends Serializable {

  12. //input format: predictioin,label

  13. def auc( data: RDD[ (Double, Double) ] ) : Double =

  14. {

  15. //group same score result

  16. val group_result = data.groupByKey().map(x => {

  17. var r = new Array[Double](2)

  18. for(item <- x._2) {

  19. if(item > 0.0) r(1) += 1.0

  20. else r(0) += 1.0

  21. }

  22. (x._1, r) // score, [ FalseN, PositiveN ]

  23. })


  24. //points 需要累积

  25. val group_rank = group_result.sortByKey(false) //big first

  26. //计算累积

  27. var step_sizes = group_rank.mapPartitions( x =>

  28. {

  29. var r = List[(Double, Double)]()

  30. var fn_sum = 0.0

  31. var pn_sum = 0.0

  32. while( x.hasNext )

  33. {

  34. val cur = x.next

  35. fn_sum += cur._2(0)

  36. pn_sum += cur._2(1)

  37. }

  38. r.::(fn_sum, pn_sum).toIterator

  39. } ,true).collect

  40. var debug_string = ""

  41. var step_sizes_sum = ofDim[Double](step_sizes.size, 2) //二维数组

  42. for( i <- 0 to (step_sizes.size - 1) ) {

  43. if(i == 0) {

  44. step_sizes_sum(i)(0) = 0.0

  45. step_sizes_sum(i)(1) = 0.0

  46. } else {

  47. step_sizes_sum(i)(0) = step_sizes_sum(i - 1)(0) + step_sizes(i - 1)._1

  48. step_sizes_sum(i)(1) = step_sizes_sum(i - 1)(1) + step_sizes(i - 1)._2

  49. }

  50. debug_string += "\t" + step_sizes_sum(i)(0).toString + "\t" + step_sizes_sum(i)(1).toString

  51. }

  52. val sss_len = step_sizes_sum.size

  53. val total_fn = step_sizes_sum(sss_len - 1)(0) + step_sizes(sss_len - 1)._1

  54. val total_pn = step_sizes_sum(sss_len - 1)(1) + step_sizes(sss_len - 1)._2

  55. //System.out.println( "debug auc_step_size: " + debug_string)


  56. val bc_step_sizes_sum = data.context.broadcast(step_sizes_sum)

  57. val modified_group_rank = group_rank.mapPartitionsWithIndex( (index, x) =>

  58. {

  59. var sss = bc_step_sizes_sum.value

  60. var r = List[(Double, Array[Double])]()

  61. //var r = List[(Double, String)]()

  62. var fn = sss(index)(0) //start point

  63. var pn = sss(index)(1)

  64. while( x.hasNext )

  65. {

  66. var p = new Array[Double](2)

  67. val cur = x.next

  68. p(0) = fn + cur._2(0)

  69. p(1) = pn + cur._2(1)

  70. fn += cur._2(0)

  71. pn += cur._2(1)

  72. //r.::= (cur._1, p(0).toString() + "\t" + p(1).toString())

  73. r.::= (cur._1, p)

  74. }

  75. r.reverse.toIterator

  76. } ,true)


  77. //output debug info

  78. //modified_group_rank.map(l => l._1.toString + "\t" + l._2(0).toString + "\t" + l._2(1)).saveAsTextFile("/home/hdp_teu_dia/resultdata/wangben/debug_info")


  79. val score = modified_group_rank.sliding(2).aggregate(0.0)(

  80. seqOp = (auc: Double, points: Array[ (Double, Array[Double]) ]) => auc + TrapezoidArea(points),

  81. combOp = _ + _

  82. )

  83. System.out.println( "debug auc_mid: " + score

  84. + "\t" + (total_fn*total_pn).toString()

  85. + "\t" + total_fn.toString()

  86. + "\t" + total_pn.toString() )


  87. score/(total_fn*total_pn)

  88. }


  89. private def TrapezoidArea(points :Array[(Double, Array[Double])]):Double = {

  90. val x1 = points(0)._2(0)

  91. val y1 = points(0)._2(1)

  92. val x2 = points(1)._2(0)

  93. val y2 = points(1)._2(1)


  94. val base = x2 - x1

  95. val height = (y1 + y2)/2.0

  96. return base*height

  97. }

  98. }


  99. object AUCTest {


  100. def main(args: Array[String]){

  101. val conf=new SparkConf()

  102. conf.setAppName("TestEvaluation")

  103. val sc = new SparkContext(conf)

  104. val accum=sc.accumulator(0)

  105. val input_file = sc.textFile(args(0))

  106. val predict_label = input_file.map(l => {

  107. val x = l.stripPrefix("(").stripSuffix(")") split(",")

  108. (x(0).toDouble, x(1).toDouble)

  109. })

  110. val auc = new BinaryAUC()

  111. val auc_score = auc.auc(predict_label)

  112. System.out.println("debug auc_score: " + auc_score.toString())

  113. }


  114. }


按照这个方式试下!