save packing approach

This commit is contained in:
Sami Remes
2026-02-06 15:54:57 +00:00
parent a8d48f9224
commit 061c9f9374
3 changed files with 205 additions and 220 deletions

View File

@@ -59,15 +59,20 @@ auto pack_scales_for_k_dimension(const ScaleTensor& scale_unpacked,
for(ck_tile::index_t k = 0; k < pack_size; ++k)
{
// Strided packing: byte k corresponds to kIter=k
// Stride by packed dimension (new_dim1 for dim1 packing, 1 for dim0 packing since it's linear)
// Wait, we need to map unpacked logical positions to correct strided pattern
// For K=512: 16 unpacked elements [0-15] map to 4 int32s strided:
// int32[0] = {elem[0], elem[4], elem[8], elem[12]} (bytes 0,1,2,3 for kIter 0,1,2,3)
// int32[1] = {elem[1], elem[5], elem[9], elem[13]}
// ...
// So: packed_index j (or i), byte position k -> unpacked_index = j/i + k * packed_size
ck_tile::index_t src_i = pack_dim1 ? i : (i + k * packed_k_dim);
ck_tile::index_t src_j = pack_dim1 ? (j + k * packed_k_dim) : j;
// The stride is always pack_size (4), not packed_k_dim!
// For K=512: 16 unpacked elements [0-15] -> 4 packed int32s
// int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (stride=4)
// int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (stride=4)
// For K=1024: 32 unpacked elements [0-31] -> 8 packed int32s
// int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (stride=4)
// int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (stride=4)
// For row-major (pack_dim1=true): packed index j, byte k -> unpacked[j + k*4]
// For col-major (pack_dim1=false): packed index i, byte k -> unpacked[i*4 + k*4] = unpacked[(i + k)*4]
// But we want: packed index i, byte k -> unpacked[i*4 + k] (base i*4, then stride 4)
// Actually: int32[i] should pack {unpacked[i*4 + 0*4], unpacked[i*4 + 1*4], unpacked[i*4 + 2*4], unpacked[i*4 + 3*4]}
// = {unpacked[i*4], unpacked[i*4 + 4], unpacked[i*4 + 8], unpacked[i*4 + 12]}
ck_tile::index_t src_i = pack_dim1 ? i : (i * pack_size + k * pack_size);
ck_tile::index_t src_j = pack_dim1 ? (j * pack_size + k * pack_size) : j;
uint8_t scale_byte = *reinterpret_cast<const uint8_t*>(&scale_unpacked(src_i, src_j));
packed_value |= (static_cast<int32_t>(scale_byte) << (k * 8));
@@ -140,13 +145,14 @@ int run_mx_gemm_with_layouts(int argc,
ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{})));
ck_tile::HostTensor<ScaleType> scale_b_host(
ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
int seed = 1234;
switch(init_method)
{
case 0:
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 10.f}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 10.f}(scale_b_host);
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f, seed++}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f, seed++}(b_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
break;
case 1:
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
@@ -155,11 +161,82 @@ int run_mx_gemm_with_layouts(int argc,
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
break;
case 2:
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f, seed++}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f, seed++}(b_host);
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
break;
case 3:
// Debug mode: simple power-of-2 pattern for scales (e8m0 format)
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
ck_tile::FillConstant<BDataType>{BDataType(1.f)}(b_host);
// Fill scales with power-of-2 pattern: 1.0, 2.0, 4.0, 8.0, 16.0, ...
// e8m0 is exponent-only, so these give clear distinct values
// for(std::size_t i = 0; i < scale_a_host.mDesc.get_element_space_size(); ++i)
// {
// float val = std::pow(2.0f, static_cast<float>(i % 16)); // cycle through 2^0 to 2^15
// scale_a_host.mData[i] = ScaleType(val);
// }
// for(std::size_t i = 0; i < scale_b_host.mDesc.get_element_space_size(); ++i)
// {
// float val = std::pow(2.0f, static_cast<float>(i % 16)); // cycle through 2^0 to 2^15
// scale_b_host.mData[i] = ScaleType(val);
// }
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
// Test data to verify K block loading for K=1024 (2 K blocks)
// K block 0: K indices 0-511, scale K indices 0-15, packed into K_packed indices 0-3
// K block 1: K indices 512-1023, scale K indices 16-31, packed into K_packed indices 4-7
// Scale A: [M, K/32] row-major (unpacked K indices in second dim)
// Strided packing: int32[j] packs unpacked[j], unpacked[j+4], unpacked[j+8], unpacked[j+12]
// K block 0: K indices 0-511 → unpacked K indices 0-15 → packed int32s 0-3
// int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (K_Lane=0)
// int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (K_Lane=1)
scale_a_host(0, 0) = ScaleType(2.f); // K block 0, int32[0] byte 0 (unpacked[0])
scale_a_host(0, 4) = ScaleType(4.f); // K block 0, int32[0] byte 1 (unpacked[4])
scale_a_host(0, 8) = ScaleType(8.f); // K block 0, int32[0] byte 2 (unpacked[8])
scale_a_host(0, 12) = ScaleType(16.f); // K block 0, int32[0] byte 3 (unpacked[12])
scale_a_host(1, 0) = ScaleType(32.f); // K block 0, int32[1] byte 0 (unpacked[1])
// K block 1: K indices 512-1023 → unpacked K indices 16-31 → packed int32s 4-7
// int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (K_Lane=0)
// int32[5] = {unpacked[17], unpacked[21], unpacked[25], unpacked[29]} (K_Lane=1)
scale_a_host(0, 16) = ScaleType(256.f); // K block 1, int32[4] byte 0 (unpacked[16])
scale_a_host(0, 20) = ScaleType(512.f); // K block 1, int32[4] byte 1 (unpacked[20])
scale_a_host(0, 24) = ScaleType(1024.f); // K block 1, int32[4] byte 2 (unpacked[24])
scale_a_host(0, 28) = ScaleType(2048.f); // K block 1, int32[4] byte 3 (unpacked[28])
scale_a_host(1, 16) = ScaleType(4096.f); // K block 1, int32[5] byte 0 (unpacked[17])
// mIter=1: M rows 16-31 (second XDL block)
scale_a_host(16, 0) = ScaleType(64.f); // K block 0, unpacked K=0, M=16
scale_a_host(16, 16) = ScaleType(8192.f); // K block 1, unpacked K=16, M=16
// Scale B: [K/32, N] col-major (unpacked K indices in first dim, N in second dim)
// Strided packing: int32[i] packs unpacked[i], unpacked[i+8], unpacked[i+16], unpacked[i+24]
// K block 0: K indices 0-511 → unpacked K indices 0-15 → packed int32s 0-3
// int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (K_Lane=0)
// int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (K_Lane=1)
scale_b_host(0, 0) = ScaleType(2.f); // K block 0, int32[0] byte 0 (unpacked[0])
scale_b_host(4, 0) = ScaleType(4.f); // K block 0, int32[0] byte 1 (unpacked[4])
scale_b_host(8, 0) = ScaleType(8.f); // K block 0, int32[0] byte 2 (unpacked[8])
scale_b_host(12, 0) = ScaleType(16.f); // K block 0, int32[0] byte 3 (unpacked[12])
scale_b_host(1, 0) = ScaleType(32.f); // K block 0, int32[1] byte 0 (unpacked[1])
// K block 1: K indices 512-1023 → unpacked K indices 16-31 → packed int32s 4-7
// int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (K_Lane=0)
// int32[5] = {unpacked[17], unpacked[21], unpacked[25], unpacked[29]} (K_Lane=1)
scale_b_host(16, 0) = ScaleType(256.f); // K block 1, int32[4] byte 0 (unpacked[16])
scale_b_host(20, 0) = ScaleType(512.f); // K block 1, int32[4] byte 1 (unpacked[20])
scale_b_host(24, 0) = ScaleType(1024.f); // K block 1, int32[4] byte 2 (unpacked[24])
scale_b_host(28, 0) = ScaleType(2048.f); // K block 1, int32[4] byte 3 (unpacked[28])
scale_b_host(17, 0) = ScaleType(4096.f); // K block 1, int32[5] byte 0 (unpacked[17])
// nIter=1: N rows 16-31 (second XDL block)
scale_b_host(0, 16) = ScaleType(64.f); // K block 0, unpacked K=0, N=16
scale_b_host(16, 16) = ScaleType(8192.f); // K block 1, unpacked K=16, N=16
break;
}
// Pack scales: 4 consecutive e8m0_t in K dimension → 1 int32 for efficient 32-bit loads
@@ -169,6 +246,48 @@ int run_mx_gemm_with_layouts(int argc,
auto scale_a_packed = pack_scales_for_k_dimension(scale_a_host, 4);
auto scale_b_packed = pack_scales_for_k_dimension(scale_b_host, 4);
// DEBUG: Print first few packed scale values
if (true ||init_method == 3)
{
std::cout << "Host: ScaleA packed [0,0]: ";
uint8_t* a_bytes = reinterpret_cast<uint8_t*>(&scale_a_packed(0, 0));
std::cout << "[" << static_cast<int>(a_bytes[0]) << "," << static_cast<int>(a_bytes[1]) << ","
<< static_cast<int>(a_bytes[2]) << "," << static_cast<int>(a_bytes[3]) << "]\n";
std::cout << "Host: ScaleA packed [0,4]: ";
uint8_t* a_bytes4 = reinterpret_cast<uint8_t*>(&scale_a_packed(0, 4));
std::cout << "[" << static_cast<int>(a_bytes4[0]) << "," << static_cast<int>(a_bytes4[1]) << ","
<< static_cast<int>(a_bytes4[2]) << "," << static_cast<int>(a_bytes4[3]) << "]\n";
std::cout << "Host: ScaleB packed [0,0]: ";
uint8_t* b_bytes = reinterpret_cast<uint8_t*>(&scale_b_packed(0, 0));
std::cout << "[" << static_cast<int>(b_bytes[0]) << "," << static_cast<int>(b_bytes[1]) << ","
<< static_cast<int>(b_bytes[2]) << "," << static_cast<int>(b_bytes[3]) << "]\n";
std::cout << "Host: ScaleB packed [4,0]: ";
uint8_t* b_bytes4 = reinterpret_cast<uint8_t*>(&scale_b_packed(4, 0));
std::cout << "[" << static_cast<int>(b_bytes4[0]) << "," << static_cast<int>(b_bytes4[1]) << ","
<< static_cast<int>(b_bytes4[2]) << "," << static_cast<int>(b_bytes4[3]) << "]\n";
// Print unpacked first row/col for reference
std::cout << "Host: ScaleA unpacked thread 0, every 4th element: [";
for (int k = 0; k < std::min(32, static_cast<int>(scale_a_host.mDesc.get_lengths()[1])); k += 4)
std::cout << static_cast<int>(*reinterpret_cast<uint8_t*>(&scale_a_host(0, k))) << ",";
std::cout << "]\n";
std::cout << "Host: ScaleB unpacked thread 0, every 4th element: [";
for (int k = 0; k < std::min(32, static_cast<int>(scale_b_host.mDesc.get_lengths()[0])); k += 4)
std::cout << static_cast<int>(*reinterpret_cast<uint8_t*>(&scale_b_host(k, 0))) << ",";
std::cout << "]\n";
// Threads 0-15: M rows 0-15, K_Lane cycles through 0,1,2,3
// Thread 16: M row 0 again, but next K_Lane group (K_Lane=1 if cycling, or next K group)
// Actually, thread 16 goes back to row 0 with a different K index
std::cout << "Host: ScaleA unpacked thread 16 (row 0, next K group), every 4th element: [";
for (int k = 1; k < std::min(32, static_cast<int>(scale_a_host.mDesc.get_lengths()[1])); k += 4)
std::cout << static_cast<int>(*reinterpret_cast<uint8_t*>(&scale_a_host(0, k))) << ",";
std::cout << "]\n";
std::cout << "Host: ScaleB unpacked thread 16 (row 0, next K group), every 4th element: [";
for (int k = 1; k < std::min(32, static_cast<int>(scale_b_host.mDesc.get_lengths()[0])); k += 4)
std::cout << static_cast<int>(*reinterpret_cast<uint8_t*>(&scale_b_host(k, 0))) << ",";
std::cout << "]\n";
}
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());