Creating Animation to Show 4 Centroid-Based Clustering Algorithms using Python and Sklearn

Author:Murphy  |  View: 26894  |  Time: 2025-03-23 13:03:17
Photo by Mel Poole on Unsplash

Clustering analysis

Clustering analysis is an effective machine learning technique that groups data by their similarities and differences. The obtained data groups can be used for various purposes, such as segmenting, structuring, and decision-making.

To perform clustering analysis, many methods are available based on different algorithms. This article will mainly focus on centroid-based clustering, which is a common and useful technique.

Centroid-based clustering

Basically, the centroid-based technique works by repeatedly calculating to obtain optimal centroids (cluster centers) and then assigning data points to the nearest ones.

Due to having many iterations, Data Visualization can be used to express what happens during the process. Thus, the purpose of this article is to create animations to show the centroid-based process with Python and Sklearn.

An example of a clustering animation in this article. Image by Author.

Sklearn (Scikit-learn) is a powerful library that helps us perform clustering analysis efficiently. The followings are the centroid-based clustering techniques that we will work with.

  1. K-means clustering
  2. MiniBatch K-means clustering
  3. Bisecting K-means clustering
  4. Mean-Shift clustering

Let's get started


Getting Data

Start with importing libraries.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

As an example, this article will use a generated dataset, which can be easily created using sklearn's make_blobs(). If you have your dataset, this step can be skipped.

from sklearn.datasets import make_blobs
X, y = make_blobs(cluster_std=5, n_samples=1200,
                  n_features=2, random_state=42)
df_X = pd.DataFrame(X)
df_X.dropna(inplace=True)

sns.set_style('darkgrid')
sns.scatterplot(data = df_X, x = 0, y = 1, linewidth=0.5)
plt.show()

1. K-Means clustering

This is a common method for centroid-based clustering. The process can be briefly explained: starting with defining the number of clusters. Next, some data points are randomly selected as initial centroids. The other data points are assigned to the closest centroids using minimum Euclidean distance.

Then, re-initializing the centroids by calculating the average of each cluster's data points. Thus, the centroids are updated. After that, repeat the process of assigning and re-initializing. The algorithm will keep on iterating until achieving the optimal centroids.

Now, let's work with Python code. Start with defying a list of iteration numbers. As an example, this article will run only the first ten iterations. If you want to change the number, please feel free to modify the code below.

iter_num = [i+1 for i in range(10)]
iter_num

#[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

Define a function and variables.

def apply_model(model_in, df):
    clus = model_in.fit_predict(df)
    cent = model_in.cluster_centers_
    #decision boundary
    z = model_in.predict(np.c_[xx.ravel(), yy.ravel()])
    z = z.reshape(xx.shape)
    return clus, cent, z

h = 0.02 
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max,h))

Here comes the clustering process. We can use Sklearn's Kmeans, the defined function, and Python's for-loop to return three values: centroids, clustering labels, and boundaries in each iteration. These values will be kept in lists for plotting later.

from sklearn.cluster import KMeans
df_ = df_X.copy()
centroids = None
keep_cent, keep_clus, keep_Z = [], [], []

for n in iter_num:
    model= KMeans(n_clusters=5, random_state=42, max_iter=n, n_init=1,
                  init=(centroids if centroids is not None else 'k-means++'))
    cluster, centroids, Z = apply_model(model, df_)
    keep_clus.append(cluster)
    keep_cent.append(centroids)
    keep_Z.append(Z)

Create a DataFrame from the list of clustering labels.

col_name = ['Iter '+str(i) for i in iter_num]

df_iter = pd.DataFrame([list(i) for i in zip(*keep_clus)],columns=col_name)
df_plot = df_.join(df_iter)
df_plot.head()

Define a function applying for-loop to create scatter plots. This function will also be applied to visualize other clustering techniques in this article as well. Please consider that the results are exported onto your computer as PNG files for combining into animation as a GIF file later.

