Fix in comments

This commit is contained in:
Qianfeng Zhang
2026-02-24 07:53:50 +00:00
parent 62cf370749
commit 4d83c92fc4
3 changed files with 8 additions and 7 deletions

View File

@@ -78,7 +78,7 @@ struct has_naive_hdim_load_flag<
template <typename T>
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
// A helper struct for detechting kUseTrLoad
// A helper struct for detecting kUseTrLoad
template <typename T, typename = void>
struct has_use_trload_flag : std::false_type
{

View File

@@ -318,7 +318,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
Policy::template MakeBiasDramTileDistribution<Problem>());
// assuming no random values need be saved, this is true when the pipeline is called from
// xformers, since we have a separate kernel to generated randomm values
// xformers, since we have a separate kernel to generatd random values
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{

View File

@@ -50,9 +50,10 @@ struct TileFmhaShape
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN0Sub = BlockTile::at(number<2>{}); // tile size for dividing kN0
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kN0Sub = BlockTile::at(
number<2>{}); // same index as kK0; used as subdivision factor when dividing kN0
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
@@ -63,8 +64,8 @@ struct TileFmhaShape
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
using VLayout = std::conditional_t<IsVLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
template <typename BlockTile_, // sequence<...