mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)
CK Tile MX GEMM Packing Improvement ## Motivation Reduce the scale loading size and also has better utilization of MFMA scale selection. ## Technical Details Add up the packing of mx scales. ## Test Plan Use the existing test cases. ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
859acb5ae7
commit
5f90f69795
@@ -14,7 +14,56 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_v
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
// Use e8m0_t directly without packing - simpler and cleaner approach
|
||||
// Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t
|
||||
// Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching
|
||||
// the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K.
|
||||
// byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position
|
||||
// kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N])
|
||||
template <ck_tile::index_t MNPack = 2,
|
||||
ck_tile::index_t KPack = 2,
|
||||
ck_tile::index_t XdlMNThread = 16,
|
||||
ck_tile::index_t XdlKThread = 4>
|
||||
auto packScalesMNxK(const ck_tile::HostTensor<ck_tile::e8m0_t>& src, bool kLast)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const ck_tile::index_t MN = kLast ? src_lengths[0] : src_lengths[1];
|
||||
const ck_tile::index_t K_scale = kLast ? src_lengths[1] : src_lengths[0];
|
||||
const ck_tile::index_t MN_packed = MN / MNPack;
|
||||
const ck_tile::index_t K_packed = K_scale / KPack;
|
||||
const ck_tile::index_t total_packed = MN_packed * K_packed;
|
||||
|
||||
// Output as flat vector of int32_t (row-major [MN/MNPack, K/32/KPack])
|
||||
std::vector<int32_t> packed(total_packed);
|
||||
|
||||
for(ck_tile::index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++)
|
||||
{
|
||||
for(ck_tile::index_t packed_k = 0; packed_k < K_packed; packed_k++)
|
||||
{
|
||||
int32_t val = 0;
|
||||
ck_tile::index_t mn_lane = packed_mn % XdlMNThread;
|
||||
ck_tile::index_t mn_group = packed_mn / XdlMNThread;
|
||||
ck_tile::index_t k_lane = packed_k % XdlKThread;
|
||||
ck_tile::index_t k_group = packed_k / XdlKThread;
|
||||
for(ck_tile::index_t ik = 0; ik < KPack; ik++)
|
||||
{
|
||||
for(ck_tile::index_t imn = 0; imn < MNPack; imn++)
|
||||
{
|
||||
ck_tile::index_t byteIdx = ik * MNPack + imn;
|
||||
ck_tile::index_t orig_mn =
|
||||
mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane;
|
||||
ck_tile::index_t orig_k =
|
||||
k_group * XdlKThread * KPack + ik * XdlKThread + k_lane;
|
||||
|
||||
ck_tile::e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn);
|
||||
val |= (static_cast<int32_t>(v.get()) << (byteIdx * 8));
|
||||
}
|
||||
}
|
||||
packed[packed_mn * K_packed + packed_k] = val;
|
||||
}
|
||||
}
|
||||
return packed;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -101,21 +150,43 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
break;
|
||||
}
|
||||
|
||||
// Device buffers for A, B, C, and scale tensors
|
||||
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
|
||||
constexpr ck_tile::index_t MPerXdl_ = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl_ = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl_ = GemmConfig::K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp_ = GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl_);
|
||||
constexpr ck_tile::index_t NIterPerWarp_ = GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl_);
|
||||
constexpr ck_tile::index_t KIterPerWarp_ = GemmConfig::K_Tile / KPerXdl_;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff = (MIterPerWarp_ >= 2 && MIterPerWarp_ % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff = (NIterPerWarp_ >= 2 && NIterPerWarp_ % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff = (KIterPerWarp_ >= 2 && KIterPerWarp_ % 2 == 0) ? 2 : 1;
|
||||
|
||||
// Pack scales: [M, K/32] e8m0_t → [M/MXdlPackEff, K/32/KXdlPackEff] int32_t
|
||||
// Original unpacked tensors are kept for CPU reference validation
|
||||
constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
auto scale_a_packed =
|
||||
packScalesMNxK<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_a_host, true);
|
||||
auto scale_b_packed =
|
||||
packScalesMNxK<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_b_host, false);
|
||||
|
||||
// Device buffers for A, B, C, and packed scale tensors
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t));
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_dev_buf.ToDevice(b_host.data());
|
||||
c_dev_buf.SetZero(); // Initialize C buffer to zero
|
||||
scale_a_dev_buf.ToDevice(scale_a_host.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_host.data());
|
||||
c_dev_buf.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_packed.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_packed.data());
|
||||
|
||||
// Scale pointers - use e8m0_t* directly
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>; // in blocks of 32 in K
|
||||
// Scale pointers - point to packed int32_t data, kernel reinterprets as int32_t*
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
Reference in New Issue
Block a user