mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Fix the Wrong Output Generated by Gemm Examples on GFX11/12 (#2713)
* Introduce macro CK_TILE_USE_WMMA Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> * Make CK_TILE_USE_WMMA global for all examples Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> * Remove CK_TILE_USE_WMMA from config.hpp Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> --------- Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
This commit is contained in:
@@ -26,6 +26,15 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#else
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
@@ -33,6 +42,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#endif
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
|
||||
2
example/ck_tile/03_gemm/gemm_utils.hpp
Executable file → Normal file
2
example/ck_tile/03_gemm/gemm_utils.hpp
Executable file → Normal file
@@ -172,6 +172,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
{
|
||||
@@ -192,6 +193,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
@@ -335,7 +335,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user