Extend gemm traits number for ck wrapper (#1153)

This commit is contained in:
Bartłomiej Kocot
2024-02-02 20:25:54 +01:00
committed by GitHub
parent 112b691bb7
commit 171ca260b5
2 changed files with 25 additions and 4 deletions

View File

@@ -225,10 +225,10 @@ TEST(TestGemm, Int8)
using DataType = int8_t;
const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 16>(
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 16>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1>(
129, 129, 67, tile_shape, thread_layout);
}
@@ -237,10 +237,10 @@ TEST(TestGemm, Half)
using DataType = ck::half_t;
const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 8>(
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1>(
129, 129, 67, tile_shape, thread_layout);
}