12.3 3D Plotting#
To use Matplotlib’s 3D plotting functionality you first need to import the mplot3d
module:
from mpl_toolkits import mplot3d
Then you need to pass the keyword argument projection="3d"
into any of Matplotlib’s axis creating functions. For example:
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()

Note that this is a little less straight forward for the plt.subplots()
function:
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()

Rotating the Viewing Angle#
You can rotate the viewing angle of the figure programmatically using the view_init()
function:
view_init(elev=None, azim=None)
where
elev
is the elevation angle in the vertical plane in degreesazim
is the angle in the horizontal plane in degrees
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, subplot_kw=dict(projection='3d'), figsize=(10,8))
# Vertical Rotation
ax[0].set_xlabel('x')
ax[0].set_ylabel('y')
ax[0].set_zlabel('z')
ax[0].set_title('Vertical Rotation')
ax[0].view_init(elev=5) #Vertical rotation
# Horizontal Rotation
ax[1].set_xlabel('x')
ax[1].set_ylabel('y')
ax[1].set_zlabel('z')
ax[1].set_title('Horizontal Rotation')
ax[1].view_init(azim=30) #Horizontal rotation
plt.show()

Of course vertical and horizontal rotation can be combined.
3D Plotting Functions#
Now you can use the plotting functions available in the mplot3d module. We shall cover a few here, but note that these are not the full extend of what is available.
Plotting a line with plot()
or plot3D()
#
You can plot a line on a 3D axis using plot()
or plot3D()
, which has the call signature:
plot(xs, ys, zs)
where xs
, ys
and zs
are 1D array-like objects that contain the plot()
function.
Worked Example - Plotting a 3D Spiral Line
For example, let’s plot a spiral, defined by:
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
s = np.linspace(0, 50, 1000)
ax.plot(np.sin(s), np.cos(s), s)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()

Plotting surfaces with plot_surface()
#
The plot_surface()
function creates a surface plot using a grid of points (or vertices) arranged to form quadrelaterals.
plot_surface(X, Y, Z)
The arguments X
, Y
and Z
are 2D arrays containing the
Worked Example - Plotting a Rectangle
For this first example, let’s plot a square with one side elevated. We will give it points with the
These coordinates will be stored in separate arrays for
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'), figsize=(10,8))
X = np.array([
[0, 1],
[0, 1]
])
Y = np.array([
[0, 0],
[1, 1]
])
Z = np.array([
[0, 0],
[1, 1]
])
ax.plot_surface(X, Y, Z)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()

Note that the code used to select the color of the surface and create the numerical annotations is not included above.
Now, we will often create surfaces where
To create this grid we will use the numpy.meshgrid()
function. For our interests, the call signature of meshgrid()
is:
meshgrid(*xi)
where xi
is sequence of arrays. Here we are only interested in passing 2 arrays into meshgrid()
(one for a sequence of
Consider the arrays x
and y
where:
if we were to put these into meshgrid()
:
X, Y = np.meshgrid(x, y)
then the resulting 2D arrays would be of the form:
and
both with
Let’s consider an example of a grid like this generated from the arrays:
used to generate a flat surface with
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'), figsize=(10,8))
x = np.arange(0, 6)
y = np.arange(0, 4)
X, Y = np.meshgrid(x, y)
Z = np.zeros(X.shape)
ax.plot_surface(X, Y, Z)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()

Worked Example - Plotting a Surface Using a Rectangular Grid
Let’s consider a simple surface:
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
x = np.linspace(-1, 1, 10)
y = np.linspace(-1, 1, 10)
X, Y = np.meshgrid(x, y)
Z = X * Y
ax.plot_surface(X, Y, Z)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()

Worked Example - Plotting a Surface Using a Radial Grid
You may not always want to use a rectangular grid. For example, if we have a radially symmetric surface (or just a surface defined using polar coordinates):
then we may want to use a radial grid, by creating a rectangular grid of
and then calculating an
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
r = np.linspace(0, 10, 10)
theta = np.linspace(0, 2 * np.pi, 20)
R, THETA = np.meshgrid(r, theta)
X = R * np.cos(THETA)
Y = R * np.sin(THETA)
Z = 1 - R * R # X**2 + Y**2 == R**2
ax.plot_surface(X, Y, Z)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
plt.show()

There’s much more that you can do with these surface plots, such as coloring them in using a colormap.
Plotting wireframes with plot_wireframe()
#
The plot_wireframe()
function is similar to the plot_surface()
function, except it produces a plot of the edges of the quadrilaterals only, with the faces unfilled.
Worked Example - Plotting a Simple Wireframe
Let’s consider the same surface as before:
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, subplot_kw=dict(projection='3d'), figsize=(10,8))
x = np.linspace(-1, 1, 10)
y = np.linspace(-1, 1, 10)
X, Y = np.meshgrid(x, y)
Z = X * Y
#Surface Plot
ax[0].plot_surface(X, Y, Z)
ax[0].set_title('Surface Plot')
ax[0].set_xlabel('x')
ax[0].set_ylabel('y')
ax[0].set_zlabel('z')
#Wireframe plot
ax[1].plot_wireframe(X, Y, Z)
ax[1].set_title('Wireframe Plot')
ax[1].set_xlabel('x')
ax[1].set_ylabel('y')
ax[1].set_zlabel('z')
plt.show()
