Python使用Numba并行测试

2024-12-06
#Python #Parallel
from numba import njit, prange
import numpy as np
import time

# 纯 Python 实现
def python_matrix_mult(A, B):
    rows, cols = A.shape[0], B.shape[1]
    result = [[0] * cols for _ in range(rows)]
    for i in range(rows):
        for j in range(cols):
            for k in range(A.shape[1]):
                result[i][j] += A[i][k] * B[k][j]
    return result

# Numba 并行实现
@njit(parallel=True)
def numba_matrix_mult(A, B):
    rows, cols = A.shape[0], B.shape[1]
    result = np.zeros((rows, cols))
    for i in prange(rows):
        for j in range(cols):
            for k in range(A.shape[1]):
                result[i, j] += A[i, k] * B[k, j]
    return result

# 测试性能
size = 1000
A = np.random.rand(size, size)
B = np.random.rand(size, size)

# 测试纯 Python
start = time.time()
result_python = python_matrix_mult(A, B)
print("纯 Python 耗时:", time.time() - start, "秒")

# 测试 NumPy
start = time.time()
result_numpy = np.dot(A, B)
print("NumPy 耗时:", time.time() - start, "秒")

# 测试 Numba 并行
start = time.time()
result_numba_parallel = numba_matrix_mult(A, B)
print("Numba 并行耗时:", time.time() - start, "秒")

说明:

  • @njit(parallel=True):Numba 会自动分析循环是否可以安全并行化,并利用多核 CPU 运行任务
  • prange 是 Numba 提供的并行循环,可以让循环在多个 CPU 核心上执行

结果:

纯 Python 耗时: 257.57秒
NumPy 耗时: 0.05 秒
Numba 并行耗时: 0.74 秒