Miscellaneous small updates and corrections

This commit is contained in:
Qianfeng Zhang
2026-06-28 13:24:05 +00:00
parent 52eff34d21
commit 6cf17bc827
9 changed files with 32 additions and 38 deletions

View File

@@ -463,7 +463,7 @@ bool run_no_group_hstu_forward(const ck_tile::ArgParser& arg_parser, bool is_jag
params.p_drop = p_drop;
params.philox_seed = PHILOX_SEED;
params.philox_offset = PHILOX_OFFSET;
};
}
bool has_dropout = (params.p_drop > 0.0f);
ck_tile::HostTensor<uint8_t> rand_vals_host(
@@ -506,8 +506,8 @@ bool run_no_group_hstu_forward(const ck_tile::ArgParser& arg_parser, bool is_jag
rv_params.stride_batch = rand_vals_host.get_strides()[0];
rv_params.philox_seed = PHILOX_SEED;
rv_params.philox_offset = PHILOX_OFFSET;
};
};
}
}
hipStream_t stream;
@@ -1023,7 +1023,7 @@ bool run_group_hstu_forward(const ck_tile::ArgParser& arg_parser, int num_group)
rv_params.stride_nhead = rand_vals_host.get_strides()[2];
rv_params.philox_seed = PHILOX_SEED;
rv_params.philox_offset = PHILOX_OFFSET;
};
}
hipStream_t stream;

View File

@@ -1057,10 +1057,10 @@ struct HstuAttentionFwdKernel
}();
const auto [seqlen_k_start, seqlen_k_end] =
mask.template GetTileRangeAlongX<kHasDropout>(
i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
mask.GetTileRangeAlongX(bool_constant<kHasDropout>{},
i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
if constexpr(!kUseSoftmax)
{
@@ -1113,10 +1113,9 @@ struct HstuAttentionFwdKernel
}();
const auto [seqlen_k_start, seqlen_k_end] =
mask.template GetTileRangeAlongX<kHasDropout>(
i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
mask.GetTileRangeAlongX(i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
if constexpr(!kUseSoftmax)
{

View File

@@ -1060,10 +1060,10 @@ struct HstuAttentionFwdSplitKVKernel
}();
const auto [global_seqlen_k_start, global_seqlen_k_end] =
mask.template GetTileRangeAlongX<kHasDropout>(
i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
mask.GetTileRangeAlongX(bool_constant<kHasDropout>{},
i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit(
global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split);
@@ -1119,10 +1119,9 @@ struct HstuAttentionFwdSplitKVKernel
}();
const auto [global_seqlen_k_start, global_seqlen_k_end] =
mask.template GetTileRangeAlongX<kHasDropout>(
i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
mask.GetTileRangeAlongX(i_m0,
number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{});
const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit(
global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split);

View File

@@ -79,7 +79,7 @@ struct HstuCrossAttentionBlockMaskWithLocal
// i_y is the start offset of the current tile along the seqlen_q dimension
template <bool kHasDropout, index_t YTile, index_t XTile>
CK_TILE_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
GetTileRangeAlongX(bool_constant<kHasDropout>, index_t i_y, number<YTile>, number<XTile>) const
{
// handle two special cases first
if(!is_tile_in_first_split)
@@ -327,7 +327,7 @@ struct HstuSelfAttentionBlockMaskWithLocal
// i_y is the start offset of the current tile along the seqlen_q dimension
template <bool kHasDropout, index_t YTile, index_t XTile>
CK_TILE_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
GetTileRangeAlongX(bool_constant<kHasDropout>, index_t i_y, number<YTile>, number<XTile>) const
{
// handle two special cases first
if(!is_tile_in_first_split)
@@ -563,7 +563,7 @@ struct HstuCrossAttentionBlockMaskNoLocal
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
// i_y is the start offset of the current tile along the seqlen_q dimension
template <bool kHasDropout, index_t YTile, index_t XTile>
template <index_t YTile, index_t XTile>
CK_TILE_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
@@ -699,7 +699,7 @@ struct HstuSelfAttentionBlockMaskNoLocal
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
// i_y is the start offset of the current tile along the seqlen_q dimension
template <bool kHasDropout, index_t YTile, index_t XTile>
template <index_t YTile, index_t XTile>
CK_TILE_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{

View File

@@ -22,8 +22,8 @@ void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersPa
param.stride_batch,
{param.philox_seed, param.philox_offset});
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
param.num_batch, param.num_head, param.seqlen_q, param.seqlen_k);
dim3 kGridSize =
HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.seqlen_q);
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;

View File

@@ -22,8 +22,8 @@ void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersPar
param.stride_batch,
{param.philox_seed, param.philox_offset});
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
param.num_batch, param.num_head, param.seqlen_q, param.seqlen_k);
dim3 kGridSize =
HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.seqlen_q);
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;

View File

@@ -21,8 +21,8 @@ void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersPar
param.seq_k_offsets_ptr,
{param.philox_seed, param.philox_offset});
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
param.num_batch, param.num_head, param.max_seqlen_q, param.seqlen_k);
dim3 kGridSize =
HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;

View File

@@ -21,8 +21,8 @@ void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersPara
param.seq_k_offsets_ptr,
{param.philox_seed, param.philox_offset});
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
param.num_batch, param.num_head, param.max_seqlen_q, param.seqlen_k);
dim3 kGridSize =
HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;

View File

@@ -156,13 +156,9 @@ struct HstuRandUniformKernel
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t seqlen_k_)
__host__ static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
(void)seqlen_k_; // not used at present
// at present, seqlen_k is not splitted by thread-groups
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kMPerBlock), nhead_, batch_size_);
}