In the previous tutorial, we used decision tree as the classifier. Decision Tree is an easy to read and understand classifier. In this tutorial, we are going to write a program to visualize the decision tree.
1. “Iris” Problem
We are going to code a program to solve a classical machine learning problem called “Iris”, to identify what type of flowers base on the measurements: the length and width of the pedal, the length and width of the sepal.
https://en.wikipedia.org/wiki/Iris_flower_data_set
| Sepal length | Sepal width | Petal length | Petal width | Species |
|---|---|---|---|---|
| 5.1 | 3.5 | 1.4 | 0.2 | I. setosa |
| 4.9 | 3.0 | 1.4 | 0.2 | I. setosa |
| 4.7 | 3.2 | 1.3 | 0.2 | I. setosa |
| 5.7 | 2.9 | 4.2 | 1.3 | I. versicolor |
| 6.3 | 3.3 | 6.0 | 2.5 | I. virginica |
The Iris data set includes three types of flowers. They are all species of Iris: Setosa, Versicolor and Virginica.
Iris data is an array of arrays looks like this
[[ 5.1 3.5 1.4 0.2] [ 4.9 3.0 1.4 0.2] [ 5.4 3.9 1.7 0.4]]
Iris data is an array of arrays looks like this
2. Procedure
Our coding procedure would be the following steps:
- Import Iris Dataset
- Create training and testing data
- Train a Classifier
- Predict label for a new flower
- Visualize the Decision Tree
3. Coding
Create a python file iris.py and write following code to program.
Please read comments carefully to understand the meaning of codes.
"""
GoodTecher Machine Learning Coding Tutorial
http://72.44.43.28
"Iris" Machine Learning Program
The program takes a measurements (the length and width of the pedal and sepal)
of a flower as input
and predicts whether it is setosa, versicolor or virginica
"""
import numpy as np
from sklearn.datasets import load_iris
from sklearn import tree
import pydotplus
# load Iris dataset
iris = load_iris()
# picks some data from Iris dataset as test data
# and rest data would be training data
test_idx = [0, 10, 50, 100]
# training data
train_data = np.delete(iris.data, test_idx, axis = 0)
train_target = np.delete(iris.target, test_idx)
# testing data
test_data = iris.data[test_idx]
test_target = iris.target[test_idx]
# train classifier
clf = tree.DecisionTreeClassifier()
clf.fit(train_data, train_target)
# display and compare test target and predict target
print ("Test target: ")
print (test_target)
print ("clf.predict: ")
print (clf.predict(test_data))
# output Decision Tree procedure to a PDF file
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("iris.pdf")
Run the program with the following command in Terminal (Mac) or Command Prompt (Windows):
python iris.py
Do you see the program predicts the flower with a correct label?
Do you find the generated pdf file?
Yep. The machine is clever lol.