fix settings for example, fix some things in pipeline

This commit is contained in:
Sami Remes
2025-12-19 12:35:03 -05:00
parent 6a4951cf8c
commit 86cc59e754
9 changed files with 105 additions and 115 deletions

View File

@@ -242,7 +242,7 @@ struct MXGemmPipelineAgBgCrV1
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
};
// Helper for Math Loop
// Helper for Main Loop
auto warp_gemm_loop = [&](auto& a_warp_window, auto& b_warp_window, auto& scale_a, auto& scale_b) {
// Define register tiles types for double buffering
using AValType = decltype(load_tile_with_offset(a_warp_window, tuple<number<0>, number<0>>{}));

View File

@@ -227,31 +227,10 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<K_Thread / AK1, K_Lane, AK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<2, 2>, // K_Thread/AK1, AK1
sequence<0, 2>>{});
}
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
{
constexpr index_t K2 = BK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t N2 = WaveSize / K1; // 8
constexpr index_t N1 = BlockSize / WaveSize; // 4
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // N0,K0,K2
sequence<0, 0, 2>>{});
}
template <typename TensorView>
CK_TILE_DEVICE static constexpr auto
@@ -294,6 +273,29 @@ struct MXGemmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
CK_TILE_DEVICE static constexpr auto MakeMX_BDramTileDistribution()
{
// TODO: these could be replaced by the standard UniversalGEMM tile distributions??
constexpr index_t K2 = BK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * BPackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t N2 = WaveSize / K1; // 8
constexpr index_t N1 = BlockSize / WaveSize; // 4
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // N0,K0,K2
sequence<0, 0, 2>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLdsBlockDescriptor()
{
constexpr index_t K2 = BK1; // f4=32; f8=16