mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)
- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats
- Implement A(bf8) B(i4) support in universal GEMM
- Use new implementation for i4 to fp8 conversion in Block Scale
[ROCm/composable_kernel commit: 82890192dd]
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_aquant_utils.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -336,7 +337,17 @@ bool run_gemm_test_with_layouts(int argc,
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
|
||||
Reference in New Issue
Block a user