mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Miscellaneous small updates and corrections
This commit is contained in:
@@ -463,7 +463,7 @@ bool run_no_group_hstu_forward(const ck_tile::ArgParser& arg_parser, bool is_jag
|
||||
params.p_drop = p_drop;
|
||||
params.philox_seed = PHILOX_SEED;
|
||||
params.philox_offset = PHILOX_OFFSET;
|
||||
};
|
||||
}
|
||||
|
||||
bool has_dropout = (params.p_drop > 0.0f);
|
||||
ck_tile::HostTensor<uint8_t> rand_vals_host(
|
||||
@@ -506,8 +506,8 @@ bool run_no_group_hstu_forward(const ck_tile::ArgParser& arg_parser, bool is_jag
|
||||
rv_params.stride_batch = rand_vals_host.get_strides()[0];
|
||||
rv_params.philox_seed = PHILOX_SEED;
|
||||
rv_params.philox_offset = PHILOX_OFFSET;
|
||||
};
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
hipStream_t stream;
|
||||
|
||||
@@ -1023,7 +1023,7 @@ bool run_group_hstu_forward(const ck_tile::ArgParser& arg_parser, int num_group)
|
||||
rv_params.stride_nhead = rand_vals_host.get_strides()[2];
|
||||
rv_params.philox_seed = PHILOX_SEED;
|
||||
rv_params.philox_offset = PHILOX_OFFSET;
|
||||
};
|
||||
}
|
||||
|
||||
hipStream_t stream;
|
||||
|
||||
|
||||
@@ -1057,10 +1057,10 @@ struct HstuAttentionFwdKernel
|
||||
}();
|
||||
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.template GetTileRangeAlongX<kHasDropout>(
|
||||
i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
mask.GetTileRangeAlongX(bool_constant<kHasDropout>{},
|
||||
i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
@@ -1113,10 +1113,9 @@ struct HstuAttentionFwdKernel
|
||||
}();
|
||||
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.template GetTileRangeAlongX<kHasDropout>(
|
||||
i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
mask.GetTileRangeAlongX(i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
|
||||
if constexpr(!kUseSoftmax)
|
||||
{
|
||||
|
||||
@@ -1060,10 +1060,10 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
}();
|
||||
|
||||
const auto [global_seqlen_k_start, global_seqlen_k_end] =
|
||||
mask.template GetTileRangeAlongX<kHasDropout>(
|
||||
i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
mask.GetTileRangeAlongX(bool_constant<kHasDropout>{},
|
||||
i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
|
||||
const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit(
|
||||
global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split);
|
||||
@@ -1119,10 +1119,9 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
}();
|
||||
|
||||
const auto [global_seqlen_k_start, global_seqlen_k_end] =
|
||||
mask.template GetTileRangeAlongX<kHasDropout>(
|
||||
i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
mask.GetTileRangeAlongX(i_m0,
|
||||
number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{});
|
||||
|
||||
const auto [seqlen_k_start, seqlen_k_end] = CalculateTileRangeAlongXForSplit(
|
||||
global_seqlen_k_start, global_seqlen_k_end, kargs.num_splits, i_split);
|
||||
|
||||
@@ -79,7 +79,7 @@ struct HstuCrossAttentionBlockMaskWithLocal
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <bool kHasDropout, index_t YTile, index_t XTile>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
GetTileRangeAlongX(bool_constant<kHasDropout>, index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
// handle two special cases first
|
||||
if(!is_tile_in_first_split)
|
||||
@@ -327,7 +327,7 @@ struct HstuSelfAttentionBlockMaskWithLocal
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <bool kHasDropout, index_t YTile, index_t XTile>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
GetTileRangeAlongX(bool_constant<kHasDropout>, index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
// handle two special cases first
|
||||
if(!is_tile_in_first_split)
|
||||
@@ -563,7 +563,7 @@ struct HstuCrossAttentionBlockMaskNoLocal
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <bool kHasDropout, index_t YTile, index_t XTile>
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
@@ -699,7 +699,7 @@ struct HstuSelfAttentionBlockMaskNoLocal
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <bool kHasDropout, index_t YTile, index_t XTile>
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
|
||||
@@ -22,8 +22,8 @@ void hstu_generate_batched_random_number_uint16(HstuGenerateRandUniformNumbersPa
|
||||
param.stride_batch,
|
||||
{param.philox_seed, param.philox_offset});
|
||||
|
||||
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
|
||||
param.num_batch, param.num_head, param.seqlen_q, param.seqlen_k);
|
||||
dim3 kGridSize =
|
||||
HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.seqlen_q);
|
||||
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ void hstu_generate_batched_random_number_uint8(HstuGenerateRandUniformNumbersPar
|
||||
param.stride_batch,
|
||||
{param.philox_seed, param.philox_offset});
|
||||
|
||||
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
|
||||
param.num_batch, param.num_head, param.seqlen_q, param.seqlen_k);
|
||||
dim3 kGridSize =
|
||||
HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.seqlen_q);
|
||||
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ void hstu_generate_jagged_random_number_uint16(HstuGenerateRandUniformNumbersPar
|
||||
param.seq_k_offsets_ptr,
|
||||
{param.philox_seed, param.philox_offset});
|
||||
|
||||
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
|
||||
param.num_batch, param.num_head, param.max_seqlen_q, param.seqlen_k);
|
||||
dim3 kGridSize =
|
||||
HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ void hstu_generate_jagged_random_number_uint8(HstuGenerateRandUniformNumbersPara
|
||||
param.seq_k_offsets_ptr,
|
||||
{param.philox_seed, param.philox_offset});
|
||||
|
||||
dim3 kGridSize = HstuRandUniformKernel_::GridSize(
|
||||
param.num_batch, param.num_head, param.max_seqlen_q, param.seqlen_k);
|
||||
dim3 kGridSize =
|
||||
HstuRandUniformKernel_::GridSize(param.num_batch, param.num_head, param.max_seqlen_q);
|
||||
dim3 kBlockSize = HstuRandUniformKernel_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuRandUniformKernel_::kBlockPerCu;
|
||||
|
||||
|
||||
@@ -156,13 +156,9 @@ struct HstuRandUniformKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t seqlen_k_)
|
||||
__host__ static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
|
||||
{
|
||||
(void)seqlen_k_; // not used at present
|
||||
|
||||
// at present, seqlen_k is not splitted by thread-groups
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kMPerBlock), nhead_, batch_size_);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user