mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm (#1743)
* [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review changes * [CK_TILE] Move hipmalloc/memcpy calls out of gpu reference gemm - review fix
This commit is contained in:
@@ -97,9 +97,9 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_gemm_gpu(DeviceMem& a_device,
|
||||
DeviceMem& b_device,
|
||||
DeviceMem& c_device,
|
||||
void reference_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
|
||||
index_t stride_b,
|
||||
index_t stride_c)
|
||||
{
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType));
|
||||
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType));
|
||||
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType));
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
errA = hipMemcpy(
|
||||
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice);
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
|
||||
}
|
||||
|
||||
errB = hipMemcpy(
|
||||
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice);
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
|
||||
}
|
||||
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
|
||||
errC = hipMemcpy(
|
||||
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
|
||||
}
|
||||
|
||||
errA = hipFree(d_A);
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
|
||||
}
|
||||
|
||||
errB = hipFree(d_B);
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
|
||||
}
|
||||
|
||||
errC = hipFree(d_C);
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
|
||||
}
|
||||
<<<numBlocks, numThreadsPerBlock>>>(
|
||||
a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -191,9 +125,9 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_batched_gemm_gpu(DeviceMem& a_device,
|
||||
DeviceMem& b_device,
|
||||
DeviceMem& c_device,
|
||||
void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
|
||||
index_t batch_stride_C,
|
||||
index_t batch_count)
|
||||
{
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
|
||||
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
|
||||
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
|
||||
<< std::endl;
|
||||
return; // Early exit on error
|
||||
}
|
||||
|
||||
errA = hipMemcpy(d_A,
|
||||
a_device.GetDeviceBuffer(),
|
||||
batch_count * M * K * sizeof(ADataType),
|
||||
hipMemcpyHostToDevice);
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
|
||||
}
|
||||
|
||||
errB = hipMemcpy(d_B,
|
||||
b_device.GetDeviceBuffer(),
|
||||
batch_count * N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice);
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
|
||||
}
|
||||
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
|
||||
{
|
||||
ADataType* d_ATemp = d_A + batch_id * batch_stride_A;
|
||||
BDataType* d_BTemp = d_B + batch_id * batch_stride_B;
|
||||
CDataType* d_CTemp = d_C + batch_id * batch_stride_C;
|
||||
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
|
||||
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
|
||||
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(
|
||||
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
|
||||
}
|
||||
|
||||
errC = hipMemcpy(c_device.GetDeviceBuffer(),
|
||||
d_C,
|
||||
batch_count * M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost);
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
|
||||
}
|
||||
|
||||
errA = hipFree(d_A);
|
||||
if(errA != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
|
||||
}
|
||||
|
||||
errB = hipFree(d_B);
|
||||
if(errB != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
|
||||
}
|
||||
|
||||
errC = hipFree(d_C);
|
||||
if(errC != hipSuccess)
|
||||
{
|
||||
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user