Understanding Feature Importance and Visualization of Tree Models

Feature importance is a crucial concept in machine learning, particularly in tree-based models. It refers to techniques that assign a score to input features based on their usefulness in predicting a target variable. This article will delve into the methods of calculating feature importance, the significance of these scores, and how to visualize them effectively.

Table of Content

  • Feature Importance in Tree Models
  • Methods to Calculate Feature Importance
    • 1. Decision Tree Feature Importance
    • 2. Random Forest Feature Importance
    • 3. Permutation Feature Importance
  • Demonstrating Visualization of Tree Models
  • Yellowbrick for Visualization of Tree Models

Feature Importance in Tree Models

Feature importance scores provide insights into the data and the model. They help in understanding which features contribute the most to the prediction, aiding in dimensionality reduction and feature selection. This can improve the efficiency and effectiveness of a predictive model.

In tree-based models, feature importance can be derived in several ways:

  • Gini Importance (Mean Decrease in Impurity): In Decision Trees and Random Forests, the importance of a feature is often calculated based on the total decrease in node impurity (Gini impurity or entropy) that the feature achieves across all the trees in the forest.
  • Mean Decrease in Accuracy: This method involves shuffling the values of each feature and observing the decrease in model accuracy. A significant drop in accuracy indicates high importance of the feature.
  • Permutation Importance: Similar to the mean decrease in accuracy, permutation importance measures the change in model performance after randomly permuting the feature values, thus breaking the relationship between the feature and the target.

Methods to Calculate Feature Importance

There are several methods to calculate feature importance, each with its own advantages and applications. Here, we will explore some of the most common methods used in tree-based models.

1. Decision Tree Feature Importance

Decision trees, such as Classification and Regression Trees (CART), calculate feature importance based on the reduction in a criterion (e.g., Gini impurity or entropy) used to select split points. The importance score for each feature is the total reduction of the criterion brought by that feature.

Example: DecisionTreeClassifier

Python
from sklearn.datasets import make_classification
from sklearn.tree import DecisionTreeClassifier
from matplotlib import pyplot as plt

# Define dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_redundant=5, random_state=1)
model = DecisionTreeClassifier()
model.fit(X, y)

# Get importance
importance = model.feature_importances_

# Summarize feature importance
for i, v in enumerate(importance):
    print(f'Feature: {i}, Score: {v:.5f}')
plt.bar([x for x in range(len(importance))], importance)
plt.show()

Output:

Feature: 0, Score: 0.01078
Feature: 1, Score: 0.01851
Feature: 2, Score: 0.18831
Feature: 3, Score: 0.30516
Feature: 4, Score: 0.08657
Feature: 5, Score: 0.00733
Feature: 6, Score: 0.18437
Feature: 7, Score: 0.02780
Feature: 8, Score: 0.12904
Feature: 9, Score: 0.04215

Decision Tree Feature Importance

2. Random Forest Feature Importance

Random forests are ensembles of decision trees. They calculate feature importance by averaging the importance scores of each feature across all the trees in the forest.

Example: RandomForestClassifier

Python
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from matplotlib import pyplot as plt

X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_redundant=5, random_state=1)
model = RandomForestClassifier()
model.fit(X, y)

# Get importance
importance = model.feature_importances_
# Summarize feature importance
for i, v in enumerate(importance):
    print(f'Feature: {i}, Score: {v:.5f}')

# Plot feature importance
plt.bar([x for x in range(len(importance))], importance)
plt.show()

Output:

Feature: 0, Score: 0.06806
Feature: 1, Score: 0.10468
Feature: 2, Score: 0.15456
Feature: 3, Score: 0.20209
Feature: 4, Score: 0.08275
Feature: 5, Score: 0.09979
Feature: 6, Score: 0.10596
Feature: 7, Score: 0.04535
Feature: 8, Score: 0.09206
Feature: 9, Score: 0.04471

Random Forest Feature Importance

3. Permutation Feature Importance

Permutation feature importance involves shuffling the values of each feature and measuring the decrease in model performance. This method can be applied to any machine learning model, not just tree-based models.

Python
from sklearn.inspection import permutation_importance
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_redundant=5, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
model = RandomForestClassifier()
model.fit(X_train, y_train)

# Perform permutation importance
results = permutation_importance(model, X_test, y_test, scoring='accuracy')
for i, v in enumerate(results.importances_mean):
    print(f'Feature: {i}, Score: {v:.5f}')
plt.bar([x for x in range(len(results.importances_mean))], results.importances_mean)
plt.show()

Output:

Feature: 0, Score: 0.00800
Feature: 1, Score: 0.06200
Feature: 2, Score: 0.12000
Feature: 3, Score: 0.10200
Feature: 4, Score: -0.00100
Feature: 5, Score: 0.03600
Feature: 6, Score: 0.01800
Feature: 7, Score: 0.00300
Feature: 8, Score: 0.03500
Feature: 9, Score: -0.00500

Permutation Feature Importance

Demonstrating Visualization of Tree Models

The decision tree is visualized using the plot_tree() function. The tree structure is displayed with nodes representing decisions and leaves representing class labels.

Python
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree

iris = load_iris()
X, y = iris.data, iris.target
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X, y)

# Plot the decision tree
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.title('Decision Tree Visualization')
plt.show()

Output:

Visualize the decision tree

Yellowbrick for Visualization of Tree Models

Yellowbrick is a Python library for visualizing the model performance. To visualize a decision tree using Yellowbrick, we can use the ClassPredictionError visualizer.

Python
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from yellowbrick.classifier import ClassPredictionError

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
visualizer = ClassPredictionError(clf, classes=iris.target_names)
visualizer.fit(X_train, y_train)
visualizer.score(X_test, y_test)
visualizer.show()

Output:

Visualize through Yellowbrick

Conclusion

Understanding which features matter most in our machine learning models is crucial for making accurate predictions. By figuring out which factors have the biggest impact on our outcomes, we can better understand how our models work. Visualizing this information, whether through bar charts or other methods, helps us see the big picture and explain our findings to others easily.



Contact Us