转载:CART算法的简单实现

mac2022-06-30  108

花了两天时间将cart算法中离散数据分类写完(后面还有连续数据的处理和决策树裁剪)。这次感觉比id3实现要更有成就感,毕竟一般以上的代码自己写的。不过看看写好的代码还是有些不堪回首啊。写代码还不熟练以后要多加锻炼!

cart算法介绍:

与id3相比cart主要在度量参数方面不同,cart用gini指标用作属性划分的标准。,其中pi为D中元素属于Ci类的概率。

对于元素的二元分裂由另一公式判断:

对于单列属性的二元分裂要选取GiniA(D)最小的一个来最为该属性列上的一个合理划分。而选择作为节点的属性列也要根据最小的gini指标判断。

大致的特点就是这样。

 

1 for (int i = 0; i < columns.size(); i++) {2 tempgini=getColGini(columns.get(i), temppropersep);3 if (gini > tempgini) {4 gini = tempgini;5 propersep=new CARTProperClassify(temppropersep);6 minColIndex = i;7 }8 }

 

1  int  totalcount  =  totalTargets.totalCount; 2  for  ( int  i  =  0 ; i  <  totalList.size(); i ++ ) { 3  double  p = 0.0 ; 4  int  itemcount  =  totalList.get(i).counts; 5  p  =  ( double ) itemcount  /  totalcount; 6  p  *=  p; 7  gini += p; 8  } 9  gini  =  1  -  gini;

 

 

 

但是,算法实现的外有一个难点,就是在选择二元分裂时,对属性项的真子集选择。对于有n个属性值的属性,会有2^n种不同的组合方式。

因为算法水平还没那么高,这个我就直接借鉴别人的代码了。大致思想就是将每个属性用二进制位(n个属性就用n个二进制为来表示)来表示,1表示选择该属性,0表示不选择该属性。

 

代码 package  BaseStructure.Tree; import  java.util.BitSet; import  java.util.HashSet; import  java.util.Set; public  class  ProperSubsetCombination { private  static  Integer[] array; private  static  BitSet startBitSet;  //  比特集合起始状态 private  static  BitSet endBitSet;  //  比特集合终止状态,用来控制循环 private  static  Set < Set < Integer >>  properSubset;  //  真子集集合 public  static  Set < Set < Integer >>  getProperSubset( int  n, Set < Integer >  itemSet) {Integer[] array  =  new  Integer[itemSet.size()];ProperSubsetCombination.array  =  itemSet.toArray(array);properSubset  =  new  HashSet < Set < Integer >> ();startBitSet  =  new  BitSet();endBitSet  =  new  BitSet(); //  初始化startBitSet,左侧占满1 for  ( int  i  =  0 ; i  <  n; i ++ ) {startBitSet.set(i,  true );} //  初始化endBit,右侧占满1 for  ( int  i  =  array.length  -  1 ; i  >=  array.length  -  n; i -- ) {endBitSet.set(i,  true );} //  根据起始startBitSet,将一个组合加入到真子集集合中 get(startBitSet); while  ( ! startBitSet.equals(endBitSet)) { int  zeroCount  =  0 ;  //  统计遇到10后,左边0的个数 int  oneCount  =  0 ;  //  统计遇到10后,左边1的个数 int  pos  =  0 ;  //  记录当前遇到10的索引位置 //  遍历startBitSet来确定10出现的位置 for  ( int  i  =  0 ; i  <  array.length; i ++ ) { if  ( ! startBitSet.get(i)) {zeroCount ++ ;} if  (startBitSet.get(i)  &&  ! startBitSet.get(i  +  1 )) {pos  =  i;oneCount  =  i  -  zeroCount; //  将10变为01 startBitSet.set(i,  false );startBitSet.set(i  +  1 ,  true ); break ;}} //  将遇到10后,左侧的1全部移动到最左侧 int  counter  =  Math.min(zeroCount, oneCount); int  startIndex  =  0 ; int  endIndex  =  0 ; if  (pos  >  1  &&  counter  >  0 ) {pos -- ;endIndex  =  pos; for  ( int  i  =  0 ; i  <  counter; i ++ ) {startBitSet.set(startIndex,  true );startBitSet.set(endIndex,  false );startIndex  =  i  +  1 ;pos -- ; if  (pos  >  0 ) {endIndex  =  pos;}}}get(startBitSet);} return  properSubset;} private  static  void  get(BitSet bitSet) {Set < Integer >  set  =  new  HashSet < Integer > (); for  ( int  i  =  0 ; i  <  array.length; i ++ ) { if  (bitSet.get(i)) {set.add(array[i]);}}properSubset.add(set);}}

 

 

接下来就的工作就是完善算法的了。

转载于:https://www.cnblogs.com/PierreDelatour/archive/2011/11/12/2246651.html

相关资源:JAVA上百实例源码以及开源项目
最新回复(0)