Creating Animation to Show 4 Centroid-Based Clustering Algorithms using Python and Sklearn

Clustering analysis
Clustering analysis is an effective machine learning technique that groups data by their similarities and differences. The obtained data groups can be used for various purposes, such as segmenting, structuring, and decision-making.
To perform clustering analysis, many methods are available based on different algorithms. This article will mainly focus on centroid-based clustering, which is a common and useful technique.
Centroid-based clustering
Basically, the centroid-based technique works by repeatedly calculating to obtain optimal centroids (cluster centers) and then assigning data points to the nearest ones.
Due to having many iterations, Data Visualization can be used to express what happens during the process. Thus, the purpose of this article is to create animations to show the centroid-based process with Python and Sklearn.

Sklearn (Scikit-learn) is a powerful library that helps us perform clustering analysis efficiently. The followings are the centroid-based clustering techniques that we will work with.
- K-means clustering
- MiniBatch K-means clustering
- Bisecting K-means clustering
- Mean-Shift clustering
Let's get started
Getting Data
Start with importing libraries.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
As an example, this article will use a generated dataset, which can be easily created using sklearn's make_blobs()
. If you have your dataset, this step can be skipped.
from sklearn.datasets import make_blobs
X, y = make_blobs(cluster_std=5, n_samples=1200,
n_features=2, random_state=42)
df_X = pd.DataFrame(X)
df_X.dropna(inplace=True)
sns.set_style('darkgrid')
sns.scatterplot(data = df_X, x = 0, y = 1, linewidth=0.5)
plt.show()

1. K-Means clustering
This is a common method for centroid-based clustering. The process can be briefly explained: starting with defining the number of clusters. Next, some data points are randomly selected as initial centroids. The other data points are assigned to the closest centroids using minimum Euclidean distance.
Then, re-initializing the centroids by calculating the average of each cluster's data points. Thus, the centroids are updated. After that, repeat the process of assigning and re-initializing. The algorithm will keep on iterating until achieving the optimal centroids.
Now, let's work with Python code. Start with defying a list of iteration numbers. As an example, this article will run only the first ten iterations. If you want to change the number, please feel free to modify the code below.
iter_num = [i+1 for i in range(10)]
iter_num
#[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Define a function and variables.
def apply_model(model_in, df):
clus = model_in.fit_predict(df)
cent = model_in.cluster_centers_
#decision boundary
z = model_in.predict(np.c_[xx.ravel(), yy.ravel()])
z = z.reshape(xx.shape)
return clus, cent, z
h = 0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max,h))
Here comes the clustering process. We can use Sklearn's Kmeans, the defined function, and Python's for-loop to return three values: centroids, clustering labels, and boundaries in each iteration. These values will be kept in lists for plotting later.
from sklearn.cluster import KMeans
df_ = df_X.copy()
centroids = None
keep_cent, keep_clus, keep_Z = [], [], []
for n in iter_num:
model= KMeans(n_clusters=5, random_state=42, max_iter=n, n_init=1,
init=(centroids if centroids is not None else 'k-means++'))
cluster, centroids, Z = apply_model(model, df_)
keep_clus.append(cluster)
keep_cent.append(centroids)
keep_Z.append(Z)
Create a DataFrame from the list of clustering labels.
col_name = ['Iter '+str(i) for i in iter_num]
df_iter = pd.DataFrame([list(i) for i in zip(*keep_clus)],columns=col_name)
df_plot = df_.join(df_iter)
df_plot.head()

Define a function applying for-loop to create scatter plots. This function will also be applied to visualize other clustering techniques in this article as well. Please consider that the results are exported onto your computer as PNG files for combining into animation as a GIF file later.
def plot_clus(names, Z_val, ctds):
sns.set_style('darkgrid', {'axes.grid' : False})
for i,z,c in zip(names, Z_val, ctds):
plt.figure(1)
plt.clf()
plt.imshow(z, interpolation='nearest',
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
cmap='coolwarm_r',
aspect='auto',
origin='lower', alpha=0.6)
sns.scatterplot(data = df_plot, x = 0, y = 1, hue = i,
palette='viridis', linewidth=0.1, alpha=0.8)
plt.scatter(c[:, 0], c[:, 1], s=92, marker = '^', c='red', lw=0.5)
plt.xlabel('')
plt.ylabel('')
plt.legend(title=i, loc='upper right')
plt.savefig(i+'.png', bbox_inches = 'tight', dpi=240)
return plt.show()
Plot the K-means clustering result.
plot_clus(col_name, keep_Z, keep_cent)

