2017-06-21 6 views
1

Rに組み込まれたランダムフォレストを取得し、getTreeが与えるすべてのifを捨てることなくSASコードに変換する方法はありますか?R randomforestパッケージをベースSASに

私はgetTree機能で1900のラインを持っていた30本の木

+0

これまでに何を試みましたか? – Stedy

+0

コード変換に関する質問は、スタックオーバーフローのトピックにありません。研究を行い、努力をしてから、そのSASプログラムについて質問します。 – Joe

+0

ルールや形式にもよりますが、ほとんどの人はRやSASに精通しているので、いくつかの例を示すために時間を割く必要があります。 – Reeza

答えて

0

を持っていた。これは、私がその周りに横たわっていたものですあなたが始めるのに役立つはずです。これまでは回帰だけがサポートされていましたが、少しの余分な作業で分類を行うことができます:

/* R code for exporting the randomForest object */ 
#Output dataset to csv for validation in SAS 
write.csv(iris,file="C:/temp/iris.csv",row.names=FALSE) 

#Train a 2-tree random forest for testing purposes 
require(randomForest) 
rf2 <- randomForest(iris[,-1],iris[,1],ntree=2) 

# Get predictions and write to csv 
write.csv(predict(rf2,iris),file="c:/temp/pred_rf2b.csv") 

# Export factor levels 
mydata <- iris 
type <- sapply(mydata,class) 
factors = type[type=="factor"] 
output <- lapply(names(factors),function(x){ 
    res <- data.frame(VarName=x, 
        Level=levels(mydata[,x]), 
        Number=1:nlevels(mydata[,x])) 
    return(res) 
}) 

write.csv(do.call(rbind, output),file="c:/temp/factorlevels.csv", row.names=FALSE) 

# Export all trees in one file 
treeoutput <- lapply(1:rf2$ntree,function(x){ 
    res <- getTree(rf2, x, labelVar=TRUE) 
    res$node <- seq.int(nrow(res)) 
    res$treenum <- x 
    return(res) 
}) 

write.csv(do.call(rbind, treeoutput),file="c:/temp/treeexport.csv", row.names=FALSE) 
/*End of R code*/ 

/*Import into SAS, replacing . with _ so we have usable variable names*/ 

proc import 
    datafile = "c:\temp\treeexport.csv" 
    out = tree 
    dbms = csv 
    replace; 
    getnames = yes; 
run; 

data tree; 
set tree; 
SPLIT_VAR = translate(SPLIT_VAR,'_','.'); 
format SPLIT_POINT 8.3; 
run; 

proc import 
    datafile = "c:\temp\factorlevels.csv" 
    out = factorlevels 
    dbms = csv 
    replace; 
    getnames = yes; 
run; 

data _null_; 
    infile "c:\temp\iris.csv"; 
    file "c:\temp\iris2.csv"; 
    input; 
    if _n_ = 1 then _infile_=translate(_infile_,'_','.'); 
    put _infile_; 
run; 

proc import 
    datafile = "c:\temp\iris2.csv" 
    out = iris 
    dbms = csv 
    replace; 
    getnames = yes; 
run; 


