Update chunk size

This commit is contained in:
Rostyslav Geyyer
2025-03-19 20:37:05 +00:00
parent 221862c912
commit 5adf19ccb3
2 changed files with 10 additions and 9 deletions

View File

@@ -520,8 +520,8 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
using arg_type = int32x8_t;
printf("!!!!!!! %d %d %d %d ", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
printf("??????? %d %d %d %d ", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
@@ -591,8 +591,8 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
using arg_type = int32x8_t;
printf("!!!!!!! %d %d %d %d ", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
printf("??????? %d %d %d %d ", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
@@ -663,8 +663,8 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
using arg_type = int32x8_t;
printf("!!!!!!! %d %d %d %d ", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
printf("??????? %d %d %d %d ", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
@@ -731,8 +731,8 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
using arg_type = int32x8_t;
printf("!!!!!!! %d %d %d %d ", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
printf("??????? %d %d %d %d ", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
printf("!!!!!!! %d %d %d %d \n", arg_a[0], arg_a[1], arg_a[2], arg_a[3]);
printf("??????? %d %d %d %d \n", arg_b[0], arg_b[1], arg_b[2], arg_b[3]);
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(

View File

@@ -397,7 +397,8 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
static constexpr int32_t WAVE_SIZE = 64;
// Here we want to load from cols of B in chunks of 16 elements each.
static constexpr uint32_t chunk_size = 16;
static constexpr uint32_t chunk_size =
16 / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1);
// each chunk is separated by an offset
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64