Define a function to combine the plots into an animation. The result will be saved onto your computer.
from PIL import Image
import imageio
def animation(names, save_name, time_speed):
img = []
for i in names: # read PNG files
myImage = Image.open(i+'.png')
img.append(myImage)
#export the GIF file, output location can be changed
imageio.mimsave(save_name, img, duration=time_speed)
Apply the function.
animation(col_name, 'animation_KMeans.gif', 0.4)
Voilà!!

The animation shows that data points are assigned to different clusters in the first iteration. Then, some data are allocated to adjacent clusters due to the recalculation. The red triangles show the centroids in each step. The process will keep continuing until the centroids reach the optimal points.
2. MiniBatch K-means clustering
Instead of processing all data points as the K-means clustering, MiniBatch K-means randomly takes small batches of data for each iteration. This results in improving the clustering speed while returning slightly different outcomes.
Sklearn's MiniBatchKmeans function can be used to perform the MiniBatch K-means clustering. We will use the same steps as previously explained in the K-means process.
from sklearn.cluster import MiniBatchKMeans
ctrd = None
keep_cent, keep_clus, keep_Z = [], [], []
for n in iter_num:
model = MiniBatchKMeans(n_clusters=5, random_state=42,
max_iter=n, n_init=1,
init=(ctrd if ctrd is not None else 'k-means++'))
cluster, centroids, Z = apply_model(model, df_)
keep_clus.append(cluster)
keep_cent.append(centroids)
keep_Z.append(Z)
df_iter = pd.DataFrame([list(i) for i in zip(*keep_clus)], columns=col_name)
df_plot = df_.join(df_iter)
plot_clus(col_name, keep_Z, keep_cent)

Thanks to the defined function from the previous step. We can create an animation with just one line of code.
animation(col_name, 'animation_miniBKMeans.gif', 0.2)

Compared with the result from K-means, clustering areas in the first iteration of MiniBatch K-means is approximately close to K-means' fifth or sixth iteration. Thus, it can be noticed that the MiniBatch K-means returns faster clustering.
3. Bisecting K-means clustering
Bisecting K-means applies K-means to divide the whole data points into two clusters in the first step. After that, the algorithm will select the cluster with the largest sum of squares to be divided into two clusters again. The process will keep repeating until the total number of clusters equals K.
This algorithm can also be considered a hybrid method between divisive hierarchical clustering and K-means. It is an effective way to deal with a large number of K.
Now we are going to work with the code part. Start by creating a list of numbers to use with the for-loop function. To compare with other algorithms in this article, I will make a list of numbers from one to five.
n_num = [i+1 for i in range(5)]
col_name = ['Iter '+str(i) for i in n_num]
n_num
#[1, 2, 3, 4, 5]
Next, we will use Sklearn‘s BisectingKMeans
function to do the Bisecting K-means clustering.
from sklearn.cluster import BisectingKMeans
keep_cent, keep_clus, keep_Z = [], [], []
for n in n_num:
model = BisectingKMeans(n_clusters=n, random_state=42,
max_iter=1, n_init=1)
cluster, centroids, Z = apply_model(model, df_)
keep_clus.append(cluster)
keep_cent.append(centroids)
keep_Z.append(Z)
df_iter = pd.DataFrame([list(i) for i in zip(*keep_clus)],columns=col_name)
df_plot = df_X.join(df_iter)
plot_clus(col_name, keep_Z, keep_cent)

Combine the plots into an animation.
animation(col_name, 'animation_BisectingKMeans.gif', 0.9)

