adjust A_LDS descriptor to avoid bankconflict

This commit is contained in:
root
2025-08-22 03:20:46 -05:00
parent 65989e940c
commit d69cab7f0c

View File

@@ -32,38 +32,29 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr index_t XDL_PerWeightK = 4;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<KPerBlock / KPack / XDL_PerWeightK>{},
number<MPerBlock>{},
number<XDL_PerWeightK>{},
number<KPack>{}),
make_tuple(number<KPack * XDL_PerWeightK>{},
number<KPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
number<KPerBlock / KPack / XDL_PerWeightK>{})),
make_pass_through_transform(number<XDL_PerWeightK>{}),
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}, sequence<3>{}));
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc =
transform_tensor_descriptor(a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(
number<KPerBlock / KPack / XDL_PerWeightK>{},
number<XDL_PerWeightK>{},
number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
@@ -194,4 +185,4 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
};
} // namespace ck_tile
} // namespace ck_tile