mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
[CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm (#1743)
* [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm
* [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review changes
* [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review fix
[ROCm/composable_kernel commit: f6c4d614e3]
This commit is contained in:
@@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_A,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
M * K * sizeof(ADataType),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
|
||||
d_C,
|
||||
M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
|
||||
|
||||
@@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_A,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
batch_count * M * K * sizeof(ADataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
batch_count * N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
ck_tile::reference_batched_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_gpu_buf_ref,
|
||||
CLayout>(d_A,
|
||||
d_B,
|
||||
d_C,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
batch_stride_C,
|
||||
batch_count);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
|
||||
d_C,
|
||||
batch_count * M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
|
||||
|
||||
|
||||
@@ -97,9 +97,9 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_gemm_gpu(DeviceMem& a_device,
|
||||
DeviceMem& b_device,
|
||||
DeviceMem& c_device,
|
||||
void reference_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
|
||||
index_t stride_b,
|
||||
index_t stride_c)
|
||||
{
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType));
|
||||
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType));
|
||||
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType));
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
errA = hipMemcpy(
|
||||
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice);
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
|
||||
}
|
||||
|
||||
errB = hipMemcpy(
|
||||
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice);
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
|
||||
}
|
||||
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
|
||||
errC = hipMemcpy(
|
||||
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
|
||||
}
|
||||
|
||||
errA = hipFree(d_A);
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
|
||||
}
|
||||
|
||||
errB = hipFree(d_B);
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
|
||||
}
|
||||
|
||||
errC = hipFree(d_C);
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
|
||||
}
|
||||
<<<numBlocks, numThreadsPerBlock>>>(
|
||||
a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -191,9 +125,9 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_batched_gemm_gpu(DeviceMem& a_device,
|
||||
DeviceMem& b_device,
|
||||
DeviceMem& c_device,
|
||||
void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
|
||||
index_t batch_stride_C,
|
||||
index_t batch_count)
|
||||
{
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
|
||||
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
|
||||
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
errA = hipMemcpy(d_A,
|
||||
a_device.GetDeviceBuffer(),
|
||||
batch_count * M * K * sizeof(ADataType),
|
||||
hipMemcpyHostToDevice);
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
|
||||
}
|
||||
|
||||
errB = hipMemcpy(d_B,
|
||||
b_device.GetDeviceBuffer(),
|
||||
batch_count * N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice);
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
|
||||
}
|
||||
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
|
||||
{
|
||||
ADataType* d_ATemp = d_A + batch_id * batch_stride_A;
|
||||
BDataType* d_BTemp = d_B + batch_id * batch_stride_B;
|
||||
CDataType* d_CTemp = d_C + batch_id * batch_stride_C;
|
||||
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
|
||||
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
|
||||
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(
|
||||
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
|
||||
}
|
||||
|
||||
errC = hipMemcpy(c_device.GetDeviceBuffer(),
|
||||
d_C,
|
||||
batch_count * M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost);
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
|
||||
}
|
||||
|
||||
errA = hipFree(d_A);
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
|
||||
}
|
||||
|
||||
errB = hipFree(d_B);
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
|
||||
}
|
||||
|
||||
errC = hipFree(d_C);
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user