""" decision_tree_classification_with_scikit_learn.py creates, trains, and prediction tests a decision tree classifier """ # Import libraries. from sklearn.datasets import load_iris from sklearn import tree import graphviz # Load iris measurements test data. X, y = load_iris(return_X_y=True) # Create a decision tree classifier classifier = tree.DecisionTreeClassifier() # Train the classifier with the test data. classifier = classifier.fit(X, y) # Export the classifier structure. dot_data = tree.export_graphviz(classifier, out_file=None) # Graph the structure. graph = graphviz.Source(dot_data) # Output the graph into a pdf file. graph.render("iris") # Create an array of values for a prediction input. X_new = X - .1 # Predict results based on the X training data and modified X_new data. classifier.predict(X) classifier.predict(X_new)