From the animation, it can be seen that the whole data points in the first step are divided into two clusters in the second step. Then, the cluster with a larger sum of squared is divided again into two clusters. Thus, we have three clusters in the third iteration. The bisecting K-means process will continue until the cluster number reaches the K number, which is five in this article.
4. Mean-shift clustering
Mean-shift clustering calculates the local mean point within a certain radius (bandwidth) and assigns data points toward the highest density. The algorithm will keep calculating until convergence. This method is also known as a hill-climbing algorithm due to its behavior.
Note: Mean-shift clustering technique is also considered a density-based algorithm [link1, link2] as well.
Next, let's use Sklearn‘s MeanShift
function to do the Mean-shift clustering. The main difference compared with the previous K-means methods is that there is no need to define the number of K. By the way, we need to calculate the bandwidth for use as a parameter for Mean-shift clustering.
from sklearn.cluster import MeanShift, estimate_bandwidth
keep_cent, keep_clus, keep_Z = [], [], []
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=250)
seeds = None
for n in n_num:
model = MeanShift(bandwidth = bandwidth,
seeds = (seeds if seeds is not None else None),
max_iter=1)
cluster, seeds, Z = apply_model(model, df_)
keep_clus.append(cluster)
keep_cent.append(seeds)
keep_Z.append(Z)
df_iter = pd.DataFrame([list(i) for i in zip(*keep_clus)],columns=col_name)
df_plot = df_.join(df_iter)
plot_clus(col_name, keep_Z, keep_cent)

Apply the function to create an animation.
animation(col_name, 'animation_MeanShift.gif', 1.1)
Ta-da!!

In the first iteration, we can see many centroids due to the algorithm trying to find the highest density of data points within the bandwidth used. In the following iteration there are only three centroids in the second iteration and two centroids left in the following iterations. This happens because the algorithm keeps finding a higher density of data points for assigning centroids.
Key takeaways
In summary, the centroid-based clustering methods aim to find centroids to be used as references for clustering data. In order to achieve that, the algorithm has to calculate the centroids to get the optimal ones repeatedly. This results in the changes of centroids and clusters during the process.
The purpose of this article is to apply data visualization to express the process, which can be useful in showing how each algorithm works and monitoring the change in the process. If you have any suggestions or questions, please feel free to comment.
Thanks for reading.
Here are some of my data visualization articles that you may find interesting:
- Visualizing 3 Sklearn Cross-validation: K-Fold, Shuffle & Split, and Time Series Split (link)
- 9 Visualizations with Python that Catch More Attention than a Bar Chart (link)
- 8 Visualizations with Python to Handle Multiple Time-Series Data (link)
- Visualizing the Effect of Multicollinearity on Multiple Regression Model (link)
References
- Techniques, A. (n.d.). What are the benefits and challenges of using cluster analysis in decision making?. Cluster Analysis for Decision Making: Benefits and Challenges. https://www.linkedin.com/advice/3/what-benefits-challenges-using-cluster-analysis
- Wikimedia Foundation. (2023, July 24). Cluster analysis. Wikipedia. https://en.wikipedia.org/wiki/Cluster_analysis
- Examples. Scikit. (n.d.). https://scikit-learn.org/stable/auto_examples
- Sharma, N. (2023, April 19). K-means clustering explained. neptune.ai. https://neptune.ai/blog/k-means-clustering
- GeeksforGeeks. (2023b, January 23). ML: Mini batch K-means clustering algorithm. GeeksforGeeks. https://www.geeksforgeeks.org/ml-mini-batch-k-means-clustering-algorithm
- Firdaus, A. (2020, May 9). Bisecting K-means clustering. Medium. https://medium.com/@afrizalfir/bisecting-kmeans-clustering-5bc17603b8a2
- Yufeng. (2022, February 22). Understanding mean shift clustering and implementation with Python. Medium. https://towardsdatascience.com/understanding-mean-shift-clustering-and-implementation-with-python-6d5809a2ac40
- Wikimedia Foundation. (2023, July 24). Mean shift. Wikipedia. https://en.wikipedia.org/wiki/Mean_shift