Получить числовые ограничения из sklearn.tree [duplicate]

Downvoters: этот код предоставляется только для информации.

Это было протестировано в Fx 19 и Chrome 24 на Mac

DEMO

var new_comment; /*<<
       
          You
       
       
          $text
       
       
          2d
       
    
EOF*/
// note the script tag here is hardcoded as the FIRST tag 
new_comment=document.currentScript.innerHTML.split("EOF")[1]; 
alert(new_comment.replace('$text','Here goes some text'));

91
задан petezurich 22 April 2018 в 19:04
поделиться

12 ответов

Я считаю, что этот ответ более корректен, чем другие ответы здесь:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Это выводит действительную функцию Python. Вот пример вывода для дерева, которое пытается вернуть свой вход, число от 0 до 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Вот несколько камней преткновения, которые я вижу в других ответах:

  1. Использование tree_.threshold == -2 для определения того, является ли узел листом, не является хорошей идеей. Что, если это реальный узел решения с порогом -2? Вместо этого вы должны посмотреть на tree.feature или tree.children_*.
  2. Линия features = [feature_names[i] for i in tree_.feature] выходит из строя с моей версией sklearn, потому что некоторые значения tree.tree_.feature равны -2 (специально для листовых узлов).
  3. Нет необходимости иметь несколько операторов if в рекурсивной функции, только один из них прав.
72
ответ дан NirIzr 15 August 2018 в 17:29
поделиться
  • 1
    Этот код отлично работает для меня. Тем не менее, у меня есть 500 + feature_names, поэтому выходной код почти невозможно понять человеку. Есть ли способ позволить мне вводить только те функции, которые мне интересны в функции? – user3768495 8 September 2017 в 19:05
  • 2
    Я согласен с предыдущим комментарием. IIUC, print "{}return {}".format(indent, tree_.value[node]) следует изменить на print "{}return {}".format(indent, np.argmax(tree_.value[node][0])) для функции, возвращающей индекс класса. – soupault 19 October 2017 в 09:56
  • 3
    Привет @paulkernfeld, большое спасибо за это! Вы сделали то же самое для случайного леса? – Nathan Lloyd 1 November 2017 в 19:42
  • 4
    @NathanLloyd Я думаю, что я изначально написал этот код для случайного леса. Если я правильно помню, все, что мне нужно было сделать, это перебрать каждое дерево и запустить на нем этот код. – paulkernfeld 2 November 2017 в 23:31
  • 5
    @paulkernfeld Ах, да, я вижу, что вы можете перебирать RandomForestClassifier.estimators_, но мне не удалось разобраться, как совместить результаты оценок. – Nathan Lloyd 3 November 2017 в 00:36

Вот функция, правила печати дерева решений scikit-learn под python 3 и смещения для условных блоков, чтобы сделать структуру более читаемой:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)
1
ответ дан Apogentus 15 August 2018 в 17:29
поделиться

По-видимому, давно кто-то уже решил попытаться добавить следующую функцию в функции экспорта дерева дерева scikit (которая в основном поддерживает только export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Вот его полная фиксация:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Не совсем уверен, что случилось с этим комментарием. Но вы также можете попытаться использовать эту функцию.

Я думаю, что это требует серьезного запроса документации хорошим людям scikit-learn, чтобы правильно документировать API sklearn.tree.Tree, который является базовой структурой дерева, которая DecisionTreeClassifier раскрывает его атрибут tree_.

0
ответ дан Aris Koning 15 August 2018 в 17:29
поделиться

Изменен код Zelazny7 для извлечения SQL из дерева решений.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'
0
ответ дан Arslán 15 August 2018 в 17:29
поделиться

Я изменил код, представленный Zelazny7 , чтобы напечатать некоторый псевдокод:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

, если вы вызываете get_code(dt, df.columns) в том же примере, который вы получите:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}
34
ответ дан Community 15 August 2018 в 17:29
поделиться
  • 1
    Можете ли вы сказать, что именно [[1. 0.]] в возвращаемом выражении означает в вышесказанном выходе. Я не парень на Питоне, но работаю над такими же вещами. Так что это будет полезно для меня, если вы, пожалуйста, подтвердите некоторые детали, чтобы мне было легче. – Subhradip Bose 30 May 2015 в 02:14
  • 2
    @ user3156186 Это означает, что в классе «0» есть один объект и нулевые объекты в классе «1», – Daniele 3 June 2015 в 07:39
  • 3
    @ Даниле, ты знаешь, как упорядочиваются классы? Я бы предположил буквенно-цифровое, но я нигде не нашел подтверждения. – IanS 4 September 2015 в 08:27
  • 4
    Благодаря! Для сценария краевого случая, где пороговое значение фактически равно -2, нам может потребоваться изменить (threshold[node] != -2) на ( left[node] != -1) (аналогично методу ниже для получения идентификаторов дочерних узлов) – tlingf 12 May 2016 в 21:26
  • 5
    @Daniele, любая идея, как сделать вашу функцию & quot; get_code & quot; & Quot; возвращение & Quot; значение, а не "печать" это, потому что мне нужно отправить его на другую функцию? – RoyaumeIX 26 May 2016 в 04:52

