用迭代器模块为Python提速

Python中文社区

共 3095字,需浏览 7分钟

 · 2020-12-13

Python官方文档用"高效的循环"来形容itertools模块,有些tools会带来性能提升,而另外一些tools并不快,只是会节省一些开发时间而已,如果滥用还会导致代码可读性变差。我们不妨把itertools的兄弟们拉出来溜溜。

1. 数列累加

给定一个列表An,返回数列累加和Sn。举例说明:

输入: [1, 2, 3, 4, 5]

返回: [1, 3, 6, 10, 15]

使用accumulate,性能提升了2.5倍

from itertools import accumulate

def _accumulate_list(arr):
    tot = 0
    for x in arr:
        tot += x
        yield tot

def accumulate_list(arr):
    return list(_accumulate_list(arr))

def fast_accumulate_list(arr):
    return list(accumulate(arr))

arr = list(range(1000))

%timeit accumulate_list(arr)
61 µs ± 2.91 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit fast_accumulate_list(arr)
21.3 µs ± 811 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

2. 选择数据

给定一个列表data,一个用0/1表示的列表selectors,返回被选择的数据。举例说明:

输入: [1, 2, 3, 4, 5], [0, 1, 0, 1, 0]

返回: [2, 4]

使用compress,性能提升了2.8倍

from itertools import compress
from random import randint

def select_data(data, selectors):
    return [x for x, y in zip(data, selectors) if y]

def fast_select_data(data, selectors):
    return list(compress(data, selectors))

data = list(range(10000))
selectors = [randint(0, 1) for _ in range(10000)]

%timeit select_data(data, selectors)
341 µs ± 17.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit fast_select_data(data, selectors)
130 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

3. 组合

给定一个列表arr和一个数字k,返回从arr中选择k个元素的所有情况。举例说明:

输入: [1, 2, 3], 2

返回: [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]

使用permutations,性能提升了10倍

from itertools import permutations

def _get_permutations(arr, k, i):
    if i == k:
        return [arr[:k]]
    res = []
    for j in range(i, len(arr)):
        arr_cpy = arr.copy()
        arr_cpy[i], arr_cpy[j] = arr_cpy[j], arr_cpy[i]
        res += _get_permutations(arr_cpy, k, i + 1)
    return res

def get_permutations(arr, k):
    return _get_permutations(arr, k, 0)


def fast_get_permutations(arr, k):
    return list(permutations(arr, k))

arr = list(range(10))
k = 5

%timeit -n 1 get_permutations(arr, k)
15.5 ms ± 1.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit -n 1 fast_get_permutations(arr, k)
1.56 ms ± 284 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

4. 筛选数据

给定一个列表arr,筛选出所有的偶数。举例说明:

输入: [3, 1, 4, 5, 9, 2]

返回: [(4, 2]

使用filterfalse,性能反而会变慢,所以不要迷信itertools

from itertools import filterfalse

def get_even_nums(arr):
    return [x for x in arr if x % 2 == 0]

def fast_get_even_nums(arr):
    return list(filterfalse(lambda x: x % 2, arr))

arr = list(range(10000))

%timeit get_even_nums(arr)
417 µs ± 18.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit fast_get_even_nums(arr)
823 µs ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

5. 条件终止

给定一个列表arr,依次对列表的所有数字进行求和,若遇到某个元素大于target之后则终止求和,返回这个和。举例说明:

输入: [1, 2, 3, 4, 5], 3

返回: 6 (4 > 3,终止)

使用takewhile,性能反而会变慢,所以不要迷信itertools

from itertools import takewhile

def cond_sum(arr, target):
    res = 0
    for x in arr:
        if x > target:
            break
        res += x
    return res

def fast_cond_sum(arr, target):
    return sum(takewhile(lambda x: x <= target, arr))

arr = list(range(10000))
target = 5000

%timeit cond_sum(arr, target)
245 µs ± 11.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit fast_cond_sum(arr, target)
404 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

6. 循环嵌套

给定列表arr1arr2,返回两个列表的所有元素两两相加的和。举例说明:

输入: [1, 2], [4, 5]

返回: [1 + 4, 1 + 5, 2 + 4, 2 + 5]

使用product,性能提升了1.25倍。

from itertools import product

def _cross_sum(arr1, arr2):
    for x in arr1:
        for y in arr2:
            yield x + y

def cross_sum(arr1, arr2):
    return list(_cross_sum(arr1, arr2))


def fast_cross_sum(arr1, arr2):
    return [x + y for x, y in product(arr1, arr2)]

arr1 = list(range(100))
arr2 = list(range(100))

%timeit cross_sum(arr1, arr2)
484 µs ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit fast_cross_sum(arr1, arr2)
373 µs ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

7. 二维列表转一维列表

给定二维列表arr,转为一维列表 举例说明:

输入: [[1, 2], [3, 4]]

返回: [1, 2, 3, 4]

使用chain,性能提升了6倍。

from itertools import chain

def _flatten(arr2d):
    for arr in arr2d:
        for x in arr:
            yield x

def flatten(arr2d):
    return list(_flatten(arr2d))


def fast_flatten(arr2d):
    return list(chain(*arr2d))

arr2d = [[x + y * 100 for x in range(100)] for y in range(100)]

%timeit flatten(arr2d)
379 µs ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit fast_flatten(arr2d)
66.9 µs ± 3.43 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

作者:李小文,先后从事过数据分析、数据挖掘工作,主要开发语言是Python,现任一家小型互联网公司的算法工程师。

Github: https://github.com/tushushu


更多阅读



程序运行慢?你怕是写的假 Python


让 Pandas DataFrame 性能飞升 40 倍


用 PyQt 打造具有专业外观的GUI(上)


特别推荐


程序员摸鱼指南


为你精选的硅谷极客资讯,
来自FLAG巨头开发者、技术、创投一手消息




点击下方阅读原文加入社区会员


浏览 8
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报