让 Pandas DataFrame 性能飞升 40 倍
val
这一列进行处理:如果是偶数则减1,奇数则加1。实际的数据分析工作要比这个例子复杂的多,但考虑到我们没有那么多时间等待运行结果,所以就偷个懒吧。可以看到transform
函数的平均运行时间是284ms:import pandas as pd
import numpy as np
def gen_data(size):
d = dict()
d["genre"] = np.random.choice(["A", "B", "C", "D"], size=size)
d["val"] = np.random.randint(low=0, high=100, size=size)
return pd.DataFrame(d)
data = gen_data(1000000)
data.head()
def transform(data):
data.loc[:, "new_val"] = data.val.apply(lambda x: x + 1 if x % 2 else x - 1)
%timeit -n 1 transform(data)
284 ms ± 8.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
x + 1 if x % 2 else x - 1
这个函数。平均运行时间降低到了202ms,果然速度变快了。性能大约提升了1.4倍,离40倍的flag还差的好远。%load_ext cython
%%cython
cpdef int _transform(int x):
if x % 2:
return x + 1
return x - 1
def transform(data):
data.loc[:, "new_val"] = data.val.apply(_transform)
%timeit -n 1 transform(data)
202 ms ± 13.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
val
这一列作为Numpy数组传递给Cython函数,注意区分cnp
和np
。平均运行时间直接降到10.8毫秒,性能大约提升了26倍,仿佛看到了一丝希望。%%cython
import numpy as np
cimport numpy as cnp
ctypedef cnp.int_t DTYPE_t
cpdef cnp.ndarray[DTYPE_t] _transform(cnp.ndarray[DTYPE_t] arr):
cdef:
int i = 0
int n = arr.shape[0]
int x
cnp.ndarray[DTYPE_t] new_arr = np.empty_like(arr)
while i < n:
x = arr[i]
if x % 2:
new_arr[i] = x + 1
else:
new_arr[i] = x - 1
i += 1
return new_arr
def transform(data):
data.loc[:, "new_val"] = _transform(data.val.values)
%timeit -n 1 transform(data)
10.8 ms ± 512 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
@cython.boundscheck(False)
,@cython.wraparound(False)
装饰器关闭数组的边界检查和负下标处理,平均运行时间变为5.9毫秒。性能提升了42倍左右,顺利完成任务。%%cython
import cython
import numpy as np
cimport numpy as cnp
ctypedef cnp.int_t DTYPE_t
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef cnp.ndarray[DTYPE_t] _transform(cnp.ndarray[DTYPE_t] arr):
cdef:
int i = 0
int n = arr.shape[0]
int x
cnp.ndarray[DTYPE_t] new_arr = np.empty_like(arr)
while i < n:
x = arr[i]
if x % 2:
new_arr[i] = x + 1
else:
new_arr[i] = x - 1
i += 1
return new_arr
def transform(data):
data.loc[:, "new_val"] = _transform(data.val.values)
%timeit -n 1 transform(data)
6.76 ms ± 545 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
作者:李小文,先后从事过数据分析、数据挖掘工作,主要开发语言是Python,现任一家小型互联网公司的算法工程师。
Github: https://github.com/tushushu
推荐阅读
点击下方阅读原文加入社区会员
点赞鼓励一下
评论