UMAP on the MNIST Digits dataset¶
A simple example demonstrating how to use UMAP on a larger dataset such as MNIST. We first pull the MNIST dataset and then use UMAP to reduce it to only 2-dimensions for easy visualisation.
Note that UMAP manages to both group the individual digit classes, but also to retain the overall global structure among the different digit classes – keeping 1 far from 0, and grouping triplets of 3,5,8 and 4,7,9 which can blend into one another in some cases.
import umap from sklearn.datasets import fetch_openml import matplotlib.pyplot as plt import seaborn as sns sns.set(context="paper", style="white") mnist = fetch_openml('mnist_784', version=1) reducer = umap.UMAP(random_state=42) embedding = reducer.fit_transform(mnist.data) fig, ax = plt.subplots(figsize=(12, 10)) color = mnist.target.astype(int) plt.scatter( embedding[:, 0], embedding[:, 1], c=color, cmap="Spectral", s=0.1 ) plt.setp(ax, xticks=, yticks=) plt.title("MNIST data embedded into two dimensions by UMAP", fontsize=18) plt.show()
Total running time of the script: ( 0 minutes 0.000 seconds)