K-means Clustering

A well-liked unsupervised machine learning technique for dividing data points into K clusters is K-means clustering. The approach updates the centroids to minimize the within-cluster sum of squared distances by iteratively assigning each data point to the closest centroid based on the Euclidean distance. K-means may converge to a local minimum and is sensitive to the centroids that are first chosen.

Implementing K-means clustering using PyTorch

1. Importing Necessary Libraries


import torch
import torch.nn.functional as F
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

2. Generate Synthetic Data and convert data to PyTorch tensor


# Generate synthetic data
data, _ = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
# Convert data to PyTorch tensor
tensor_data = torch.from_numpy(data).float()

3. Perform K-means Clustering

In this code, we are going to implement K-means clustering:

  • We randomly initialize 4 data points from the dataset as centroids and defined the number of iterations.
  • In the main loop, we calculate the distances between each data point and each centroid using the Euclidean distance. Then, we assign each data point to the closest centroid based on the calculated distances. Then, we update the centroids by computing the mean of the data points assigned to each centroid.
  • These steps are repeated for the specified number of iterations.
  • This process ultimately converges to a set of centroids that represent the centers of the clusters in the data.


# Initialize centroids randomly
centroids = tensor_data[torch.randperm(tensor_data.size(0))[:4]]
# Define the number of iterations
num_iterations = 100
for _ in range(num_iterations):
    # Calculate distances from data points to centroids
    distances = torch.cdist(tensor_data, centroids)
    # Assign each data point to the closest centroid
    _, labels = torch.min(distances, dim=1)
    # Update centroids by taking the mean of data points assigned to each centroid
    for i in range(4):
        if torch.sum(labels == i) > 0:
            centroids[i] = torch.mean(tensor_data[labels == i], dim=0)

4. Visualize Clusters


# Visualize clusters
plt.scatter(data[:, 0], data[:, 1], c=labels.numpy(), cmap='viridis')
plt.scatter(centroids[:, 0], centroids[:, 1], marker='X', s=200, color='red')


PyTorch for Unsupervised Clustering

The aim of unsupervised clustering, a fundamental machine learning problem, is to divide data into groups or clusters based on resemblance or some underlying structure. One well-liked deep learning framework for unsupervised clustering problems is PyTorch.

