LOADING

進度條正在跑跑中

Matplotlib 相關功能整理

Visualization with Matplotlib



介紹

  • a multiplatform data visualization library built on NumPy arrays and designed to work with the broader SciPy stack.
  • supports numerous backends and output types, which means we can count on it to work regardless of the operating system we are using or the output format we desire.
  • It has led to a large user base.
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('seaborn-v0_8-whitegrid')

Two interfaces for the matplotlib

  1. a user-friendly functional-style state-based interface
  2. a more powerful object-oriented interface.

The simplest method, plot() accept two arrays (x and y) as inputs.

x = np.linspace(-np.pi, np.pi, 256)
# `x` is now a array with 256 values ranging from $-\pi$ to $\pi$ (included).
C, S = np.cos(x), np.sin(x)
# `C` is the cosine (256 values) and `S` is the sine (256 values).

Functional interface

# 1. create a plot figure
plt.figure() 
#2. create the first of two panels and set current axis
plt.subplot(2, 1, 1) # (rows, columns, panel number)
plt.plot(x, S)
# 3. create the second panel and set current axis
plt.subplot(2, 1, 2)
plt.plot(x, C)
plt.savefig("test.jpg");
# Note that the semicolon at the end of the last line is intentional: it suppresses the textual
# representation of the plot from the output
  • this interface is stateful:
    it maintains information about the “current” figure and axes, which serve as the targets for all plt commands.

We can obtain a reference to these by using

  • plt.gcf() (get current figure)
  • plt.gca() (get current axes)

Although quick and convenient for basic plots, it can lead to difficulties. For instance,

  • after creating the second panel, how can we return to the first one and add something?

Object-oriented interface

對於較複雜的情境或是需要更多控制的情境,物件導向的介面會很有幫助。
物件導向的介面會將圖表的每個元件都視為物件,而不是將圖表視為一個整體。

# 1. First create a grid of plots
# ax will be an array of two Axes objects
fig, ax = plt.subplots(2)
# 2. Call plot() method on the appropriate object
ax[0].plot(x, S)
ax[1].plot(x, C);

Simple plots

要創建一個簡單的圖表,我們

  1. Call the plt.figure() to create a new figure. (optional for %matplotlib inline)
  2. Generate a sequence of xx values usually using linspace().
  3. Generate a sequence of yy values usually by substitute the x values into a function.
  4. Input plt.plot(x, y, [format], **kwargs) where [format] is an (optional) format string, and **kwargs are (optional) keyword arguments specifying the line properties of the plot.
  5. Utilize plt functions to enhance the figure with features such as a title, legend, grid lines, etc.
  6. Input plt.show() to display the resulting figure (this step is optional in a Jupyter notebook).

Adjusting the plot: Line colors, styles and widths

Line colors

color keyword:

plt.plot(x, np.cos(x - 0), color='blue')         # specify color by name
plt.plot(x, np.cos(x - 1), color='g')            # short color code (rgbcmyk)
plt.plot(x, np.cos(x - 2), color='0.75')         # grayscale between 0 and 1
plt.plot(x, np.cos(x - 4), color=(1.0,0.2,0.3)); # RGB tuple, values 0 to 1

Line styles

linestyle keyword:

plt.plot(x, x - 0, linestyle='-')  # solid
plt.plot(x, x - 1, linestyle='--') # dashed
plt.plot(x, x - 2, linestyle='-.') # dashdot
plt.plot(x, x - 3, linestyle=':')  # dotted
plt.plot(x, x - 4, ':k');          # dotted black 
# You can save some keystrokes by combining these linestyle and color codes into a single non-keyword argument

Line widths

linewidth keyword:

plt.plot(x, np.cos(x - 0)) 
plt.plot(x, np.cos(x - 1), linewidth='5');

Axes limits

plt.plot(x, np.cos(x))

plt.xlim(-0.5, 10.5)
plt.ylim(-1.5, 1.5);

Labeling

plt.plot(x, np.sin(x), '-g', label='sin(x)') # solid green line
plt.plot(x, np.cos(x), ':b', label='cos(x)') # dotted blue line

plt.title("A Sin/Cos Curve", fontsize=18)       # we can also specify the font size
plt.xlabel("x", fontsize=14)
plt.ylabel("sin(x)", fontsize=14)
plt.legend(fontsize=12)

plt.axis('equal');

Matplotlib tips

Functional OOP
plt.xlabel() ax.set_xlabel()
plt.ylabel() ax.set_ylabel()
plt.xlim() ax.set_xlim()
plt.ylim() ax.set_ylim()
plt.title() ax.set_title()

Simple scatter plots

所謂的散佈圖,就是將資料點以點的方式呈現在圖表上。

x = np.linspace(0, 10, 30)
y = np.sin(x)
plt.plot(x, y, 'o', color='black');

第三個參數 'o' 代表使用點的方式呈現資料點,而不是使用線的方式呈現資料點。
marker styles有很多種,可以參考這裡

Scatter plots with plt.scatter

進階的散佈圖可以使用 plt.scatter 來呈現。
主要的差異在於 plt.scatter 可以讓我們在每個資料點上設定不同的顏色、大小、形狀等等。

np.random.seed(42)

x = np.random.randn(100)
y = np.random.randn(100)
colors = np.random.rand(100)
sizes = 1000 * np.random.rand(100)

