DL85Classifier example to export tree as image

######################################################################
#                      DL8.5 default classifier                      #
######################################################################
Model built in 0.0017 seconds
Found tree: {'feat': 5, 'left': {'feat': 32, 'left': {'value': 1, 'error': 51.0, 'proba': [0.1275, 0.8725]}, 'right': {'value': 0, 'error': 2.0, 'proba': [0.95, 0.05]}, 'proba': [0.20227272727272727, 0.7977272727272727]}, 'right': {'feat': 78, 'left': {'value': 1, 'error': 84.0, 'proba': [0.2346368715083799, 0.7653631284916201]}, 'right': {'value': 0, 'error': 0.0, 'proba': [1.0, 0]}, 'proba': [0.26344086021505375, 0.7365591397849462]}, 'proba': [0.23029556650246305, 0.7697044334975369]}
Confusion Matrix below
 [[ 13  21]
 [  0 129]]
Accuracy on training set = 0.8313
Accuracy on test set = 0.8712

'plots/anneal_odt.png'

import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from dl85 import DL85Classifier
import graphviz


print("######################################################################\n"
      "#                      DL8.5 default classifier                      #\n"
      "######################################################################")

# read the dataset and split into features and targets
dataset = np.genfromtxt("../datasets/anneal.txt", delimiter=' ')
X, y = dataset[:, 1:], dataset[:, 0]
# split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

clf = DL85Classifier(max_depth=2)
clf.fit(X, y)
y_pred = clf.predict(X_test)

# show results
print("Model built in", round(clf.runtime_, 4), "seconds")
print("Found tree:", clf.tree_)
print("Confusion Matrix below\n", confusion_matrix(y_test, y_pred))
print("Accuracy on training set =", round(clf.accuracy_, 4))
print("Accuracy on test set =", round(accuracy_score(y_test, y_pred), 4))

# print the tree
dot = clf.export_graphviz()
graph = graphviz.Source(dot, format="png")
graph.render("plots/anneal_odt")

Total running time of the script: ( 0 minutes 0.109 seconds)

Gallery generated by Sphinx-Gallery