mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
test(grouped_gemm_multi_d): add unit test for bf16 support
This commit is contained in:
committed by
Aviral Goel
parent
8d8b49dec2
commit
ca1ab083a7
@@ -8,9 +8,10 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_multi_d_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F32 = float;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F32 = float;
|
||||
|
||||
// Custom tuple-like structure for kernel configuration
|
||||
template <typename ALayout_,
|
||||
@@ -18,6 +19,8 @@ template <typename ALayout_,
|
||||
typename ELayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename D0DataType_,
|
||||
typename D1DataType_,
|
||||
typename AccDataType_,
|
||||
typename EDataType_,
|
||||
int M_Tile_val_,
|
||||
@@ -41,9 +44,11 @@ struct KernelConfig
|
||||
using DsLayoutType = ck_tile::tuple<Row, Row>;
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using D0DataType = D0DataType_;
|
||||
using D1DataType = D1DataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using EDataType = EDataType_;
|
||||
using DsDataType = ck_tile::tuple<F16, F16>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType_, D1DataType_>;
|
||||
|
||||
static constexpr int M_Tile_ = M_Tile_val_;
|
||||
static constexpr int N_Tile_ = N_Tile_val_;
|
||||
@@ -63,13 +68,21 @@ struct KernelConfig
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4
|
||||
// ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent
|
||||
// FP16 A/B/D/E
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4
|
||||
// BF16 A/B/D/E
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user