mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add optimized blockwise gemm using ck wrapper (#1157)
* Add optimized blockwise gemm using ck wrapper * Add basic gemm example * Update docs * Add tutorial for gemm using ck wrapper * Add perf note * edits * Fix cmake * Fixes --------- Co-authored-by: Lisa Delaney <lisa.delaney@amd.com>
This commit is contained in:
@@ -20,48 +20,57 @@ namespace wrapper {
|
||||
* \tparam K1Value The number of K-dim elements that are packed together as
|
||||
* a separate logical dimension. Usually aligns with vector load size.
|
||||
*/
|
||||
template <index_t MPerXDLValue,
|
||||
index_t NPerXDLValue,
|
||||
index_t MXdlPerWaveValue,
|
||||
index_t NXdlPerWaveValue,
|
||||
index_t K1Value>
|
||||
template <typename MPerXDLValue,
|
||||
typename NPerXDLValue,
|
||||
typename MXdlPerWaveValue,
|
||||
typename NXdlPerWaveValue,
|
||||
typename K1Value>
|
||||
struct BlockwisGemmXdlTraits
|
||||
{
|
||||
static constexpr index_t MPerXDL = MPerXDLValue;
|
||||
static constexpr index_t NPerXDL = NPerXDLValue;
|
||||
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
|
||||
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
|
||||
static constexpr index_t K1 = K1Value;
|
||||
static constexpr auto MPerXDL = MPerXDLValue{};
|
||||
static constexpr auto NPerXDL = NPerXDLValue{};
|
||||
static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
|
||||
static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
|
||||
static constexpr auto K1 = K1Value{};
|
||||
};
|
||||
|
||||
// K1 = 4
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
|
||||
{
|
||||
};
|
||||
// K1 = 8
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
|
||||
{
|
||||
};
|
||||
// K1 = 16
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user