mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4704 (commit 17662f9)
[CK_TILE] Fix FP8 MXGEMM numerical error in async load path (#4704) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes FP8 MXGEMM producing half the expected result (e.g., 128 instead of 256 with all 1s input). **Bug introduced in:** `b7de1e14cea70681a23cd1a136df42910c776e4a` - "[CK_TILE] Add blockscale GEMM support for EightWarps on gfx950 (#4280)" ## Root Cause In the `static_move_ys=true` code path in `tile_window.hpp`, the IMM optimization computes `lds_ys_offset` using a default-constructed tensor descriptor: ```cpp make_tensor_coordinate(decltype(tensor_descriptor){}, idx_ys_offset) ``` This default-constructed descriptor has different strides than the actual DRAM tensor descriptor used for dram_ys_offset. When these offsets are mixed in the address calculation: ```cpp imm_valid = lds_ys_offset % IMM_RANGE; // From wrong descriptor wave_offset = dram_ys_offset - imm_valid; // From correct descriptor ``` The final address wave_offset + imm_valid ≠ dram_ys_offset, causing incorrect memory accesses. Fix ```cpp Set imm_valid = 0 to bypass the IMM optimization and ensure the full offset is passed through wave_offset: constexpr auto imm_valid = 0; // Avoids inconsistency between lds_ys_offset and dram_ys_offset ``` This disables the 12-bit immediate field optimization in the buffer_load_lds instruction but guarantees correctness. A proper fix would require making the DRAM tensor descriptor constexpr, which is not feasible since tensor strides depend on runtime parameters (LDA, LDB).
This commit is contained in:
committed by
assistant-librarian[bot]
parent
816abdcf9f
commit
6aa1cd8212
@@ -627,16 +627,9 @@ struct tile_window_with_static_distribution
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
constexpr auto IMM_RANGE =
|
||||
(1 << 12) / sizeof(typename Base::DataType) * Traits::PackedSize;
|
||||
constexpr auto imm_total = lds_ys_offset;
|
||||
constexpr auto imm_valid = imm_total % IMM_RANGE;
|
||||
constexpr auto imm_overflow = imm_total - imm_valid;
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
|
||||
lds_coord.get_offset() / Traits::PackedSize +
|
||||
imm_overflow / Traits::PackedSize;
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_base_ptr + (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize;
|
||||
|
||||
const auto dram_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
@@ -656,13 +649,14 @@ struct tile_window_with_static_distribution
|
||||
offset + dram_ys_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
else
|
||||
{
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord.get_offset() + offset,
|
||||
dram_ys_offset - imm_valid,
|
||||
number<imm_valid>{},
|
||||
dram_ys_offset,
|
||||
number<0>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
}
|
||||
// Move thread coordinate if not last access
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user