Files
composable_kernel/rocm_ck/tests/unit/unit_fmha_bwd_consteval.cpp
chris-tsiaousis-hpc dc3c1cffd5 [rocm-libraries] ROCm/rocm-libraries#7891 (commit 4dee41d)
Porting existing FMHA infra from users/shumway/ck/exp-kpack
 to develop (#7891)

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
2026-06-09 14:00:32 +00:00

231 lines
11 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Consteval tests for sliding-window attention (SWA) and bottom-right causal
// mask configurations.
//
// These tests verify compile-time guarantees for window_size_left/right and
// mask_type slot layouts, and that the SWA/CMaskBR variant specs are correctly
// derived from the underlying cmask configuration.
//
// NOTE: The SWA/CMaskBR/CMask variants share the same compiled kernel binary.
// The mask_type and window sizes are runtime-parametrized via scalar slots.
// Consteval tests verify the slot infrastructure, not the runtime values.
#include "fmha_bwd_variants.hpp"
#include <rocm_ck/ops/fmha_bwd/dqdkdv_spec.hpp>
#include <gtest/gtest.h>
using ::rocm_ck::DataType;
using ::rocm_ck::FmhaBiasType;
using ::rocm_ck::FmhaBwdDQDKDVConfig;
using ::rocm_ck::FmhaBwdDQDKDVSpec;
using ::rocm_ck::FmhaMaskType;
using ::rocm_ck::FmhaMode;
using ::rocm_ck::makeSpec;
using ::rocm_ck::usesBatchSizeSlot;
namespace S = ::rocm_ck::fmha_bwd_dqdkdv_slots;
// Pin the scalar-slot layout for the SWA / mask infrastructure. The runtime
// argument-packing code in the .hip variants writes window sizes and mask_type
// into these exact indices; if anyone renumbers the slots, the build must fail
// here rather than silently corrupting kernel inputs.
static_assert(S::WINDOW_SIZE_LEFT == 8, "WINDOW_SIZE_LEFT slot index drifted");
static_assert(S::WINDOW_SIZE_RIGHT == 9, "WINDOW_SIZE_RIGHT slot index drifted");
static_assert(S::MASK_TYPE == 10, "MASK_TYPE slot index drifted");
// ============================================================================
// Spec equivalence: SWA / CMaskBR / CMask share compiled spec (all fields)
// ============================================================================
TEST(FmhaBwdConsteval, SWA_AllFieldsMatchCMask)
{
// SWA, CMaskBR, and CMask produce identical compiled specs.
// The difference is purely runtime (mask_type + window sizes).
constexpr auto k_cmask =
::rocm_ck::fmha_bwd_dqdkdv_variant_spec("fmha_bwd_dqdkdv_fp16_d128_batch_cmask");
constexpr auto k_swa =
::rocm_ck::fmha_bwd_dqdkdv_variant_spec("fmha_bwd_dqdkdv_fp16_d128_batch_swa");
// Every field must match -- these share a compiled kernel binary.
EXPECT_EQ(k_swa.dtype, k_cmask.dtype);
EXPECT_EQ(k_swa.hdim_q, k_cmask.hdim_q);
EXPECT_EQ(k_swa.hdim_v, k_cmask.hdim_v);
EXPECT_EQ(k_swa.mode, k_cmask.mode);
EXPECT_EQ(k_swa.bias_type, k_cmask.bias_type);
EXPECT_EQ(k_swa.has_bias_grad, k_cmask.has_bias_grad);
EXPECT_EQ(k_swa.mask_type, k_cmask.mask_type);
EXPECT_EQ(k_swa.has_dropout, k_cmask.has_dropout);
EXPECT_EQ(k_swa.is_deterministic, k_cmask.is_deterministic);
EXPECT_EQ(k_swa.pad_hdim_q, k_cmask.pad_hdim_q);
EXPECT_EQ(k_swa.pad_hdim_v, k_cmask.pad_hdim_v);
EXPECT_EQ(k_swa.block_per_cu, k_cmask.block_per_cu);
EXPECT_EQ(k_swa.block_size, k_cmask.block_size);
EXPECT_EQ(k_swa.block_n0, k_cmask.block_n0);
}
TEST(FmhaBwdConsteval, CMaskBR_AllFieldsMatchCMask)
{
constexpr auto k_cmask =
::rocm_ck::fmha_bwd_dqdkdv_variant_spec("fmha_bwd_dqdkdv_fp16_d128_batch_cmask");
constexpr auto k_br =
::rocm_ck::fmha_bwd_dqdkdv_variant_spec("fmha_bwd_dqdkdv_fp16_d128_batch_cmask_br");
EXPECT_EQ(k_br.dtype, k_cmask.dtype);
EXPECT_EQ(k_br.hdim_q, k_cmask.hdim_q);
EXPECT_EQ(k_br.hdim_v, k_cmask.hdim_v);
EXPECT_EQ(k_br.mode, k_cmask.mode);
EXPECT_EQ(k_br.bias_type, k_cmask.bias_type);
EXPECT_EQ(k_br.has_bias_grad, k_cmask.has_bias_grad);
EXPECT_EQ(k_br.mask_type, k_cmask.mask_type);
EXPECT_EQ(k_br.has_dropout, k_cmask.has_dropout);
EXPECT_EQ(k_br.is_deterministic, k_cmask.is_deterministic);
EXPECT_EQ(k_br.pad_hdim_q, k_cmask.pad_hdim_q);
EXPECT_EQ(k_br.pad_hdim_v, k_cmask.pad_hdim_v);
EXPECT_EQ(k_br.block_per_cu, k_cmask.block_per_cu);
EXPECT_EQ(k_br.block_size, k_cmask.block_size);
EXPECT_EQ(k_br.block_n0, k_cmask.block_n0);
}
// ============================================================================
// Spec generation with mask: scalar slot requirements
// (These test the underlying makeSpec() system. The SWA/CMaskBR/CMask variants
// all use this same spec infrastructure.)
// ============================================================================
TEST(FmhaBwdConsteval, MaskedSpec_RequiredScalars_DeterministicBatch)
{
// Deterministic batch mode adds BATCH_SIZE slot (index 11) -> 12 scalars.
// This dominates the mask slots (index 10) -> 11 scalars.
constexpr auto k =
makeSpec(FmhaBwdDQDKDVConfig{.signature = {.dtype = DataType::FP16,
.hdim_q = 128,
.hdim_v = 128,
.mode = FmhaMode::BATCH},
.algorithm = {.mask_type = FmhaMaskType::TOP_LEFT_CAUSAL,
.is_deterministic = true,
.pad_hdim_q = 8,
.pad_hdim_v = 8}});
EXPECT_EQ(S::requiredScalars(k), 12);
}
TEST(FmhaBwdConsteval, MaskedSpec_RequiredScalars_DeterministicGroup)
{
// Group mode does not use the BATCH_SIZE slot (kernel derives batch count
// from seqstart pointers), so mask slots dominate -> 11 scalars.
constexpr auto k =
makeSpec(FmhaBwdDQDKDVConfig{.signature = {.dtype = DataType::FP16,
.hdim_q = 128,
.hdim_v = 128,
.mode = FmhaMode::GROUP},
.algorithm = {.mask_type = FmhaMaskType::TOP_LEFT_CAUSAL,
.is_deterministic = true,
.pad_hdim_q = 8,
.pad_hdim_v = 8}});
EXPECT_EQ(S::requiredScalars(k), 11);
}
// ============================================================================
// usesBatchSizeSlot with mask combinations
// ============================================================================
TEST(FmhaBwdConsteval, MaskedSpec_UsesBatchSizeSlot_WhenDeterministicBatch)
{
// Deterministic + batch mode -> uses BATCH_SIZE slot.
constexpr auto k =
makeSpec(FmhaBwdDQDKDVConfig{.signature = {.dtype = DataType::FP16,
.hdim_q = 128,
.hdim_v = 128,
.mode = FmhaMode::BATCH},
.algorithm = {.mask_type = FmhaMaskType::TOP_LEFT_CAUSAL,
.is_deterministic = true,
.pad_hdim_q = 8,
.pad_hdim_v = 8}});
EXPECT_TRUE(usesBatchSizeSlot(k));
}
TEST(FmhaBwdConsteval, MaskedSpec_DoesNotUseBatchSizeSlot_WhenDeterministicGroup)
{
// Deterministic + group mode -> does NOT use BATCH_SIZE slot.
constexpr auto k =
makeSpec(FmhaBwdDQDKDVConfig{.signature = {.dtype = DataType::FP16,
.hdim_q = 128,
.hdim_v = 128,
.mode = FmhaMode::GROUP},
.algorithm = {.mask_type = FmhaMaskType::TOP_LEFT_CAUSAL,
.is_deterministic = true,
.pad_hdim_q = 8,
.pad_hdim_v = 8}});
EXPECT_FALSE(usesBatchSizeSlot(k));
}
TEST(FmhaBwdConsteval, MaskedSpec_DoesNotUseBatchSizeSlot_WhenNonDeterministic)
{
// Mask without deterministic flag: batch slot is NOT required.
// Only mask+deterministic+batch uses the batch slot.
constexpr auto k = makeSpec(FmhaBwdDQDKDVConfig{
.signature =
{.dtype = DataType::FP16, .hdim_q = 128, .hdim_v = 128, .mode = FmhaMode::BATCH},
.algorithm = {
.mask_type = FmhaMaskType::TOP_LEFT_CAUSAL, .pad_hdim_q = 8, .pad_hdim_v = 8}});
EXPECT_FALSE(usesBatchSizeSlot(k));
}
// ============================================================================
// Tensor slot invariance: mask does not add tensor slots
// ============================================================================
TEST(FmhaBwdConsteval, MaskedSpec_RequiredTensors_UnchangedForGroup)
{
// Group mode always requires 16 tensor slots, regardless of mask.
constexpr auto k = makeSpec(FmhaBwdDQDKDVConfig{
.signature =
{.dtype = DataType::FP16, .hdim_q = 128, .hdim_v = 128, .mode = FmhaMode::GROUP},
.algorithm = {
.mask_type = FmhaMaskType::TOP_LEFT_CAUSAL, .pad_hdim_q = 8, .pad_hdim_v = 8}});
EXPECT_EQ(S::requiredTensors(k), 16);
}
// ============================================================================
// BF16 mask spec: SWA configurations work with both dtypes
// ============================================================================
TEST(FmhaBwdConsteval, MaskedSpec_BF16_WithMask)
{
constexpr auto k = makeSpec(FmhaBwdDQDKDVConfig{
.signature =
{.dtype = DataType::BF16, .hdim_q = 128, .hdim_v = 128, .mode = FmhaMode::BATCH},
.algorithm = {
.mask_type = FmhaMaskType::TOP_LEFT_CAUSAL, .pad_hdim_q = 8, .pad_hdim_v = 8}});
EXPECT_EQ(k.dtype, DataType::BF16);
EXPECT_NE(k.mask_type, FmhaMaskType::NO_MASK);
EXPECT_EQ(S::requiredScalars(k), 11);
}
TEST(FmhaBwdConsteval, MaskedSpec_BF16_WithMaskAndDeterministic)
{
constexpr auto k =
makeSpec(FmhaBwdDQDKDVConfig{.signature = {.dtype = DataType::BF16,
.hdim_q = 128,
.hdim_v = 128,
.mode = FmhaMode::BATCH},
.algorithm = {.mask_type = FmhaMaskType::TOP_LEFT_CAUSAL,
.is_deterministic = true,
.pad_hdim_q = 8,
.pad_hdim_v = 8}});
EXPECT_EQ(k.dtype, DataType::BF16);
EXPECT_NE(k.mask_type, FmhaMaskType::NO_MASK);
EXPECT_TRUE(k.is_deterministic);
EXPECT_EQ(S::requiredScalars(k), 12);
}