Logistic Regression Text Classification with Scikit-Learn
We’ll use the popular SMS Collection Dataset, consists of a collection of SMS (Short Message Service) messages, which are labeled as either “ham” (non-spam) or “spam” based on their content. The implementation is designed to classify text messages into two categories: spam (unwanted messages) and ham (legitimate messages), using a logistic regression model. The process is broken down into several key steps:
Step 1. Import Libraries
The first step involves importing necessary libraries.
- Pandas is used for data manipulation.
- CountVectorizer for converting text data into a numeric format.
- Various functions from sklearn.model_selection and sklearn.linear_model for creating and training the model.
- functions from sklearn.metrics to evaluate the model’s performance.
Python
import pandas as pd from sklearn.feature_extraction.text import CountVectorizer from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score, confusion_matrix |
Step 2. Load and Prepare the Data
- Load the dataset from a CSV file, and rename columns for clarity.
- latin-1 encoding is specified to handle any non-ASCII characters that may be present in the file
- Map labels from text to numeric values (0 for ham, 1 for spam), making it suitable for model training.
Python
data = pd.read_csv( "spam.csv" , encoding = 'latin-1' ) data.rename(columns = { 'v1' : 'label' , 'v2' : 'text' }, inplace = True ) data[ 'label' ] = data[ 'label' ]. map ({ 'ham' : 0 , 'spam' : 1 }) |
Step 3. Text Vectorization
Convert text data into a numeric format using CountVectorizer, which transforms the text into a sparse matrix of token counts.
Python
vectorizer = CountVectorizer() X = vectorizer.fit_transform(data[ 'text' ]) y = data[ 'label' ] |
Step 4. Split Data into Training and Testing Sets
Divide the dataset into training and testing sets to evaluate the model’s performance on unseen data.
Python
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25 , random_state = 42 ) |
Step 5. Train the Logistic Regression Model
Create and train the logistic regression model using the training set.
Python
model = LogisticRegression(random_state = 42 ) model.fit(X_train, y_train) |
Step 6. Model Evaluation
Use the trained model to make predictions on the test set and evaluate the model’s accuracy and confusion matrix to understand its performance better.
Python
y_pred = model.predict(X_test) print ( "Accuracy:" , accuracy_score(y_test, y_pred)) print ( "Confusion Matrix:\n" , confusion_matrix(y_test, y_pred)) |
Output:
Accuracy: 0.9763101220387652
Confusion Matrix:
[[1201 1]
[ 32 159]]
The model is 97.6% correct on unseen data. The Confusion Matrix stated:
- 1201 messages correctly classified as ‘ham’.
- 159 messages correctly classified as ‘spam’.
- 32 ‘ham’ messages wrongly labeled as ‘spam’
- and 1 ‘spam’ wrongly labeled as ‘ham’.
Manual Testing : Function to Classify Text Messages
To simplify the use of this model for predicting the category of new messages, we create a function that takes a text input and classifies it as spam or ham.
Python3
def classify_message(model, vectorizer, message): message_vect = vectorizer.transform([message]) prediction = model.predict(message_vect) return "ham" if prediction[ 0 ] = = 0 else "spam" # Example of using the function message = "Congratulations! You've won a free ticket to Bahamas!" print (classify_message(classifier, vectorizer, message)) |
Output:
ham
This function first vectorizes the input text using the previously fitted CountVectorizer, then predicts the category using the trained logistic regression model, and finally returns the prediction as a human-readable label.
Text Classification using Logistic Regression
Text classification is the process of automatically assigning labels or categories to pieces of text. This has tons of applications, like sorting emails into spam or not-spam, figuring out if a product review is positive or negative, or even identifying the topic of a news article. In this article, we will see How logistic regression is used for text classification with Scikit-Learn.
Contact Us