def plot_clus(names, Z_val, ctds):
    sns.set_style('darkgrid', {'axes.grid' : False})
    for i,z,c in zip(names, Z_val, ctds):
        plt.figure(1)
        plt.clf()
        plt.imshow(z, interpolation='nearest',
               extent=(xx.min(), xx.max(), yy.min(), yy.max()),
               cmap='coolwarm_r',
               aspect='auto',
               origin='lower', alpha=0.6)
        sns.scatterplot(data = df_plot, x = 0, y = 1, hue = i,
                        palette='viridis', linewidth=0.1, alpha=0.8)
        plt.scatter(c[:, 0], c[:, 1], s=92, marker = '^', c='red', lw=0.5)
        plt.xlabel('')
        plt.ylabel('')
        plt.legend(title=i, loc='upper right')
        plt.savefig(i+'.png', bbox_inches = 'tight', dpi=240)
    return plt.show()

Plot the K-means clustering result.

plot_clus(col_name, keep_Z, keep_cent)
An example of a scatter plot showing the K-means clustering process. Image by Author.

Define a function to combine the plots into an animation. The result will be saved onto your computer.

from PIL import Image
import imageio
def animation(names, save_name, time_speed):
    img = []
    for i in names:           # read PNG files
        myImage = Image.open(i+'.png')
        img.append(myImage)
    #export the GIF file, output location can be changed
    imageio.mimsave(save_name, img, duration=time_speed)

Apply the function.

animation(col_name, 'animation_KMeans.gif', 0.4)

Voilà!!

Animation showing the K-means clustering process. Image by Author.

The animation shows that data points are assigned to different clusters in the first iteration. Then, some data are allocated to adjacent clusters due to the recalculation. The red triangles show the centroids in each step. The process will keep continuing until the centroids reach the optimal points.


2. MiniBatch K-means clustering

Instead of processing all data points as the K-means clustering, MiniBatch K-means randomly takes small batches of data for each iteration. This results in improving the clustering speed while returning slightly different outcomes.

Sklearn's MiniBatchKmeans function can be used to perform the MiniBatch K-means clustering. We will use the same steps as previously explained in the K-means process.

from sklearn.cluster import MiniBatchKMeans
ctrd = None
keep_cent, keep_clus, keep_Z = [], [], []

for n in iter_num:
    model = MiniBatchKMeans(n_clusters=5, random_state=42,
                            max_iter=n, n_init=1,
                            init=(ctrd if ctrd is not None else 'k-means++'))
    cluster, centroids, Z = apply_model(model, df_)
    keep_clus.append(cluster)
    keep_cent.append(centroids)
    keep_Z.append(Z)

df_iter = pd.DataFrame([list(i) for i in zip(*keep_clus)], columns=col_name)
df_plot = df_.join(df_iter)
plot_clus(col_name, keep_Z, keep_cent)
An example of a scatter plot showing the MiniBatch K-means clustering process. Image by Author.

Thanks to the defined function from the previous step. We can create an animation with just one line of code.

animation(col_name, 'animation_miniBKMeans.gif', 0.2)
Animation showing the MiniBatch K-means clustering process. Image by Author.

Compared with the result from K-means, clustering areas in the first iteration of MiniBatch K-means is approximately close to K-means' fifth or sixth iteration. Thus, it can be noticed that the MiniBatch K-means returns faster clustering.


3. Bisecting K-means clustering

Bisecting K-means applies K-means to divide the whole data points into two clusters in the first step. After that, the algorithm will select the cluster with the largest sum of squares to be divided into two clusters again. The process will keep repeating until the total number of clusters equals K.

This algorithm can also be considered a hybrid method between divisive hierarchical clustering and K-means. It is an effective way to deal with a large number of K.

Now we are going to work with the code part. Start by creating a list of numbers to use with the for-loop function. To compare with other algorithms in this article, I will make a list of numbers from one to five.

n_num = [i+1 for i in range(5)]
col_name = ['Iter '+str(i) for i in n_num]
n_num

