mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
initial stream-k implementation with example (#699)
* initial stream-k implementation with example * fix unexpected change in err * improve a little bit performance by reorganize pipeline. * improve perf a little bit by swizzle block idx * add profiler * update example * fix spelling * shrink karg for streamk * support dynamic buffer using memory coherence glc_slc bit from template * control memory coherence while construct dynamic buffer * update reduction for streamk(not ready yet) * Add template parameter to make_dynamic_buffer to support amd_buffer coherence setting * fix build issue * fix several bug * now result is correct, everything works (but has scratch) * remove scratch by manually reset coordinate * update device code * fix a bug in final reduce * fix something in example * update async memset * fix enum as camel case * modify coherence enum name * clean code and use atomic streamk by default * remove unused var * throw exception if have empty pointer * fix format * fix CI warning * fix type in init * modify CI error * filter out on gfx10+ * restore changed example code --------- Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
add_instance_library(device_gemm_streamk_instance
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
// static constexpr auto GemmMNPadding =
|
||||
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,20 +10,57 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
|
||||
void DeviceMem::Realloc(std::size_t mem_size)
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(hipFree(mpDeviceBuf));
|
||||
}
|
||||
mMemSize = mem_size;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
|
||||
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
|
||||
|
||||
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
|
||||
|
||||
void DeviceMem::ToDevice(const void* p) const
|
||||
{
|
||||
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(
|
||||
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("ToDevice with an empty pointer");
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p) const
|
||||
{
|
||||
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("FromDevice with an empty pointer");
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
|
||||
void DeviceMem::SetZero() const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize));
|
||||
}
|
||||
}
|
||||
|
||||
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }
|
||||
DeviceMem::~DeviceMem()
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
hip_check_error(hipFree(mpDeviceBuf));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user