Python使用Numba并行测试
from numba import njit, prange
import numpy as np
import time
# 纯 Python 实现
def python_matrix_mult(A, B):
rows, cols = A.shape[0], B.shape[1]
result = [[0] * cols for _ in range(rows)]
for i in range(rows):
for j in range(cols):
for k in range(A.shape[1]):
result[i][j] += A[i][k] * B[k][j]
return result
# Numba 并行实现
@njit(parallel=True)
def numba_matrix_mult(A, B):
rows, cols = A.shape[0], B.shape[1]
result = np.zeros((rows, cols))
for i in prange(rows):
for j in range(cols):
for k in range(A.shape[1]):
result[i, j] += A[i, k] * B[k, j]
return result
# 测试性能
size = 1000
A = np.random.rand(size, size)
B = np.random.rand(size, size)
# 测试纯 Python
start = time.time()
result_python = python_matrix_mult(A, B)
print("纯 Python 耗时:", time.time() - start, "秒")
# 测试 NumPy
start = time.time()
result_numpy = np.dot(A, B)
print("NumPy 耗时:", time.time() - start, "秒")
# 测试 Numba 并行
start = time.time()
result_numba_parallel = numba_matrix_mult(A, B)
print("Numba 并行耗时:", time.time() - start, "秒")
说明:
@njit(parallel=True)
:Numba 会自动分析循环是否可以安全并行化,并利用多核 CPU 运行任务prange
是 Numba 提供的并行循环,可以让循环在多个 CPU 核心上执行
结果:
纯 Python 耗时: 257.57秒
NumPy 耗时: 0.05 秒
Numba 并行耗时: 0.74 秒