mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
save packing approach
This commit is contained in:
@@ -59,15 +59,20 @@ auto pack_scales_for_k_dimension(const ScaleTensor& scale_unpacked,
|
||||
for(ck_tile::index_t k = 0; k < pack_size; ++k)
|
||||
{
|
||||
// Strided packing: byte k corresponds to kIter=k
|
||||
// Stride by packed dimension (new_dim1 for dim1 packing, 1 for dim0 packing since it's linear)
|
||||
// Wait, we need to map unpacked logical positions to correct strided pattern
|
||||
// For K=512: 16 unpacked elements [0-15] map to 4 int32s strided:
|
||||
// int32[0] = {elem[0], elem[4], elem[8], elem[12]} (bytes 0,1,2,3 for kIter 0,1,2,3)
|
||||
// int32[1] = {elem[1], elem[5], elem[9], elem[13]}
|
||||
// ...
|
||||
// So: packed_index j (or i), byte position k -> unpacked_index = j/i + k * packed_size
|
||||
ck_tile::index_t src_i = pack_dim1 ? i : (i + k * packed_k_dim);
|
||||
ck_tile::index_t src_j = pack_dim1 ? (j + k * packed_k_dim) : j;
|
||||
// The stride is always pack_size (4), not packed_k_dim!
|
||||
// For K=512: 16 unpacked elements [0-15] -> 4 packed int32s
|
||||
// int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (stride=4)
|
||||
// int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (stride=4)
|
||||
// For K=1024: 32 unpacked elements [0-31] -> 8 packed int32s
|
||||
// int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (stride=4)
|
||||
// int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (stride=4)
|
||||
// For row-major (pack_dim1=true): packed index j, byte k -> unpacked[j + k*4]
|
||||
// For col-major (pack_dim1=false): packed index i, byte k -> unpacked[i*4 + k*4] = unpacked[(i + k)*4]
|
||||
// But we want: packed index i, byte k -> unpacked[i*4 + k] (base i*4, then stride 4)
|
||||
// Actually: int32[i] should pack {unpacked[i*4 + 0*4], unpacked[i*4 + 1*4], unpacked[i*4 + 2*4], unpacked[i*4 + 3*4]}
|
||||
// = {unpacked[i*4], unpacked[i*4 + 4], unpacked[i*4 + 8], unpacked[i*4 + 12]}
|
||||
ck_tile::index_t src_i = pack_dim1 ? i : (i * pack_size + k * pack_size);
|
||||
ck_tile::index_t src_j = pack_dim1 ? (j * pack_size + k * pack_size) : j;
|
||||
|
||||
uint8_t scale_byte = *reinterpret_cast<const uint8_t*>(&scale_unpacked(src_i, src_j));
|
||||
packed_value |= (static_cast<int32_t>(scale_byte) << (k * 8));
|
||||
@@ -140,13 +145,14 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(
|
||||
ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
|
||||
int seed = 1234;
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 10.f}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 10.f}(scale_b_host);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f, seed++}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f, seed++}(b_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
|
||||
break;
|
||||
case 1:
|
||||
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
|
||||
@@ -155,11 +161,82 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
|
||||
break;
|
||||
case 2:
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_host);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f, seed++}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f, seed++}(b_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
|
||||
break;
|
||||
case 3:
|
||||
// Debug mode: simple power-of-2 pattern for scales (e8m0 format)
|
||||
ck_tile::FillConstant<ADataType>{ADataType(1.f)}(a_host);
|
||||
ck_tile::FillConstant<BDataType>{BDataType(1.f)}(b_host);
|
||||
// Fill scales with power-of-2 pattern: 1.0, 2.0, 4.0, 8.0, 16.0, ...
|
||||
// e8m0 is exponent-only, so these give clear distinct values
|
||||
// for(std::size_t i = 0; i < scale_a_host.mDesc.get_element_space_size(); ++i)
|
||||
// {
|
||||
// float val = std::pow(2.0f, static_cast<float>(i % 16)); // cycle through 2^0 to 2^15
|
||||
// scale_a_host.mData[i] = ScaleType(val);
|
||||
// }
|
||||
// for(std::size_t i = 0; i < scale_b_host.mDesc.get_element_space_size(); ++i)
|
||||
// {
|
||||
// float val = std::pow(2.0f, static_cast<float>(i % 16)); // cycle through 2^0 to 2^15
|
||||
// scale_b_host.mData[i] = ScaleType(val);
|
||||
// }
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_a_host);
|
||||
ck_tile::FillConstant<ScaleType>{ScaleType(1.f)}(scale_b_host);
|
||||
|
||||
// Test data to verify K block loading for K=1024 (2 K blocks)
|
||||
// K block 0: K indices 0-511, scale K indices 0-15, packed into K_packed indices 0-3
|
||||
// K block 1: K indices 512-1023, scale K indices 16-31, packed into K_packed indices 4-7
|
||||
|
||||
// Scale A: [M, K/32] row-major (unpacked K indices in second dim)
|
||||
// Strided packing: int32[j] packs unpacked[j], unpacked[j+4], unpacked[j+8], unpacked[j+12]
|
||||
// K block 0: K indices 0-511 → unpacked K indices 0-15 → packed int32s 0-3
|
||||
// int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (K_Lane=0)
|
||||
// int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (K_Lane=1)
|
||||
scale_a_host(0, 0) = ScaleType(2.f); // K block 0, int32[0] byte 0 (unpacked[0])
|
||||
scale_a_host(0, 4) = ScaleType(4.f); // K block 0, int32[0] byte 1 (unpacked[4])
|
||||
scale_a_host(0, 8) = ScaleType(8.f); // K block 0, int32[0] byte 2 (unpacked[8])
|
||||
scale_a_host(0, 12) = ScaleType(16.f); // K block 0, int32[0] byte 3 (unpacked[12])
|
||||
scale_a_host(1, 0) = ScaleType(32.f); // K block 0, int32[1] byte 0 (unpacked[1])
|
||||
|
||||
// K block 1: K indices 512-1023 → unpacked K indices 16-31 → packed int32s 4-7
|
||||
// int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (K_Lane=0)
|
||||
// int32[5] = {unpacked[17], unpacked[21], unpacked[25], unpacked[29]} (K_Lane=1)
|
||||
scale_a_host(0, 16) = ScaleType(256.f); // K block 1, int32[4] byte 0 (unpacked[16])
|
||||
scale_a_host(0, 20) = ScaleType(512.f); // K block 1, int32[4] byte 1 (unpacked[20])
|
||||
scale_a_host(0, 24) = ScaleType(1024.f); // K block 1, int32[4] byte 2 (unpacked[24])
|
||||
scale_a_host(0, 28) = ScaleType(2048.f); // K block 1, int32[4] byte 3 (unpacked[28])
|
||||
scale_a_host(1, 16) = ScaleType(4096.f); // K block 1, int32[5] byte 0 (unpacked[17])
|
||||
|
||||
// mIter=1: M rows 16-31 (second XDL block)
|
||||
scale_a_host(16, 0) = ScaleType(64.f); // K block 0, unpacked K=0, M=16
|
||||
scale_a_host(16, 16) = ScaleType(8192.f); // K block 1, unpacked K=16, M=16
|
||||
|
||||
// Scale B: [K/32, N] col-major (unpacked K indices in first dim, N in second dim)
|
||||
// Strided packing: int32[i] packs unpacked[i], unpacked[i+8], unpacked[i+16], unpacked[i+24]
|
||||
// K block 0: K indices 0-511 → unpacked K indices 0-15 → packed int32s 0-3
|
||||
// int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (K_Lane=0)
|
||||
// int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (K_Lane=1)
|
||||
scale_b_host(0, 0) = ScaleType(2.f); // K block 0, int32[0] byte 0 (unpacked[0])
|
||||
scale_b_host(4, 0) = ScaleType(4.f); // K block 0, int32[0] byte 1 (unpacked[4])
|
||||
scale_b_host(8, 0) = ScaleType(8.f); // K block 0, int32[0] byte 2 (unpacked[8])
|
||||
scale_b_host(12, 0) = ScaleType(16.f); // K block 0, int32[0] byte 3 (unpacked[12])
|
||||
scale_b_host(1, 0) = ScaleType(32.f); // K block 0, int32[1] byte 0 (unpacked[1])
|
||||
|
||||
// K block 1: K indices 512-1023 → unpacked K indices 16-31 → packed int32s 4-7
|
||||
// int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (K_Lane=0)
|
||||
// int32[5] = {unpacked[17], unpacked[21], unpacked[25], unpacked[29]} (K_Lane=1)
|
||||
scale_b_host(16, 0) = ScaleType(256.f); // K block 1, int32[4] byte 0 (unpacked[16])
|
||||
scale_b_host(20, 0) = ScaleType(512.f); // K block 1, int32[4] byte 1 (unpacked[20])
|
||||
scale_b_host(24, 0) = ScaleType(1024.f); // K block 1, int32[4] byte 2 (unpacked[24])
|
||||
scale_b_host(28, 0) = ScaleType(2048.f); // K block 1, int32[4] byte 3 (unpacked[28])
|
||||
scale_b_host(17, 0) = ScaleType(4096.f); // K block 1, int32[5] byte 0 (unpacked[17])
|
||||
|
||||
// nIter=1: N rows 16-31 (second XDL block)
|
||||
scale_b_host(0, 16) = ScaleType(64.f); // K block 0, unpacked K=0, N=16
|
||||
scale_b_host(16, 16) = ScaleType(8192.f); // K block 1, unpacked K=16, N=16
|
||||
break;
|
||||
}
|
||||
|
||||
// Pack scales: 4 consecutive e8m0_t in K dimension → 1 int32 for efficient 32-bit loads
|
||||
@@ -169,6 +246,48 @@ int run_mx_gemm_with_layouts(int argc,
|
||||
auto scale_a_packed = pack_scales_for_k_dimension(scale_a_host, 4);
|
||||
auto scale_b_packed = pack_scales_for_k_dimension(scale_b_host, 4);
|
||||
|
||||
// DEBUG: Print first few packed scale values
|
||||
if (true ||init_method == 3)
|
||||
{
|
||||
std::cout << "Host: ScaleA packed [0,0]: ";
|
||||
uint8_t* a_bytes = reinterpret_cast<uint8_t*>(&scale_a_packed(0, 0));
|
||||
std::cout << "[" << static_cast<int>(a_bytes[0]) << "," << static_cast<int>(a_bytes[1]) << ","
|
||||
<< static_cast<int>(a_bytes[2]) << "," << static_cast<int>(a_bytes[3]) << "]\n";
|
||||
std::cout << "Host: ScaleA packed [0,4]: ";
|
||||
uint8_t* a_bytes4 = reinterpret_cast<uint8_t*>(&scale_a_packed(0, 4));
|
||||
std::cout << "[" << static_cast<int>(a_bytes4[0]) << "," << static_cast<int>(a_bytes4[1]) << ","
|
||||
<< static_cast<int>(a_bytes4[2]) << "," << static_cast<int>(a_bytes4[3]) << "]\n";
|
||||
std::cout << "Host: ScaleB packed [0,0]: ";
|
||||
uint8_t* b_bytes = reinterpret_cast<uint8_t*>(&scale_b_packed(0, 0));
|
||||
std::cout << "[" << static_cast<int>(b_bytes[0]) << "," << static_cast<int>(b_bytes[1]) << ","
|
||||
<< static_cast<int>(b_bytes[2]) << "," << static_cast<int>(b_bytes[3]) << "]\n";
|
||||
std::cout << "Host: ScaleB packed [4,0]: ";
|
||||
uint8_t* b_bytes4 = reinterpret_cast<uint8_t*>(&scale_b_packed(4, 0));
|
||||
std::cout << "[" << static_cast<int>(b_bytes4[0]) << "," << static_cast<int>(b_bytes4[1]) << ","
|
||||
<< static_cast<int>(b_bytes4[2]) << "," << static_cast<int>(b_bytes4[3]) << "]\n";
|
||||
|
||||
// Print unpacked first row/col for reference
|
||||
std::cout << "Host: ScaleA unpacked thread 0, every 4th element: [";
|
||||
for (int k = 0; k < std::min(32, static_cast<int>(scale_a_host.mDesc.get_lengths()[1])); k += 4)
|
||||
std::cout << static_cast<int>(*reinterpret_cast<uint8_t*>(&scale_a_host(0, k))) << ",";
|
||||
std::cout << "]\n";
|
||||
std::cout << "Host: ScaleB unpacked thread 0, every 4th element: [";
|
||||
for (int k = 0; k < std::min(32, static_cast<int>(scale_b_host.mDesc.get_lengths()[0])); k += 4)
|
||||
std::cout << static_cast<int>(*reinterpret_cast<uint8_t*>(&scale_b_host(k, 0))) << ",";
|
||||
std::cout << "]\n";
|
||||
// Threads 0-15: M rows 0-15, K_Lane cycles through 0,1,2,3
|
||||
// Thread 16: M row 0 again, but next K_Lane group (K_Lane=1 if cycling, or next K group)
|
||||
// Actually, thread 16 goes back to row 0 with a different K index
|
||||
std::cout << "Host: ScaleA unpacked thread 16 (row 0, next K group), every 4th element: [";
|
||||
for (int k = 1; k < std::min(32, static_cast<int>(scale_a_host.mDesc.get_lengths()[1])); k += 4)
|
||||
std::cout << static_cast<int>(*reinterpret_cast<uint8_t*>(&scale_a_host(0, k))) << ",";
|
||||
std::cout << "]\n";
|
||||
std::cout << "Host: ScaleB unpacked thread 16 (row 0, next K group), every 4th element: [";
|
||||
for (int k = 1; k < std::min(32, static_cast<int>(scale_b_host.mDesc.get_lengths()[0])); k += 4)
|
||||
std::cout << static_cast<int>(*reinterpret_cast<uint8_t*>(&scale_b_host(k, 0))) << ",";
|
||||
std::cout << "]\n";
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
|
||||
|
||||
Reference in New Issue
Block a user