Это основано на ответе @paulkernfeld. Если у вас есть фреймворк X с вашими функциями и целевым фреймворком y с вашими резонами, и вы хотите получить представление о том, какое значение y закончилось, в каком узле (а также муравье, чтобы построить его соответственно) вы можете сделать следующее:

    def tree_to_code(tree, feature_names):
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

не самая элегантная версия, но она выполняет эту работу ...

1
ответ дан horseshoe 15 August 2018 в 17:29
поделиться

В 0.18.0 имеется новый метод DecisionTreeClassifier , decision_path. Разработчики предоставляют обширное (хорошо документированное) прохождение .

Первый раздел кода в пошаговом руководстве, который печатает древовидную структуру, кажется, в порядке. Тем не менее, я изменил код во втором разделе, чтобы опросить один образец. Мои изменения, обозначенные # <--

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Измените sample_id, чтобы увидеть пути решения для других образцов. Я не спрашивал разработчиков об этих изменениях, просто казался более интуитивным при работе с примером.

12
ответ дан Kevin 15 August 2018 в 17:29
поделиться
  • 1
    ты мой друг легенда! любые идеи о том, как построить дерево решений для этой конкретной выборки? много помощи – Victor 20 February 2018 в 15:45
  • 2
    Спасибо Виктору, вероятно, лучше спросить об этом как о отдельном вопросе, поскольку требования к графике могут быть специфическими для потребностей пользователя. Вероятно, вы получите хороший ответ, если вы предоставите представление о том, как вы хотите, чтобы результат выглядел. – Kevin 20 February 2018 в 16:12
  • 3
    hey kevin, я создал вопрос stackoverflow.com/questions/48888893/… – Victor 20 February 2018 в 16:38
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Вы можете увидеть дерево орграфа. Затем clf.tree_.feature и clf.tree_.value представляют собой массив функций разбиения узлов и массива значений узлов соответственно. Вы можете обратиться к более подробной информации из этого источника github .

11
ответ дан lennon310 15 August 2018 в 17:29
поделиться

Я прошел через это, но мне нужны были правила, которые должны быть записаны в этом формате

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

. Поэтому я адаптировал ответ @paulkernfeld (спасибо), который вы можете настроить в соответствии с вашими потребностями

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)
1
ответ дан Rene B. 15 August 2018 в 17:29
поделиться

Просто потому, что все были так полезны, я просто добавлю изменения в красивые решения Zelazny7 и Daniele. Это для python 2.7, с вкладками, чтобы сделать его более читаемым:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)
2
ответ дан Ruslan 15 August 2018 в 17:29
поделиться

Ниже приведен мой подход под anaconda python 2.7 плюс имя пакета «pydot-ng» для создания файла PDF с правилами принятия решений. Надеюсь, что это полезно.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

здесь показана деревная графика

2
ответ дан TED Zhao 15 August 2018 в 17:29
поделиться

Я создал свою собственную функцию для извлечения правил из деревьев решений, созданных sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Эта функция сначала начинается с узлов (идентифицированных -1 в дочерних массивах), а затем рекурсивно находит родителей. Я называю это «родословной» узла. Попутно я получаю значения, которые мне нужно создать, если / then / else логика SAS:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Наборы кортежей ниже содержат все, что мне нужно для создания SAS, если / then / else. Мне не нравится использование блоков do в SAS, поэтому я создаю логику, описывающую весь путь узла. Единственное целое после кортежей - это идентификатор конечного узла в пути. Все предыдущие кортежи объединяются для создания этого узла.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

GraphViz output of example tree [/g0]

44
ответ дан Zelazny7 15 August 2018 в 17:29
поделиться
  • 1
    этот тип дерева является правильным, потому что col1 возвращается снова, то есть col1 & lt; = 0.50000 и один col1 = 2.5000, если да, это любой тип рекурсии, используемый в библиотеке – jayant singh 1 March 2017 в 17:12
  • 2
    у правой ветви были бы записи между (0.5, 2.5]. Деревья сделаны с рекурсивным разбиением. Нет ничего, что мешало бы выбирать переменную несколько раз. – Zelazny7 1 March 2017 в 18:25
  • 3
    хорошо, вы можете объяснить часть рекурсии, что происходит xactly, потому что я использовал ее в своем коде, и аналогичный результат – jayant singh 1 March 2017 в 18:38
Другие вопросы по тегам:

Похожие вопросы: