[CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12 (#2808)

* [CK_TILE] Fix example batched_gemm, grouped_gemm, gemm_multi_d, convolution on gfx11 & gfx12

* fix gemm_splitk_two_stage

* revert .pre-commit-config.yaml
This commit is contained in:
linqunAMD
2025-09-11 22:27:33 +08:00
committed by GitHub
parent 0b9a638f26
commit 60d3e8f504
22 changed files with 439 additions and 192 deletions

View File

@@ -529,7 +529,10 @@ struct GroupedConvolutionBackwardDataKernel
return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvBwdDataKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdDataHostArgs& hostArgs)

View File

@@ -392,7 +392,10 @@ struct GroupedConvolutionBackwardWeightKernel
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)

View File

@@ -398,7 +398,10 @@ struct GroupedConvolutionForwardKernel
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static auto BlockSize()
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)