• A.7 用Numba编写快速NumPy函数
    • 用Numba创建自定义numpy.ufunc对象

    A.7 用Numba编写快速NumPy函数

    Numba是一个开源项目,它可以利用CPUs、GPUs或其它硬件为类似NumPy的数据创建快速函数。它使用了LLVM项目(http://llvm.org/),将Python代码转换为机器代码。

    为了介绍Numba,来考虑一个纯粹的Python函数,它使用for循环计算表达式(x - y).mean():

    1. import numpy as np
    2. def mean_distance(x, y):
    3. nx = len(x)
    4. result = 0.0
    5. count = 0
    6. for i in range(nx):
    7. result += x[i] - y[i]
    8. count += 1
    9. return result / count

    这个函数很慢:

    1. In [209]: x = np.random.randn(10000000)
    2. In [210]: y = np.random.randn(10000000)
    3. In [211]: %timeit mean_distance(x, y)
    4. 1 loop, best of 3: 2 s per loop
    5. In [212]: %timeit (x - y).mean()
    6. 100 loops, best of 3: 14.7 ms per loop

    NumPy的版本要比它快过100倍。我们可以转换这个函数为编译的Numba函数,使用numba.jit函数:

    1. In [213]: import numba as nb
    2. In [214]: numba_mean_distance = nb.jit(mean_distance)

    也可以写成装饰器:

    1. @nb.jit
    2. def mean_distance(x, y):
    3. nx = len(x)
    4. result = 0.0
    5. count = 0
    6. for i in range(nx):
    7. result += x[i] - y[i]
    8. count += 1
    9. return result / count

    它要比矢量化的NumPy快:

    1. In [215]: %timeit numba_mean_distance(x, y)
    2. 100 loops, best of 3: 10.3 ms per loop

    Numba不能编译Python代码,但它支持纯Python写的一个部分,可以编写数值算法。

    Numba是一个深厚的库,支持多种硬件、编译模式和用户插件。它还可以编译NumPy Python API的一部分,而不用for循环。Numba也可以识别可以便以为机器编码的结构体,但是若调用CPython API,它就不知道如何编译。Numba的jit函数有一个选项,nopython=True,它限制了可以被转换为Python代码的代码,这些代码可以编译为LLVM,但没有任何Python C API调用。jit(nopython=True)有一个简短的别名numba.njit。

    前面的例子,我们还可以这样写:

    1. from numba import float64, njit
    2. @njit(float64(float64[:], float64[:]))
    3. def mean_distance(x, y):
    4. return (x - y).mean()

    我建议你学习Numba的线上文档(http://numba.pydata.org/)。下一节介绍一个创建自定义Numpy ufunc对象的例子。

    用Numba创建自定义numpy.ufunc对象

    numba.vectorize创建了一个编译的NumPy ufunc,它与内置的ufunc很像。考虑一个numpy.add的Python例子:

    1. from numba import vectorize
    2. @vectorize
    3. def nb_add(x, y):
    4. return x + y

    现在有:

    1. In [13]: x = np.arange(10)
    2. In [14]: nb_add(x, x)
    3. Out[14]: array([ 0., 2., 4., 6., 8., 10., 12., 14., 16., 18.])
    4. In [15]: nb_add.accumulate(x, 0)
    5. Out[15]: array([ 0., 1., 3., 6., 10., 15., 21., 28., 36., 45.])