mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 4 (#2724)
* Fix example * fix build error * update pk_i4 & moe test case * fix all instance build (examples) * fix batched_gemm_gemm (example) * disable example_gemm_bias_softmax_gemm_permute on gfx11 * remove unnecessary disable gfx11 * update tests * update tests2
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <numeric>
|
||||
#include <cstdlib>
|
||||
@@ -81,6 +81,7 @@ __global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
|
||||
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
|
||||
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
|
||||
@@ -256,6 +257,16 @@ __global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
#else
|
||||
ck::ignore = p_a;
|
||||
ck::ignore = p_b;
|
||||
ck::ignore = p_c;
|
||||
ck::ignore = M;
|
||||
ck::ignore = N;
|
||||
ck::ignore = K;
|
||||
ck::ignore = tile_shape;
|
||||
ck::ignore = thread_layout;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
@@ -374,3 +385,14 @@ TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4, false>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This test support gfx9 only" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user