plt.scatter(x, y, c=colors, s=sizes, alpha=0.3, cmap='viridis')
plt.colorbar(); # show color scale

想看所有的Matplotlib colormaps,可以參考list of colormaps.

Density plots

Histograms

一個常見的資料視覺化方式是直方圖,可以使用 plt.hist 來呈現。

np.random.seed(42)
data = np.random.normal(size=1000)
plt.hist(data);

binnings

plt.hist 有一個參數 bins,可以用來調整直方圖的區間數。

plt.hist(data, bins=30, density=True, alpha=0.5, color='steelblue', edgecolor='none')
x = np.linspace(-4,4,100)
y = 1/(2*np.pi)**0.5 * np.exp(-x**2/2)
plt.plot(x,y,'b',alpha=0.8);

Density

density=True 可以讓直方圖的面積總和為1,這樣就可以當作機率密度函數來看待。

Advance plot

Filling the area between lines

要畫出兩條線之間的區域,可以使用plt.fill_between()

x = np.linspace(0, 2*np.pi, 1000)

plt.plot(x, np.sin(x), 'r')
plt.plot(x, np.cos(x), 'g')
plt.fill_between(x, np.cos(x), np.sin(x), color='red', alpha=0.1);

Plot in polar coordinate

要畫極座標的圖,可以使用 plt.polar()

t = np.linspace(0, 2*np.pi, 64)

# plot in polar coordinates
plt.axes(projection='polar')
plt.plot(t, np.sin(t), '-');


# Set ticks for polar coordinate
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2], ['0', '$\pi/2$', '$\pi$', '$3\pi/2$']);

Note that we would expect that a radius of 0 designates the origin, and a negative radius is reflected across the origin
in particular, the polar coordinates (r,t)(r, t) and (r,t+π)(-r, t+\pi) should designate the same point.
We can enforce this behavior using the following code:

t = np.linspace(0, 2*np.pi, 64)
r = np.sin(t)
# plot in polar coordinates
plt.axes(projection='polar')
plt.plot(t+(r<0)*np.pi, np.abs(r), '-')

# Set ticks for polar coordinate
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2], ['0', '$\pi/2$', '$\pi$', '$3\pi/2$']);

Customizing Plot

Customizing plot legends

x = np.linspace(0, 10, 100)
plt.plot(x, np.sin(x), '-g', label='sin(x)')
plt.plot(x, np.cos(x), ':b', label='cos(x)')
plt.axis('equal')

plt.legend(loc='upper left', frameon=True)

Text and Annotation

當我們想要在圖表上加上文字或是註解時,可以使用 plt.text() 或是 plt.annotate()

plt.figure()
x = np.linspace(0, 20, 1000)
plt.plot(x, np.cos(x))
plt.axis('equal')
plt.annotate('local maximum', xy=(6.28, 1), xytext=(10, 4), arrowprops=dict(facecolor='black'), fontsize=14)

plt.annotate('local minimum', xy=(5 * np.pi, -1), xytext=(2, -6), arrowprops=dict(arrowstyle="->"), fontsize=14)
plt.text(3.14, -1, 'local minimum', fontsize=14, ha='center'); # (x, y, text)

Customizing ticks and splines

tick的定位和標籤可以使用 plt.xticks()plt.yticks() 來調整。

x = np.linspace(0, 20, 1000)
plt.plot(x, np.cos(x))
plt.plot(x, np.sin(x))
plt.axis('equal')

# Set the ticks and tick labels
plt.xticks([0, np.pi, 2 * np.pi, 3 * np.pi, 4 * np.pi],
           [r'$0$', r'$\pi$', r'$2\pi$', r'$3\pi$', r'$4\pi$'], fontsize=14)
plt.yticks([-1, 0, +1],  [r'$-1$', r'$0$', r'$+1$'], fontsize=14);

Spines是指圖表的邊框,可以使用 plt.gca() 來取得目前的axes,然後使用 ax.spines 來調整。

plt.figure(figsize=(8,5), dpi=80)
X = np.linspace(-np.pi, np.pi, 256, endpoint=True)
C = np.cos(X)
S = np.sin(X)

plt.plot(X, C, color="blue", linewidth=2.5, linestyle="-")
plt.plot(X, S, color="red", linewidth=2.5, linestyle="-")

ax = plt.gca()
ax.spines[['top', 'right']].set_visible(False)
ax.spines['bottom'].set_position(('data',0))
ax.spines['left'].set_position(('data',0))

Multiple Subplots

plt.subplots()

要畫出多個子圖,可以使用 plt.subplots(),這個函數會回傳一個figure和一個axes的array。

fig, ax = plt.subplots(2, 3)
fig.subplots_adjust(hspace=0.5, wspace=0.4)
for i in range(2):
    for j in range(3):
        ax[i, j].text(0.5, 0.5, str((i, j)), fontsize=18, ha='center', va='center')

利用指令plt.subplots_adjust()可以調整子圖之間的間距。

fig, ax = plt.subplots(2, 2, figsize=(8,8))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

x = np.linspace(0, 10, 1000)
ax[0,0].plot(x, np.sin(x))
ax[0,1].plot(x, np.cos(x))
ax[1,0].plot(x, x**2)
ax[1,0].set_xscale('log') # Set the scale to log scale
ax[1,0].set_yscale('log')
ax[1,1].plot(x, x**2);