Matplotlib - 3D Scatter Plots



A scatter plot is a way to display individual data points on a two-dimensional graph. Each point on the graph represents the values of two variables.

In a 3D scatter plot, three variables are considered instead of just two, creating a three-dimensional space where each point is defined by three values. Each data point is represented by a marker in this 3D space, with its position determined by its values along the three axes.

For example, you can use a 3D scatter plot if you are plotting the heights, weights, and ages of a group of people. In this representation, the x-axis signifies height, the y-axis corresponds to weight, and the z-axis represents age. Data of each person is then represented by a point in the three-dimensional space, showing their specific combination of height, weight, and age −

3D Scatter Plot

3D Scatter Plots in Matplotlib

We can create a 3D scatter plot in Matplotlib using the scatter() function from the "mpl_toolkits.mplot3d" module. This function allows you to specify the X, Y, and Z coordinates of each data point, and it places markers in the 3D space accordingly.

The resulting plot provides a visual representation of the distribution of data points and relationships between three variables.

Let's start by creating a basic 3D scatter plot.

Basic 3D Scatter Plot

In a basic 3D scatter plot using Matplotlib, imagine a three-dimensional space where you have a bunch of points floating around. Each point has its own set of three coordinates: one for the position along the X-axis, one for the Y-axis, and one for the Z-axis. The plot displays these points as markers in space, creating a visual representation of your data in three dimensions.

Example

In the following example, we are generating random 3D data points and creating a basic 3D scatter plot. The x, y, and z coordinates of each point are randomly chosen, and the resulting plot visualizes these points in three-dimensional space using blue markers −

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Generating 3D data
x = np.random.rand(50)
y = np.random.rand(50)
z = np.random.rand(50)

# Creating a basic 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z, c='blue', marker='o', label='Data Points')

ax.set_xlabel('X Axis')
ax.set_ylabel('Y Axis')
ax.set_zlabel('Z Axis')
ax.set_title('Basic 3D Scatter Plot')
plt.legend()
plt.show()

Output

After executing the above code, we get the following output −

Basic 3D Scatter Plot

Multiple 3D Scatter Plots

Multiple 3D scatter plots in Matplotlib allow you to represent multiple sets of points in a three-dimensional space. Each point corresponds to a unique dataset, and within each dataset, individual points are defined by their specific combinations of X, Y, and Z coordinates.

Example

In here, we are generating two sets of random 3D data points and creates multiple 3D scatter plots to visualize them. Each set of points (Set 1 and Set 2) has x, y, and z coordinates randomly chosen. The resulting plot shows two sets of points in three-dimensional space, distinguished by blue circles (Set 1) and green triangles (Set 2) −

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Generating 3D data for two sets of points
x1 = np.random.rand(50)
y1 = np.random.rand(50)
z1 = np.random.rand(50)

x2 = np.random.rand(50)
y2 = np.random.rand(50)
z2 = np.random.rand(50)

# Creating multiple 3D scatter plots
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x1, y1, z1, c='blue', marker='o', label='Set 1')
ax.scatter(x2, y2, z2, c='green', marker='^', label='Set 2')

ax.set_xlabel('X Axis')
ax.set_ylabel('Y Axis')
ax.set_zlabel('Z Axis')
ax.set_title('Multiple 3D Scatter Plots')
plt.legend()
plt.show()

Output

Following is the output of the above code −

Multiple 3D Scatter Plots

3D Scatter Plot with Annotations

A 3D Scatter Plot with Annotations in Matplotlib is a way to represent points in a three-dimensional space, where each point is marked with additional information through text labels.

To add these annotations, you can use the text() function in matplotlib. This function places text labels at desired locations in the plot, usually near specific data points.

Example

Now, we are creating a 3D scatter plot with annotations. We label each point with a number from 1 to the total number of points, and these labels are displayed next to their respective points in the plot −

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Generatng random data
np.random.seed(42)
n_points = 10
x = np.random.rand(n_points)
y = np.random.rand(n_points)
z = np.random.rand(n_points)

# Creating a 3D scatter plot with annotations
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z)

# Annotating points
for i, txt in enumerate(range(1, n_points + 1)):
    ax.text(x[i], y[i], z[i], str(txt), color='red')

# Show the plot
plt.show()

Output

Output of the above code is as follows −

3D Scatter Plot with Annotations

Connected 3D Scatter Plot

A connected 3D scatter plot is a variation where data points are represented individually, but lines are drawn between consecutive points. This helps to visualize connections between the points over the three-dimensional space.

Example

In the example below, we are creating a connected 3D scatter plot that shows individual data points (blue) connected by dashed lines (linestyle='--') with red color −

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Generating random data
np.random.seed(42)
n_points = 10
x = np.random.rand(n_points)
y = np.random.rand(n_points)
z = np.random.rand(n_points)

# Creating a connected 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
ax.scatter(x, y, z, c='blue', marker='o', label='Points')

# Connected lines
ax.plot(x, y, z, c='red', linestyle='-', linewidth=2, label='Lines')

# Adding labels
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

# Adding a legend
ax.legend()

# Displaying the plot
plt.show()

Output

The output obtained is as shown below −

Connected 3D Scatter Plot
Advertisements