Re enable f8 x bf8 tests on compv3 and compv4 (#3605)

* Re-enable f8 x bf8 tests on CompV3 as they now pass

* On CompV4, fp8 x bf8 tests now pass with K_BlockSize I32

* Add a changelog entry

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: 834642202c]
This commit is contained in:
SamiAario-AMD
2026-01-26 20:23:26 +02:00
committed by GitHub
parent ea30b43692
commit b07fbbc33a
3 changed files with 7 additions and 11 deletions

View File

@@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.2.0
### Added
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.

View File

@@ -13,13 +13,8 @@ class TestCkTileGemmPipelineCompV3
static constexpr bool check_data_type()
{
using Base = TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompV3<T>>;
if constexpr(std::is_same_v<typename Base::ADataType, F8> &&
std::is_same_v<typename Base::BDataType, BF8>)
{
return false;
}
else if constexpr(std::is_same_v<typename Base::BLayout, Row> &&
std::is_same_v<typename Base::BDataType, I4>)
if constexpr(std::is_same_v<typename Base::BLayout, Row> &&
std::is_same_v<typename Base::BDataType, I4>)
{
return false;
}

View File

@@ -170,7 +170,7 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
@@ -180,7 +180,7 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
@@ -190,7 +190,7 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
@@ -200,7 +200,7 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>