[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)

- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats
- Implement A(bf8) B(i4) support in universal GEMM
- Use new implementation for i4 to fp8 conversion in Block Scale

[ROCm/composable_kernel commit: 82890192dd]
This commit is contained in:
Cong Ma
2025-09-09 17:40:52 -06:00
committed by GitHub
parent 22490acf0b
commit f7ffd111ee
15 changed files with 320 additions and 135 deletions

View File

@@ -14,6 +14,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "test_gemm_aquant_utils.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename GemmConfig,
typename ADataType,
@@ -336,7 +337,17 @@ bool run_gemm_test_with_layouts(int argc,
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();