mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Extend gemm traits number for ck wrapper (#1153)
This commit is contained in:
@@ -225,10 +225,10 @@ TEST(TestGemm, Int8)
|
||||
using DataType = int8_t;
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{});
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 16>(
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 16>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
// Irregular case
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1>(
|
||||
129, 129, 67, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
@@ -237,10 +237,10 @@ TEST(TestGemm, Half)
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{});
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 8>(
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
// Irregular case
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1>(
|
||||
129, 129, 67, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user