data _null_; 
    debug = 0; 
    type = "regression"; 
    maxiterations = 10000; 
    file log; 
    if 0 then set tree factorlevels; 
    /*Hash to hold the whole tree*/ 
    declare hash t(dataset:'tree'); 
    rc = t.definekey('treenum'); 
    rc = t.definekey('node'); 
    rc = t.definedata(all:'yes'); 
    rc = t.definedone(); 

    /*Hash for looking up factor levels*/ 
    declare hash fl(dataset:'factorlevels'); 
    rc = fl.definekey('VARNAME','NUMBER'); 
    rc = fl.definedata('LEVEL'); 
    rc = fl.definedone(); 

    do treenum = 1 by 1 while(t.find(key:treenum,key:1)=0); 
    /*Hash to hold the queue for current tree*/ 
    length position qnode processed 8; 
    declare hash q(ordered:'a'); 
    rc = q.definekey('position'); 
    rc = q.definedata('qnode','position','processed'); 
    rc = q.definedone(); 
    declare hiter qi('q'); 
    /*Hash for reverse queue lookup*/ 
    declare hash q2(); 
    rc = q2.definekey('qnode'); 
    rc = q2.definedata('position'); 
    rc = q2.definedone(); 

    /*Load the starting node for the current tree*/ 
    node = 1; 
    nodetype = 'L'; /*Track whether current node is a Left or Right node*/ 
    complete = 0; 
    length treename $10; 
    treename = cats('tree',treenum); 

    do iteration = 1 by 1 while(complete = 0 and iteration <= maxiterations); 
     rc = t.find(); 
     if debug then put "Processing node " node; 

     /*Logic for terminal nodes*/ 
     if status = -1 then do; 
     if type ne "regression" then prediction = cats('"',prediction,'"'); 
     put treename '=' prediction ';'; 
     /*If current node is a right node, remove it from the queue*/ 
     if nodetype = 'R' then do; 
      rc = q2.find(); 
      if debug then put "Unqueueing node " qnode "in position " position; 
      processed = 1; 
      rc = q.replace(); 
     end; 
     /*If the queue is empty, we are done*/ 
     rc = qi.last(); 
     do while(rc = 0 and processed = 1); 
      if position = 1 then complete = 1; 
      rc = qi.prev(); 
     end; 
     /*Otherwise, process the most recently queued unprocessed node*/ 
     if complete = 0 then do; 
      put "else "; 
      node = qnode; 
      nodetype = 'R'; 
     end; 
     end; 

     /*Logic for split nodes - status ne -1*/ 
     else do; 
     /*Add right_daughter to queue if present*/ 
     position = q.num_items + 1; 
     qnode = right_daughter; 
     processed = 0; 
     rc = q.add(); 
     rc = q2.add(); 
     if debug then put "Queueing node " qnode "in position " position; 

     /*Check whether current split var is a (categorical) factor*/ 
     rc = fl.find(key:split_var,key:1); 
     /*If yes, factor levels corresponding to 1s in the binary representation of the split point go left*/ 
     if rc = 0 then do; 
      /*Get binary representation of split point (least significant bit first)*/ 
      /*binaryw. format behaves very differently above width 58 - only 58 levels per factor supported here*/ 
      /*This is sufficient as the R randomForest package only supports 53 levels per factor anyway*/ 
      binarysplit = reverse(put(split_point,binary58.)); 
      put 'if ' @; 
      j=0; /*Track how many levels have been encountered for this split var*/ 
      do i = 1 to 64 while(rc = 0); 
      if i > 1 then rc = fl.find(key:split_var,key:i); 
      LEVEL = cats('"',LEVEL,'"'); 
      if debug then put _all_; 
      if substr(binarysplit,i,1) = '1' then do; 
       if j > 0 then put ' or ' @; 
       put split_var ' = ' LEVEL @; 
       j + 1; 
      end; 
      end; 
      put 'then'; 
     end; 
     /*If not, anything < split point goes to left child*/ 
     else put "if " split_var "< " split_point 8.3 " then "; 
     if nodetype = 'R' then do; 
      qnode = node; 
      rc = q2.find(); 
      if debug then put "Unqueueing node " qnode "in position " position; 
      processed = 1; 
      rc = q.replace(); 
     end; 
     node = left_daughter; 
     nodetype = 'L'; 
     end; 
    end; 
    /*End of tree function definition!*/ 
    put ';'; 
    /*Clear the queue between trees*/ 
    rc = q.delete(); 
    rc = q2.delete(); 
    end; 

    /*We end up going 1 past the actual number of trees after the end of the do loop*/ 
    treenum = treenum - 1; 

    if type = "regression" then do; 
    put 'RFprediction=('; 
    do i = 1 to treenum; 
     treename = cats('tree',i); 
     put treename +1 @; 
     if i < treenum then put '+' +1 @; 
    end; 
    put ')/' treenum ';'; 
    end; 

    /*To do - write code to aggregate predictions from multiple trees for classification*/ 

    stop; 
run; 


/*Sample of generated if-then-else code */ 

data predictions; 
    set iris; 
if Petal_Length < 4.150 then 
if Petal_Width < 1.050 then 
if Petal_Width < 0.350 then 
tree1 =4.91702127659574 ; 
else 
if Petal_Width < 0.450 then 
tree1 =5.18333333333333 ; 
else 
if Species = "versicolor" then 
tree1 =5.08888888888889 ; 
else 
tree1 =5.1 ; 
else 
if Sepal_Width < 2.550 then 
tree1 =5.525 ; 
else 
if Petal_Length < 4.050 then 
tree1 =5.8 ; 
else 
tree1 =5.63333333333333 ; 
else 
if Petal_Width < 1.950 then 
if Sepal_Width < 3.050 then 
if Species = "setosa" or Species = "virginica" then 
if Petal_Length < 5.700 then 
tree1 =6.05833333333333 ; 
else 
tree1 =7.2 ; 
else 
tree1 =6.176 ; 
else 
if Sepal_Width < 3.250 then 
if Sepal_Width < 3.150 then 
tree1 =6.62 ; 
else 
tree1 =6.66666666666667 ; 
else 
tree1 =6.3 ; 
else 
if Petal_Length < 6.050 then 
if Petal_Width < 2.050 then 
tree1 =6.275 ; 
else 
tree1 =6.65 ; 
else 
if Petal_Length < 6.550 then 
tree1 =7.76666666666667 ; 
else 
tree1 =7.7 ; 
; 
if Petal_Width < 1.150 then 
if Species = "setosa" then 
tree2 =5.08947368421053 ; 
else 
tree2 =5.55714285714286 ; 
else 
if Species = "setosa" or Species = "versicolor" then 
if Sepal_Width < 2.750 then 
if Petal_Length < 4.450 then 
tree2 =5.44 ; 
else 
tree2 =6.06666666666667 ; 
else 
if Petal_Width < 1.350 then 
tree2 =5.85294117647059 ; 
else 
if Petal_Width < 1.750 then 
if Petal_Width < 1.650 then 
tree2 =6.3625 ; 
else 
tree2 =6.7 ; 
else 
tree2 =5.9 ; 
else 
if Petal_Length < 5.850 then 
if Sepal_Width < 2.650 then 
if Petal_Length < 4.750 then 
tree2 =4.9 ; 
else 
if Sepal_Width < 2.350 then 
tree2 =6 ; 
else 
if Sepal_Width < 2.550 then 
tree2 =6.14 ; 
else 
tree2 =6.1 ; 
else 
tree2 =6.49166666666667 ; 
else 
if Petal_Length < 6.350 then 
tree2 =7.125 ; 
else 
tree2 =7.775 ; 
; 
RFprediction=(
tree1 + tree2 )/2 ; 
run; 
関連する問題