Scikit-Learn for K-Means ClusteringΒΆ
Learning ObjectivesΒΆ
- Gain a high-level understanding of what is machine learning, and when you should or shouldnβt use it.
- Apply
scikit-learn
and other popular data science Python libraries for a data science project: using K-means clustering on the Iris dataset. - Develop proficiency in analyzing data and data visualization.
Before we get started, here are some tips for beginners:ΒΆ
- We'll be covering a wide range of data science libraries. Pause and revisit concepts as needed. Repetition will help solidify understanding of topics.
- If you don't understand a specific topic, or are running into issues, there are many sources out there that can help you:
- You can read the documentation for a specific library or function.
- Googling concepts or visiting StackOverflow is another great way to learn about common issues you'll encounter.
- Finally, ChatGPT can be helpful for co-coding and getting started with ideas, but be wary of its answers.
- Don't be afraid to explore different solutions and make mistakes. There are many paths to the same destination.
1. What is machine learning?ΒΆ
Data is being generated and captured in all industries, from e-commerce websites to hospitals. That's because data contains important information that we can use for generating insights and making predictions about upcoming events. To leverage all this data, we can use machine learning and artificial intelligence to generate these insights and predictions.
Machine learning uses math to learn patterns in data, and storing these patterns in a model. With these learned patterns, new data can be fed into the model to come up with predictions. Note that machine learning models learn these patterns without hard-coding any rules into it. This is powerful, because it learns through examples and data, rather than human intervention.
The Machine Learning WorkflowΒΆ

In a machine learning workflow, there are 6 general steps:
Getting data related to a particular subject area or problem You can do this either by sourcing it from a database, generating it, or receiving it from someone else.
Preparing your data Likely, when you first get your data, it is not in a suitable format that computers can use for learning. You need to do process it.
Creating a model A model is a mathematical representation that learns patterns in data. There are different types of models that are suitable for learning different tasks. You can either use a known model architecture, or design a custom model from first-principles.
Train your model on new data for different tasks Below are some of the different kinds of learning tasks machines can generally do:
- Supervised Learning: Predict numeric values or categories based on associated attributes
- Unsupervised Learning: Group similar objects together
- Semi-Supervised Learning: Group data into labels (unsupervised)and make predictions (supervised)
- Reinforcement Learning: Have a model interact with the environment and learn best actions to take.
Evaluate your models' performance With a new model, we need to see if it is good at its task. We usually use a subset of the data to validate and measure its performance.
Using your model for something You got this cool model that can make predictions. You should use it to make a difference in the world!
When can you use machine learning?ΒΆ
Machine learning can be used to save time, effort, and capital. You're having a computer do all of the heavy lifting for tedious, repetitve, and mundane tasks instead of a human.
Below are some situations when you would want to use machine learning:
- There are patterns in the data that are predictable.
- The patterns are hard for humans to make decisions / take action, but are easily understood by computers.
- There is data that is representative of the problem you want to solve.
- Incoming or new data (likely) shares the patterns with your dataset.
Check-in 1: Think about situations where machine learning is NOT useful? What are some situations where you wouldn't want to use machine learning?ΒΆ
When NOT to use machine learningΒΆ
While machine learning is useful for many things, it's not a catch-all. There are many situations where you wouldn't want to, or could apply machine learning.
Here are some scenarios where you wouldn't want to use ML:
- The event you're trying to predict doesn't have a distinct pattern
- The problem you're trying to solve is simple.
- It's not cost effective to implement ML.
- It's unethical to apply ML for a given problem.
2. Project OverviewΒΆ
Now that we discussed at a high-level what machine learning is, we will now go over the data science project we'll use to demonstrate how to examine and visualize data, train models, and evaluate model performance.
The Iris DatasetΒΆ
In this project, we'll be using the Iris dataset, a well-curated dataset that contains the sepal and petal length/width data (in cm) for three species of Iris flowers. Below is an image that shows what the sepal and petal are on a plant:

And here are images of the flowers that are described in the Iris dataset.
Iris setosa | Iris versicolor | Iris virginica |
---|---|---|
![]() |
![]() |
![]() |
Source | Source | Source |
Unsupervised Learning and ClusteringΒΆ
Unsupervised learning algorithms try to learn patterns within unlabeled data, data that does not have an assigned category mapped back to the individual entity. Because labeling data requires a lot of time and money to do, unsupervised algorithms are ideal in situations where you don't have labels or want to examine the underlying characteristics of a dataset.

