mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Grouped batched attention + permute (#412)
* grouped attn without batch validates; now move toward grouped batched attn * grouped batched attention * working * remove debug logging clean up clean up * reintroduce g_ prefix back to host tensor variables * format * rename file * restore old file * rename * consolidate padded/non-padded attention example * harmonize padding specialization in attn examples
This commit is contained in:
@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
return is_valid;
|
||||
}
|
||||
|
||||
// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
|
||||
// workgroups assigned to a given gemm problem have top index offsetted to range [0,
|
||||
// grid_size_per_gemm]
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMap
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] - block_start_));
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user