mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 3 (#2723)
Support Wave32/Wave64 in all XDL Kernels 1. Add following helper function/marocs in device_base.hpp - GET_NXDL_PER_WAVE_IMPL and GetNXdlPerWave2 - INVOKER_RUN_IMPL and INVOKER_RUN3_IMPL - IsValidGemmCompilationParameter and IS_VALID_COMPILATION_PARAMETER_IMPL 2. Replace GridwiseGemm to GridwiseGemm32 and GridwiseGemm64, and use one of them according to current GPU target 3. Move gridwise gemm related variable from Argument member to local variable in RunImp - It is to avoid duplicated GridwiseGemm::CheckValidity 4. Add IsValidGemmCompilationParameter to all XDL kernels. Know issues: - DeviceBatchedGemmXdl and DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle are incorrect on gfx11. - DeviceGemmMultipleDLayernorm_Xdl_CShuffle are incorrect on both gfx11 and gfx12.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
@@ -76,28 +76,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
|
||||
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
|
||||
// | | | | | | | | | | | Wave| Wave| Wave| |
|
||||
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
|
||||
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
|
||||
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
|
||||
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
|
||||
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
|
||||
{ 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
|
||||
{ 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
|
||||
{ 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
|
||||
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 16, 16, 4, 8, 4, 1},
|
||||
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 16, 16, 4, 8, 8, 1},
|
||||
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
|
||||
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
|
||||
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1},
|
||||
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1},
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
|
||||
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
|
||||
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
|
||||
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
|
||||
{ 256, 128, 256, 64, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
|
||||
{ 256, 128, 256, 64, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
|
||||
// Padded fallback kernel
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
|
||||
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 16, 16, 2, 4, 8, 1},
|
||||
// Irregular k
|
||||
{ 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
|
||||
{ 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
|
||||
{ 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
|
||||
{ 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
|
||||
{ 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 256, 128, 48, 64, 32, 4, 4, 2, 16, 16, 4, 8, 4, 1},
|
||||
{ 256, 256, 128, 48, 128, 32, 4, 4, 2, 16, 16, 4, 8, 8, 1},
|
||||
{ 256, 128, 256, 48, 64, 32, 4, 4, 2, 16, 16, 2, 16, 4, 1},
|
||||
{ 256, 128, 256, 48, 128, 32, 4, 4, 2, 16, 16, 2, 16, 8, 1},
|
||||
{ 256, 128, 128, 48, 64, 32, 4, 4, 2, 16, 16, 2, 8, 4, 1},
|
||||
{ 256, 128, 128, 48, 128, 32, 4, 4, 2, 16, 16, 2, 8, 8, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -200,28 +200,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// _MBlock_MWaveMPerXdl| ScalarPerVector
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1,16>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1,16>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1,16>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1,16>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
// Padded fallback kernel
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
// Irregular k
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
@@ -81,16 +81,16 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
|
||||
// | | | | | | | | Wave| Wave| Stage|
|
||||
// | | | | | | | | | | |
|
||||
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
|
||||
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
|
||||
{ 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 16, 16, 4, 8, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 128, 128, 64, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 128, 64, 128, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, 1},
|
||||
{ 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, 1},
|
||||
// Irregular tile
|
||||
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
|
||||
{ 64, 32, 32, 32, 8, 8, 16, 16, 2, 2, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -194,14 +194,14 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// _MBlock_MWaveMPerXdl| ScalarPerVector
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 4>, 8},
|
||||
{ S<1, 16, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 4>, 4},
|
||||
{ S<1, 16, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
// Irregular tile
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
// clang-format on
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include <iostream>
|
||||
@@ -55,12 +55,12 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
|
||||
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
|
||||
// | | | | | | | | Wave| Wave| Stage|
|
||||
// | | | | | | | | | | |
|
||||
{ 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}
|
||||
{ 64, 64, 32, 32, 8, 8, 16, 16, 4, 2, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 16, 16, 8, 4, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 64, 64, 64, 32, 8, 8, 16, 16, 4, 4, 1},
|
||||
{ 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -116,11 +116,11 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1, 16>, 4},
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1, 8>, 8}
|
||||
{ S<1, 32, 1, 8>, 4},
|
||||
{ S<1, 16, 1, 8>, 4}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -223,8 +223,9 @@ extern "C" __global__ void run_${name}(
|
||||
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = DeviceConv::GridwiseGemm;
|
||||
|
||||
using GridwiseGemm = ck::conditional_t<ck::get_warp_size() == 64,
|
||||
typename DeviceConv::GridwiseGemm64,
|
||||
typename DeviceConv::GridwiseGemm32>;
|
||||
static constexpr auto I0 = ck::Number<0>{};
|
||||
|
||||
ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/utils.hpp"
|
||||
|
||||
@@ -13,7 +13,8 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
|
||||
|
||||
const std::unordered_set<std::string>& get_xdlop_archs()
|
||||
{
|
||||
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx942"};
|
||||
static std::unordered_set<std::string> supported_archs{
|
||||
"gfx90a", "gfx908", "gfx942", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"};
|
||||
return supported_archs;
|
||||
}
|
||||
|
||||
|
||||
@@ -160,9 +160,10 @@ struct Epilogue
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
int i = 0;
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
|
||||
@@ -160,9 +160,10 @@ struct Epilogue
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
int i = 0;
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
|
||||
@@ -160,9 +160,10 @@ struct Epilogue
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
int i = 0;
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
|
||||
@@ -160,9 +160,10 @@ struct Epilogue
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
int i = 0;
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
|
||||
Reference in New Issue
Block a user