A task within unsupervised learning is called clustering. The goal is to take several data points (represented as gray dots above), identify the similarities between data points using a model (black box), and group them together based on how similar they are relative to other clusters of data (in this case, red/blue).
Project GoalΒΆ
In this notebook, we'll demonstrate the entire data science workflow that you can use in your other machine learning projects. We'll specifically cover:
- Basics of machine learning and data visualization Python libraries
- Downloading the Iris dataset and data processing
- Training a K-means clustering algorithm
- Evaluating how well K-means clustering separates out the Iris flowers
3. Training a K-means clustering algorithm on the Iris DatasetΒΆ
Overview of scikit-learn
and relevant data science librariesΒΆ

scikit-learn
is a popular Python library that is used for several common machine learning tasks. It is built on other data science libraries such as NumPy
, SciPy
, and matplotlib
and is well-integrated in many data science environments and workflows. Weβll go over how to train a machine learning model called K-means clustering using scikit-learn
, and other common data science libraries for data analysis on the Iris dataset.
A short description of each library is provided below, with a link to the documentation. I would highly recommend reading the documentation before continuing, but we'll go over some usage of these libraries in the project:
numpy
: A library for scientific computing for arrays and matrices.pandas
: A data manipulation and analysis library.scikit-learn
: A machine learning library. It is built on top of other libraries likenumpy
.matplotlib
: A plotting and data visualization library.seaborn
: A data visualization library built on top of matplotlib, making it easier to make visually appearing statistical plots.
Installing ML/Data Science LibrariesΒΆ
To install the libraries within this notebook, run the code block below. If you want to create a virtual environment for your notebook, follow the instructions in the README file.
!pip install --upgrade pip
!pip install scikit-learn
!pip install matplotlib
!pip install seaborn
!pip install numpy
Requirement already satisfied: pip in ./venv/lib/python3.10/site-packages (23.2) Requirement already satisfied: scikit-learn in ./venv/lib/python3.10/site-packages (1.3.0) Requirement already satisfied: numpy>=1.17.3 in ./venv/lib/python3.10/site-packages (from scikit-learn) (1.25.1) Requirement already satisfied: scipy>=1.5.0 in ./venv/lib/python3.10/site-packages (from scikit-learn) (1.11.1) Requirement already satisfied: joblib>=1.1.1 in ./venv/lib/python3.10/site-packages (from scikit-learn) (1.3.1) Requirement already satisfied: threadpoolctl>=2.0.0 in ./venv/lib/python3.10/site-packages (from scikit-learn) (3.2.0) Requirement already satisfied: matplotlib in ./venv/lib/python3.10/site-packages (3.7.2) Requirement already satisfied: contourpy>=1.0.1 in ./venv/lib/python3.10/site-packages (from matplotlib) (1.1.0) Requirement already satisfied: cycler>=0.10 in ./venv/lib/python3.10/site-packages (from matplotlib) (0.11.0) Requirement already satisfied: fonttools>=4.22.0 in ./venv/lib/python3.10/site-packages (from matplotlib) (4.41.0) Requirement already satisfied: kiwisolver>=1.0.1 in ./venv/lib/python3.10/site-packages (from matplotlib) (1.4.4) Requirement already satisfied: numpy>=1.20 in ./venv/lib/python3.10/site-packages (from matplotlib) (1.25.1) Requirement already satisfied: packaging>=20.0 in ./venv/lib/python3.10/site-packages (from matplotlib) (23.1) Requirement already satisfied: pillow>=6.2.0 in ./venv/lib/python3.10/site-packages (from matplotlib) (10.0.0) Requirement already satisfied: pyparsing<3.1,>=2.3.1 in ./venv/lib/python3.10/site-packages (from matplotlib) (3.0.9) Requirement already satisfied: python-dateutil>=2.7 in ./venv/lib/python3.10/site-packages (from matplotlib) (2.8.2) Requirement already satisfied: six>=1.5 in ./venv/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0) Requirement already satisfied: seaborn in ./venv/lib/python3.10/site-packages (0.12.2) Requirement already satisfied: numpy!=1.24.0,>=1.17 in ./venv/lib/python3.10/site-packages (from seaborn) (1.25.1) Requirement already satisfied: pandas>=0.25 in ./venv/lib/python3.10/site-packages (from seaborn) (2.0.3) Requirement already satisfied: matplotlib!=3.6.1,>=3.1 in ./venv/lib/python3.10/site-packages (from seaborn) (3.7.2) Requirement already satisfied: contourpy>=1.0.1 in ./venv/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.1->seaborn) (1.1.0) Requirement already satisfied: cycler>=0.10 in ./venv/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.1->seaborn) (0.11.0) Requirement already satisfied: fonttools>=4.22.0 in ./venv/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.1->seaborn) (4.41.0) Requirement already satisfied: kiwisolver>=1.0.1 in ./venv/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.1->seaborn) (1.4.4) Requirement already satisfied: packaging>=20.0 in ./venv/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.1->seaborn) (23.1) Requirement already satisfied: pillow>=6.2.0 in ./venv/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.1->seaborn) (10.0.0) Requirement already satisfied: pyparsing<3.1,>=2.3.1 in ./venv/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.1->seaborn) (3.0.9) Requirement already satisfied: python-dateutil>=2.7 in ./venv/lib/python3.10/site-packages (from matplotlib!=3.6.1,>=3.1->seaborn) (2.8.2) Requirement already satisfied: pytz>=2020.1 in ./venv/lib/python3.10/site-packages (from pandas>=0.25->seaborn) (2023.3) Requirement already satisfied: tzdata>=2022.1 in ./venv/lib/python3.10/site-packages (from pandas>=0.25->seaborn) (2023.3) Requirement already satisfied: six>=1.5 in ./venv/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib!=3.6.1,>=3.1->seaborn) (1.16.0) Requirement already satisfied: numpy in ./venv/lib/python3.10/site-packages (1.25.1)
4. The Iris DatasetΒΆ
BackgroundΒΆ
The Iris dataset is a well-known dataset to practice machine learning concepts such as clustering and classification. In this dataset, we are interested in how a clustering algorithm will group the three species of Iris flowers in the dataset, which are Setosa, Versicolour, and Virginica. These designations are known as labels.
In the dataset, we have four attributes, or features that are associated with each species of Irises: the Sepal Length, Sepal Width, Petal Length, and Petal Width, which are all in cm. There are 50 samples per species, with a total number of 150 samples in the dataset.
Load the Iris datasetΒΆ
Let's first load the Python libraries we just downloaded to the notebook:
from sklearn import datasets
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
Now we'll import the Iris dataset from scikit-learn
using the dataset
class.
iris = datasets.load_iris(as_frame=True)
Let's go over what the code above does:
The class
datasets
contains a method calledload_iris
that gets the Iris dataset.The optional argument
as_frame=True
returns aPandas dataframe
, which is common data structure for handling tabular data.- A
dataframe
contains rows and columns- Rows represent a single entity (e.g. flower)
- Columns represent attributes for that entity (e.g. sepal length)
- A
The variable
iris
is a dictionary-like object that contains several attributes about the Iris dataset. We can call the keys to get these attributes. Relevant keys include:data
: the pandas dataframe of features.target
: the numeric labels representing Iris species for the associated flower indata
.feature_names
: the feature names for the columns indata
.target_names
: the Iris species names associated with the numeric labels intarget
.
More information about what can be returned can be found here.
We'll also print the contents of each relevant object:
print(f"Contents of data: \n {iris.data}")
Contents of data: sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) 0 5.1 3.5 1.4 0.2 1 4.9 3.0 1.4 0.2 2 4.7 3.2 1.3 0.2 3 4.6 3.1 1.5 0.2 4 5.0 3.6 1.4 0.2 .. ... ... ... ... 145 6.7 3.0 5.2 2.3 146 6.3 2.5 5.0 1.9 147 6.5 3.0 5.2 2.0 148 6.2 3.4 5.4 2.3 149 5.9 3.0 5.1 1.8 [150 rows x 4 columns]
print(f"Contents of target: \n {iris.target}")
Contents of target: 0 0 1 0 2 0 3 0 4 0 .. 145 2 146 2 147 2 148 2 149 2 Name: target, Length: 150, dtype: int64
print(f"Contents of target_names. \nNote that 0 = setosa, 1 = versicolor, and 2 = virginica: \n {iris.target_names}")
Contents of target_names. Note that 0 = setosa, 1 = versicolor, and 2 = virginica: ['setosa' 'versicolor' 'virginica']
Now that we have loaded the dataset into this notebook, we'll examine the data in more depth.
4. Exploratory Data AnalysisΒΆ
Before we think about using the data for a machine learning project, we should understand the properties of the dataset. This process of understanding the dataset by examining and visualizing its contents is called exploratory data analysis (EDA).
There are many benefits to performing EDA, including:
- Summarizing the main characteristics for communicating the properties of the data to others
- Identifying obvious patterns and relationships
- Detecting anomalies and outliers in the data
- Formulating hypotheses, based on data
We'll go over two data visualization methods that will help with EDA: box plots and pairplots.
Box plotsΒΆ
First, letβs examine the distribution for each feature using a box plot. Box plots are useful for visualizing the distribution of a feature within a dataset, including the median value, quartiles, and the potential outliers.
The diagram below shows the components of a boxplot:
The code below draws boxplots for the different features in the Iris dataset. You don't need to know all the components to the code below, just that it is used to make the plots more readable and interpretable.
sns.set()
plt.figure(figsize=(12, 4))
for index, feature in enumerate(iris.feature_names):
plt.subplot(1, 4, index + 1)
sns.boxplot(x=iris.target_names[iris.target], y=iris.data[feature], palette='Set3')
plt.xlabel('Species')
plt.ylabel(feature)
plt.tight_layout()
plt.show()
By eye-balling the distributions, you can clearly see that there are differences between the feature distributions across all three species. In general, that's great, because that suggests that the targets (species) are separable by the values of the features.
In our case, that means that the three species are likely going to group together in distinct clusters. However, we should note that different distributions does not necessarily guarentee that the data will be separable with any given clustering algorithm.
Exercise 2: Are there other distinguishing things you can note about the feature distribution?ΒΆ
Potential answers include:
setosa
tends to be very different thanversicolor
andvirginica
in terms of data distribution.- We can hypothesize that
setosa
will be much easier to distinguish from the other two Iris species. - We may also hypothesize that because there is distribution overlap between
versicolor
andvirginica
, the algorithm may swap labels between those two species.
- We can hypothesize that
- Sepal width is the only feature that looks different from the distribution patterns that emerge with the other features.
- Additionally, there is a lot of distribution overlap with this feature.
- We could hypothesize that this feature is not informative if the goal is to cluster Iris species based on species.
PairplotsΒΆ
A pairplot aims to plot the pairwise relationship between features in a dataset.
- It returns a N feature x N feature grid of histograms and scatter plots.
- Each element in the x and y axis represents a different feature.
- A histogram is shown when the two axes contain the same feature (e.g. x = sepal_width and y = sepal_width).
- A scatter plot is shown when the two axes contain different features, showing their pairwise relationship (e.g. x = sepal_width and y = sepal_length).
Note that the off-diagonal plots (plots separated by the histograms) show duplicate information, except the axes are flipped.
To get the pairplot code to work, we need to first combine the data objects into a dataframe. We'll store the contents of the dataframe in the variable df
.
df = pd.concat([iris.data, iris.target], axis=1)
df.head()
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
Now we can show the pairwise distribution of each feature, separated by the target
variable using the pairplot
method in Seaborn.
sns.pairplot(data=df, hue='target')
/Users/scottcampit/Projects/intro-to-ml/venv/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight self._figure.tight_layout(*args, **kwargs)
<seaborn.axisgrid.PairGrid at 0x17e7f8d00>
5. Training a K-Means Clustering AlgorithmΒΆ
Intuition for K-Means ClusteringΒΆ
Things that belong in the same category tend to have similar attributes. In other words, things that belong to the same category tend to have smaller differences in feature values between other members of the same category. Things that belong in different categories will have larger differences in feature values.
- For example above, consider two Iris Setosa plants. They are more likely to be close in Sepal Length, Sepal Width, Petal Length, and Petal Width, compared to those of Iris Versicolour or Iris Virginica.
Using this intuition, we can now understand big picture of what the K-means algorithm is trying to learn. The K-means algorithm aims to learn the best way to group samples together by minimizing the difference between a given sample (a single Iris Setosa plant) and the average of the samples in a cluster (the average of all Iris Setosa plants).
Thereβs more to the algorithm in a mathematical sense, but that is outside the scope of this lesson, and youβre well equipped to apply K-means clustering to real data.
Training a K-Means Algorithm using scikit-learn
ΒΆ
Letβs train our first ML model! Luckily, scikit-learn already has coded the math for the K-means clustering algorithm. You can load it using the following code:
from sklearn.cluster import KMeans
An important note about the K-means algorithm is that we need to tell it how many groups there are in the dataset, before it learns how to assign each sample to a group. Intuitively, since there are three classes of Iris flowers, we should tell K-means that there are 3 potential clusters in the data:
model = KMeans(n_clusters=3)
The code that actually trains the K-means algorithm is stored in the class object we call model
. This object contains the method fit_predict()
. By providing this method the training data, it will fit the model, and output a prediction of each class. We'll store the model predictions in the variable kmeans_pred
.
kmeans_pred = model.fit_predict(iris.data)
/Users/scottcampit/Projects/intro-to-ml/venv/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:1412: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning super()._check_params_vs_input(X, default_n_init=10)
Now the predictions we got from K-means clustering don't actually represent the labels in the original dataset. The code below performs a mapping between the true labels and the ones predicted by the model. You don't need to worry about the details.
import numpy as np
# Predict the cluster labels
predicted_labels = model.labels_
# Create a mapping from predicted labels to true labels
label_mapping = {}
for cluster in np.unique(predicted_labels):
cluster_labels = iris.target[predicted_labels == cluster]
mapped_label = np.argmax(np.bincount(cluster_labels))
label_mapping[cluster] = mapped_label
# Map the predicted labels to the true labels
mapped_predicted_labels = np.array([label_mapping[cluster] for cluster in predicted_labels])
# Print the mapped predicted labels
print("Mapped Predicted Labels:", mapped_predicted_labels)
Mapped Predicted Labels: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 2 2 2 2 1 2 2 2 2 2 2 1 1 2 2 2 2 1 2 1 2 1 2 2 1 1 2 2 2 2 2 1 2 2 2 2 1 2 2 2 1 2 2 2 1 2 2 1]
6. Evaluating K-Means Clustering Through Visual InspectionΒΆ
There are several metrics you can use to evaluate clustering algorithms. However, one intuitive way to assess how good the model did is by visual inspection. The code below will generate pairplots that will help us assess the quality of our predictions against the true labels. We'll use dataframe objects for the pair plots.
First, we'll generate a dataframe for the K-means predictions.
df2 = pd.concat([iris.data, pd.Series(mapped_predicted_labels, name='prediction')], axis=1)
df2.head()
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | prediction | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
Visualize the original data.
sns.pairplot(data=df, hue='target')
/Users/scottcampit/Projects/intro-to-ml/venv/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight self._figure.tight_layout(*args, **kwargs)
<seaborn.axisgrid.PairGrid at 0x2aef5a530>
Now visualize the prediction data.
sns.pairplot(df2, hue='prediction')
/Users/scottcampit/Projects/intro-to-ml/venv/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight self._figure.tight_layout(*args, **kwargs)
<seaborn.axisgrid.PairGrid at 0x2bd475060>
If we look at the scatter plots side-by-side, we note that they look really similar, with some minor differences in the label distribution. That is expected: we cannot expect a machine learning algorithm to perfectly capture the patterns existing in real data.
Otherwise, if the algorithm fitted too perfectly with the data, that is an example of a common machine learning problem called overfitting, where the model "memorizes" the patterns in the training data, and cannot be used for new data. Alternatively, if the model did not cluster the data well enough, that is an example of underfitting, where you would need more data points for the model to learn the patterns of the data.
7. Challenges with evaluating K-means clusteringΒΆ
You can see that unsupervised learning and clustering can be powerful for grouping data without any labels. However, we should mention the challenges with evaluating unsupervised learning models:
- Lack of objective evaluation metrics
- While we had the true labels in the example above, in real data, we tend to not have labeled data. That makes it hard to measure the correctness of an algorithm.
- Subjective to interpretation
- Since there are no labels, interpreting clusters becomes more of an art than a science.
- Determining the optimal number of clusters (K-means specific)
- In the exercise above, we knew there were 3 species of flowers, and set the number of cluster to be 3. However, in real data, it is much harder to determine the optimal number of clusters to use for K-means clustering
While these are some common challenges with evaluating K-means clustering, it is a powerful algorithm to know when faced with unlabeled data, and can still provide a lot of insight into the patterns within your dataset.