Extend XDL kernel to Support RDNA3/4 - Part 3 (#2723)

Support Wave32/Wave64 in all XDL Kernels

1. Add following helper function/marocs in device_base.hpp
- GET_NXDL_PER_WAVE_IMPL and GetNXdlPerWave2
- INVOKER_RUN_IMPL and INVOKER_RUN3_IMPL
- IsValidGemmCompilationParameter and IS_VALID_COMPILATION_PARAMETER_IMPL
2. Replace GridwiseGemm to GridwiseGemm32 and GridwiseGemm64, and use one of them according to current GPU target
3. Move gridwise gemm related variable from Argument member to local variable in RunImp
- It is to avoid duplicated GridwiseGemm::CheckValidity
4. Add IsValidGemmCompilationParameter to all XDL kernels.

Know issues:
- DeviceBatchedGemmXdl  and DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle are incorrect on gfx11.
- DeviceGemmMultipleDLayernorm_Xdl_CShuffle are incorrect on both gfx11 and gfx12.
This commit is contained in:
linqunAMD
2025-09-09 11:22:36 +08:00
committed by GitHub
parent e4a7728903
commit 0f8e33f811
131 changed files with 8731 additions and 5329 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/stringutils.hpp"
@@ -76,28 +76,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
// | | | | | | | | | | | Wave| Wave| Wave| |
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
{ 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
{ 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
{ 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 16, 16, 4, 8, 4, 1},
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 16, 16, 4, 8, 8, 1},
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1},
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1},
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
{ 256, 128, 256, 64, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
{ 256, 128, 256, 64, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
// Padded fallback kernel
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 16, 16, 2, 4, 8, 1},
// Irregular k
{ 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
{ 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
{ 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
{ 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
{ 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
{ 256, 256, 128, 48, 64, 32, 4, 4, 2, 16, 16, 4, 8, 4, 1},
{ 256, 256, 128, 48, 128, 32, 4, 4, 2, 16, 16, 4, 8, 8, 1},
{ 256, 128, 256, 48, 64, 32, 4, 4, 2, 16, 16, 2, 16, 4, 1},
{ 256, 128, 256, 48, 128, 32, 4, 4, 2, 16, 16, 2, 16, 8, 1},
{ 256, 128, 128, 48, 64, 32, 4, 4, 2, 16, 16, 2, 8, 4, 1},
{ 256, 128, 128, 48, 128, 32, 4, 4, 2, 16, 16, 2, 8, 8, 1},
// clang-format on
};
@@ -200,28 +200,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// _MBlock_MWaveMPerXdl| ScalarPerVector
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// |
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1,16>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1,16>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 16, 1,16>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 16, 1,16>, 4},
{ S<1, 32, 1, 8>, 4},
// Padded fallback kernel
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
// Irregular k
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
// clang-format on
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/stringutils.hpp"
@@ -81,16 +81,16 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
// | | | | | | | | Wave| Wave| Stage|
// | | | | | | | | | | |
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
{ 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1},
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1},
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1},
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
{ 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1},
{ 256, 128, 256, 32, 8, 8, 16, 16, 4, 8, 1},
{ 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1},
{ 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1},
{ 128, 128, 64, 32, 8, 8, 16, 16, 4, 4, 1},
{ 128, 64, 128, 32, 8, 8, 16, 16, 4, 4, 1},
{ 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, 1},
{ 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, 1},
// Irregular tile
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
{ 64, 32, 32, 32, 8, 8, 16, 16, 2, 2, 1},
// clang-format on
};
@@ -194,14 +194,14 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// _MBlock_MWaveMPerXdl| ScalarPerVector
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// |
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 4>, 8},
{ S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 16, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 4>, 4},
{ S<1, 16, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 32, 1, 8>, 4},
// Irregular tile
{ S<1, 16, 1, 4>, 1},
// clang-format on

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include <iostream>
@@ -55,12 +55,12 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
// | | | | | | | | Wave| Wave| Stage|
// | | | | | | | | | | |
{ 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1},
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}
{ 64, 64, 32, 32, 8, 8, 16, 16, 4, 2, 1},
{ 256, 128, 256, 32, 8, 8, 16, 16, 8, 4, 1},
{ 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1},
{ 64, 64, 64, 32, 8, 8, 16, 16, 4, 4, 1},
{ 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1},
{ 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1}
// clang-format on
};
@@ -116,11 +116,11 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// |
{ S<1, 16, 1, 4>, 1},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1, 16>, 4},
{ S<1, 32, 1, 8>, 4},
{ S<1, 16, 1, 4>, 1},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1, 8>, 8}
{ S<1, 32, 1, 8>, 4},
{ S<1, 16, 1, 8>, 4}
// clang-format on
};
@@ -223,8 +223,9 @@ extern "C" __global__ void run_${name}(
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
// GridwiseGemm
using GridwiseGemm = DeviceConv::GridwiseGemm;
using GridwiseGemm = ck::conditional_t<ck::get_warp_size() == 64,
typename DeviceConv::GridwiseGemm64,
typename DeviceConv::GridwiseGemm32>;
static constexpr auto I0 = ck::Number<0>{};
ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/utils.hpp"
@@ -13,7 +13,8 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx942"};
static std::unordered_set<std::string> supported_archs{
"gfx90a", "gfx908", "gfx942", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"};
return supported_archs;
}

View File

@@ -160,9 +160,10 @@ struct Epilogue
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
int i = 0;
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,

View File

@@ -160,9 +160,10 @@ struct Epilogue
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
int i = 0;
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,

View File

@@ -160,9 +160,10 @@ struct Epilogue
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
int i = 0;
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,

View File

@@ -160,9 +160,10 @@ struct Epilogue
Epilogue{1.0f, 1.0f});
out_host.SetZero();
ref_invoker.Run(ref_argument);**/
int i = 0;
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
{
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
// substitute instance values into the template
auto src = ck::host::InterpolateString(
conv_compile_check,