【Python】这个装饰器竟让 Python 提速了 30 倍!
Python是一种解释语言,其代码不是直接编译成机器码,而是由另一个叫做
解释器
的程序实时解释的(一般是
cpython
)。因此,与其他编译语言相比,Python灵活性高(动态类型,兼容性高,...)。但这也造成了Python非常慢的缺点。
加速 Python的方法
实际上,有多种解决方案可以解决python的缓慢问题。
-
使用
cython
:一种编程语言,是python的超集。Cython是Python编程语言和扩展 Cython 编程语言(基于Pyrex)的优化静态编译器。它使得为 Python 编写 C 扩展就像 Python 本身一样容易。 -
使用C/C++语言结合
ctypes
,pybind11
或CFFI
来编写Python的绑定程序 - 用C/C++扩展Python
- 使用其他编译过的语言,如rust[1]
而所有这些方法,都需要使用除Python外的另一种语言,并编译代码使之与Python一起工作。尽管这些方法都很不错,但并不是最适合我们初学者的使python更快的方法,更别提他们通常比较难以设置了。
Numba & JIT 编译器
Numba[2]是一个Python包,在兼具Python的便利的同时,可以使你的代码更快。
numba
使用Just-in-time (JIT)编译(即在Python代码执行过程中的实时编译的),使用起来非常方便,无需向其他工具一样,还需安装一个C/C++编译器,它仅需用 pip/conda 安装它即可。
pip install numba
接下来试一个例子:用蒙特卡洛模拟来计算π的估计值。
import random
from numba import njit
def monte_carlo_pi_without_numba(nsamples):
acc = 0
for i in range(nsamples):
x = random.random()
y = random.random()
if (x ** 2 + y ** 2) < 1.0:
acc += 1
return 4.0 * acc / nsamples
# 添加numba的装饰器,使该函数更快。
@njit
def monte_carlo_pi_with_numba(nsamples):
acc = 0
for i in range(nsamples):
x = random.random()
y = random.random()
if (x ** 2 + y ** 2) < 1.0:
acc += 1
return 4.0 * acc / nsamples
在使用该方法时,我们只需要导入numba
的一个装饰器(njit
),剩下的都由它自己完成即可,可以说是非常方便。
我们运行两个版本的代码,并进行计时对比。显示numba比普通python快30倍。
%timeit monte_carlo_pi_with_numba(100_000)
# 1.24 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit monte_carlo_pi_without_numba(100_000)
# 40.6 ms ± 814 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
一些注意事项
值得一提的是,numba
确实有一些缺点:
-
在首次运行
numba
装饰的函数时,有一定的时间开销。这是因为首次执行时,numba
会试图找出参数的类型并编译函数,从而导致程序有一定的延迟。 -
不是所有的Python代码都能用
numba
编译,例如,如果你对同一个变量或对列表元素使用混合数据类型,此种情况将会抛出异常。
加速 Pandas
Numba
是专门为numpy设计的,对numpy数组非常友好。而 pandas
是建立在 numpy
之上的,这使得在使用用户定义的函数或甚至执行不同的Dataframe操作时,可以进行疯狂优化。
首先创建一个DataFrame数据集。
import numpy as np
import pandas as pd
n = 1_000_000
df = pd.DataFrame({
'height': 1 + 1.3 * np.random.random(n),
'weight': 40 + 260 * np.random.random(n),
'hip_circumference': 94 + 14 * np.random.random(n)
})
用户定义的函数
numba 的另一个重要的方法是 vectorize
,使用该方法可以很容易的创建numpy通用函数(ufuncs[3])
通用函数(或简称ufunc)是以ndarrays逐个元素的方式运行的函数,支持数组广播、类型转换和其他几个标准特性。也就是说,ufunc 是一个函数的“矢量化”包装器,它接受固定数量的特定输入并产生固定数量的特定输出。
下面是计算数据集中列height
的平方。
from numba import vectorize
def get_squared_height_without_numba(height):
return height ** 2
@vectorize
def get_squared_height_with_numba(height):
return height ** 2
%timeit df['height'].apply(get_squared_height_without_numba)
# 279 ms ± 7.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit df['height'] ** 2
# 2.04 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 我们首先将列转换为numpy数组,
# 因为numba与numpy兼容,与pandas并不兼容。
%timeit get_squared_height_with_numba(df['height'].to_numpy())
# 1.6 ms ± 51.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
基本操作
使用njit
,并计算 BMI(身体质量指数)。
from numba import njit
@njit
def get_bmi(weight_col, height_col):
n = len(weight_col)
result = np.empty(n, dtype="float64")
# 与python循环相比,Numba的循环非常快
for i, (weight, height) in enumerate(zip(weight_col, height_col)):
result[i] = weight / (height ** 2)
return result
# 不要忘记将列转换为 numpy
%timeit df['bmi'] = get_bmi(df['weight'].to_numpy(),
df['height'].to_numpy())
# 6.77 ms ± 230 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit df['bmi'] = df['weight'] / (df['height'] ** 2)
# 8.63 ms ± 316 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
你可以看到,即使是基本的操作,numba仍然比原始 pandas 花费的时间更少(6.77ms vs 8.63ms)。
写在最后
numba
是一种开箱即用的方法,可以轻而易举地 让你的 Python 代码变得更快。当然,在成功编译代码之前可能需要多几次尝试,你可以试试使用它。如果本文对你有用,那就点个赞和在看支持下云朵君吧!
拓展阅读
rust: https://github.com/PyO3/pyo3
[2]Numba: https://numba.pydata.org/
[3]ufuncs: https://numpy.org/doc/stable/reference/ufuncs.html
往期 精彩 回顾
- 适合初学者入门人工智能的路线及资料下载
- (图文+视频)机器学习入门系列下载
- 机器学习及深度学习笔记等资料打印
- 《统计学习方法》的代码复现专辑
- 机器学习交流qq群955171419,加入微信群请 扫码