#[1, 2, 3, 4, 5]

Next, we will use Sklearn‘s BisectingKMeans function to do the Bisecting K-means clustering.

from sklearn.cluster import BisectingKMeans
keep_cent, keep_clus, keep_Z = [], [], []

for n in n_num:
    model = BisectingKMeans(n_clusters=n, random_state=42,
                            max_iter=1, n_init=1)
    cluster, centroids, Z = apply_model(model, df_)
    keep_clus.append(cluster)
    keep_cent.append(centroids)
    keep_Z.append(Z)

df_iter = pd.DataFrame([list(i) for i in zip(*keep_clus)],columns=col_name)
df_plot = df_X.join(df_iter)
plot_clus(col_name, keep_Z, keep_cent)
An example of a scatter plot showing the Bisecting K-means clustering process. Image by Author.

Combine the plots into an animation.

animation(col_name, 'animation_BisectingKMeans.gif', 0.9)
Animation showing the Bisecting K-means clustering process. Image by Author.

From the animation, it can be seen that the whole data points in the first step are divided into two clusters in the second step. Then, the cluster with a larger sum of squared is divided again into two clusters. Thus, we have three clusters in the third iteration. The bisecting K-means process will continue until the cluster number reaches the K number, which is five in this article.


4. Mean-shift clustering

Mean-shift clustering calculates the local mean point within a certain radius (bandwidth) and assigns data points toward the highest density. The algorithm will keep calculating until convergence. This method is also known as a hill-climbing algorithm due to its behavior.

Note: Mean-shift clustering technique is also considered a density-based algorithm [link1, link2] as well.

Next, let's use Sklearn‘s MeanShift function to do the Mean-shift clustering. The main difference compared with the previous K-means methods is that there is no need to define the number of K. By the way, we need to calculate the bandwidth for use as a parameter for Mean-shift clustering.

from sklearn.cluster import MeanShift, estimate_bandwidth
keep_cent, keep_clus, keep_Z = [], [], []
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=250)

seeds = None 
for n in n_num:
    model = MeanShift(bandwidth = bandwidth,
                      seeds = (seeds if seeds is not None else None),
                      max_iter=1)
    cluster, seeds, Z = apply_model(model, df_)
    keep_clus.append(cluster)
    keep_cent.append(seeds)
    keep_Z.append(Z)

df_iter = pd.DataFrame([list(i) for i in zip(*keep_clus)],columns=col_name)
df_plot = df_.join(df_iter)
plot_clus(col_name, keep_Z, keep_cent)
An example of a scatter plot showing the Meanshift K-means clustering process. Image by Author.

Apply the function to create an animation.

animation(col_name, 'animation_MeanShift.gif', 1.1)

Ta-da!!

Animation showing the Meanshift clustering process. Image by Author.

In the first iteration, we can see many centroids due to the algorithm trying to find the highest density of data points within the bandwidth used. In the following iteration there are only three centroids in the second iteration and two centroids left in the following iterations. This happens because the algorithm keeps finding a higher density of data points for assigning centroids.


Key takeaways

In summary, the centroid-based clustering methods aim to find centroids to be used as references for clustering data. In order to achieve that, the algorithm has to calculate the centroids to get the optimal ones repeatedly. This results in the changes of centroids and clusters during the process.

The purpose of this article is to apply data visualization to express the process, which can be useful in showing how each algorithm works and monitoring the change in the process. If you have any suggestions or questions, please feel free to comment.

Thanks for reading.


Here are some of my data visualization articles that you may find interesting:

  • Visualizing 3 Sklearn Cross-validation: K-Fold, Shuffle & Split, and Time Series Split (link)
  • 9 Visualizations with Python that Catch More Attention than a Bar Chart (link)
  • 8 Visualizations with Python to Handle Multiple Time-Series Data (link)
  • Visualizing the Effect of Multicollinearity on Multiple Regression Model (link)

References

Tags: Clustering Data Science Data Visualization Machine Learning Tips And Tricks

Comment