Update chunk size for f4x2

This commit is contained in:
Rostyslav Geyyer
2025-03-10 21:09:38 +00:00
parent facbaab7b7
commit d466096e25

View File

@@ -226,7 +226,8 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
static constexpr int32_t WAVE_SIZE = 64;
// Here we want to load from rows of A in chunks of 16 elements each.
static constexpr uint32_t chunk_size = 16;
static constexpr uint32_t chunk_size =
16 / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1);
// each chunk is separated by offset
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;