mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)
CK Tile MX GEMM Packing Improvement ## Motivation Reduce the scale loading size and also has better utilization of MFMA scale selection. ## Technical Details Add up the packing of mx scales. ## Test Plan Use the existing test cases. ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
859acb5ae7
commit
5f90f69795
@@ -45,6 +45,55 @@ class TestMxGemmUtil : public ::testing::Test
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
|
||||
// Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t
|
||||
// Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching
|
||||
// the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K.
|
||||
// byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position
|
||||
// kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N])
|
||||
template <ck_tile::index_t MNPack = 2,
|
||||
ck_tile::index_t KPack = 2,
|
||||
ck_tile::index_t XdlMNThread = 16,
|
||||
ck_tile::index_t XdlKThread = 4>
|
||||
static auto packScalesMNxK(const ck_tile::HostTensor<ck_tile::e8m0_t>& src, bool kLast)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const ck_tile::index_t MN = kLast ? src_lengths[0] : src_lengths[1];
|
||||
const ck_tile::index_t K_scale = kLast ? src_lengths[1] : src_lengths[0];
|
||||
const ck_tile::index_t MN_packed = MN / MNPack;
|
||||
const ck_tile::index_t K_packed = K_scale / KPack;
|
||||
const ck_tile::index_t total_packed = MN_packed * K_packed;
|
||||
|
||||
std::vector<int32_t> packed(total_packed);
|
||||
|
||||
for(ck_tile::index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++)
|
||||
{
|
||||
for(ck_tile::index_t packed_k = 0; packed_k < K_packed; packed_k++)
|
||||
{
|
||||
int32_t val = 0;
|
||||
ck_tile::index_t mn_lane = packed_mn % XdlMNThread;
|
||||
ck_tile::index_t mn_group = packed_mn / XdlMNThread;
|
||||
ck_tile::index_t k_lane = packed_k % XdlKThread;
|
||||
ck_tile::index_t k_group = packed_k / XdlKThread;
|
||||
for(ck_tile::index_t ik = 0; ik < KPack; ik++)
|
||||
{
|
||||
for(ck_tile::index_t imn = 0; imn < MNPack; imn++)
|
||||
{
|
||||
ck_tile::index_t byteIdx = ik * MNPack + imn;
|
||||
ck_tile::index_t orig_mn =
|
||||
mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane;
|
||||
ck_tile::index_t orig_k =
|
||||
k_group * XdlKThread * KPack + ik * XdlKThread + k_lane;
|
||||
|
||||
ck_tile::e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn);
|
||||
val |= (static_cast<int32_t>(v.get()) << (byteIdx * 8));
|
||||
}
|
||||
}
|
||||
packed[packed_mn * K_packed + packed_k] = val;
|
||||
}
|
||||
}
|
||||
return packed;
|
||||
}
|
||||
|
||||
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234)
|
||||
{
|
||||
const ck_tile::index_t scale_k_size = K / 32;
|
||||
@@ -75,17 +124,43 @@ class TestMxGemmUtil : public ::testing::Test
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
|
||||
|
||||
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
|
||||
constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl = GemmConfig::K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp =
|
||||
GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl);
|
||||
constexpr ck_tile::index_t NIterPerWarp =
|
||||
GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl);
|
||||
constexpr ck_tile::index_t KIterPerWarp = GemmConfig::K_Tile / KPerXdl;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff =
|
||||
(MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff =
|
||||
(NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff =
|
||||
(KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
|
||||
constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
// Pack scales into int32_t for GPU consumption
|
||||
auto scale_a_packed =
|
||||
packScalesMNxK<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_a_host, true);
|
||||
auto scale_b_packed =
|
||||
packScalesMNxK<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_b_host, false);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t));
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_dev_buf.ToDevice(b_host.data());
|
||||
c_dev_buf.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_host.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_host.data());
|
||||
scale_a_dev_buf.ToDevice(scale_a_packed.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_packed.data());
|
||||
|
||||
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
Reference in New Issue
Block a user