[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)

- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats
- Implement A(bf8) B(i4) support in universal GEMM
- Use new implementation for i4 to fp8 conversion in Block Scale

[ROCm/composable_kernel commit: 82890192dd]
This commit is contained in:
Cong Ma
2025-09-09 17:40:52 -06:00
committed by GitHub
parent 22490acf0b
commit f7ffd111ee
15 changed files with 320 additions and 135 deletions

View File

@@ -228,4 +228,4 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }

View File

@@ -5,6 +5,7 @@
#pragma once
#include <random>
#include <stdexcept>
#include "../00_shared/host_tensor_utils.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -217,7 +218,16 @@ int run_gemm_example_with_layouts(int argc,
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

View File

@@ -3,6 +3,7 @@
#pragma once
#include <random>
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -208,7 +209,17 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

View File

@@ -4,6 +4,7 @@
#pragma once
#include <random>
#include <stdexcept>
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -308,7 +309,17 @@ int run_gemm_example_with_layouts(int argc,
aq_dev_buf.ToDevice(aq_tensor.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();