[CK_TILE] Port hw independent changes from internal repo to develop branch (#3301)

* [CK_TILE] Port hw independent changes from internal repo to develop branch

It includes PR#96, #114, #120, #121.

* correct rebase error
This commit is contained in:
linqunAMD
2025-12-13 01:28:37 +08:00
committed by GitHub
parent 9869641324
commit fc7bf0ab1c
22 changed files with 465 additions and 310 deletions

View File

@@ -561,6 +561,7 @@ struct GroupedGemmKernel
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
block_sync_lds();
block_id = block_id + grid_size; // advance to next block
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
if(block_id >= cum_grid_size)

View File

@@ -631,6 +631,7 @@ struct StreamKKernel
tile_idx += kargs.tile_partitioner.get_grid())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();
}
// Stream-K section
@@ -679,8 +680,8 @@ struct StreamKKernel
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
int num_cu = dev_prop.multiProcessorCount;
return num_cu;
@@ -700,7 +701,7 @@ struct StreamKKernel
constexpr int min_block_per_cu = 1;
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
return max(occupancy, 1);

View File

@@ -280,7 +280,7 @@ struct UniversalGemmKernel
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<1, Kernel, KernelArgs>;
int occupancy;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
const int grid_size = get_available_compute_units(s) * occupancy;