mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Merge commit '7e93eed8787afd175d3a045303096a4a98638f4b' into develop
This commit is contained in:
@@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result,
|
||||
e_ms_ns_host_result,
|
||||
"Error: Incorrect results!",
|
||||
1e-4,
|
||||
1e-4)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -12,40 +12,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
@@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
@@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
@@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
|
||||
@@ -11,40 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
@@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
@@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
@@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
|
||||
@@ -11,24 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
|
||||
@@ -10,40 +10,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -100,7 +66,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType, bool Persistent>
|
||||
@@ -117,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
@@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
|
||||
return combined_hash;
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
@@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
|
||||
is_gfx12_supported() || is_gfx11_supported();
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
index_t MPerXDL64,
|
||||
index_t NPerXDL64,
|
||||
index_t MPerXDL32 = MPerXDL64,
|
||||
index_t NPerXDL32 = NPerXDL64>
|
||||
inline bool is_xdl_wmma_supported()
|
||||
{
|
||||
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
|
||||
}
|
||||
else if(is_gfx12_supported() || is_gfx11_supported())
|
||||
{
|
||||
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
|
||||
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#endif
|
||||
#endif
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
|
||||
IsWave64>(); \
|
||||
}
|
||||
|
||||
template <index_t BlockSize_,
|
||||
index_t MPerBlock_,
|
||||
index_t NPerBlock_,
|
||||
index_t MPerXDL_,
|
||||
index_t NPerXDL_,
|
||||
index_t MXdlPerWave_,
|
||||
index_t CShuffleMXdlPerWavePerShuffle_,
|
||||
index_t CShuffleNXdlPerWavePerShuffle_,
|
||||
bool IsWave64>
|
||||
static constexpr auto GetWarpTileConfig()
|
||||
{
|
||||
constexpr auto MXdlPerWave64 = MXdlPerWave_;
|
||||
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
|
||||
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
|
||||
|
||||
constexpr auto NXdlPerWave =
|
||||
IsWave64
|
||||
? GetNXdlPerWave2<BlockSize_,
|
||||
MPerBlock_,
|
||||
NPerBlock_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
true>()
|
||||
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
|
||||
|
||||
if constexpr(IsWave64 == false && NXdlPerWave != 0)
|
||||
{
|
||||
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
|
||||
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
: CShuffleNXdlPerWavePerShuffle_;
|
||||
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
|
||||
return Sequence<16,
|
||||
16,
|
||||
MXdlPerWave32,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle32,
|
||||
CShuffleNXdlPerWavePerShuffle32>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave64,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle_,
|
||||
CShuffleNXdlPerWavePerShuffle_>{};
|
||||
}
|
||||
}
|
||||
|
||||
#define INVOKER_RUN_IMPL \
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
|
||||
{ \
|
||||
|
||||
@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
true>();
|
||||
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
false>();
|
||||
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
|
||||
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
// GridwiseGemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <typename WarpTileConfig>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
WarpTileConfig::At(0),
|
||||
WarpTileConfig::At(1),
|
||||
WarpTileConfig::At(2),
|
||||
WarpTileConfig::At(3),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
WarpTileConfig::At(4),
|
||||
WarpTileConfig::At(5),
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType,
|
||||
ComputeDataType,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
WarpTileConfig32.At(0),
|
||||
WarpTileConfig32.At(1)>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< ABlockTransferSrcVectorDim << ", "
|
||||
|
||||
@@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
@@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
@@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
return shuffle_b(t, GemmConfig{});
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
{
|
||||
@@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
@@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
return shuffle_b_permuteN(t, GemmConfig{});
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -43,4 +43,26 @@ struct TileGemmShape
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
|
||||
constexpr index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
|
||||
else
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,26 +11,6 @@
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
{
|
||||
@@ -67,7 +47,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<BDataType, M_Warp_Tile>();
|
||||
static const ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<BDataType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
|
||||
static constexpr bool TransposeC = false; // transpose c is not supported
|
||||
@@ -101,46 +82,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
k_ / K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
@@ -340,6 +281,14 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
}
|
||||
}
|
||||
|
||||
struct BShuffleGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Warp_Tile =
|
||||
TestCkTileGroupedGemmPreshuffle::N_Warp_Tile;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
TestCkTileGroupedGemmPreshuffle::K_Warp_Tile;
|
||||
};
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
@@ -424,7 +373,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
// Host-side preshuffle of B
|
||||
auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]);
|
||||
auto b_shuffle_host = ck_tile::shuffle_b<BShuffleGemmConfig>(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
@@ -81,42 +81,3 @@ inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
|
||||
|
||||
return traits;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile,
|
||||
ck_tile::index_t N_Tile,
|
||||
ck_tile::index_t N_Warp)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
|
||||
N_Warp,
|
||||
N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / K_Warp_Tile,
|
||||
divisor,
|
||||
K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
|
||||
@@ -111,21 +111,30 @@ class GemmProfiler
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
struct GemmConfig
|
||||
{
|
||||
ck_tile::index_t N_Warp_Tile;
|
||||
ck_tile::index_t K_Warp_Tile;
|
||||
ck_tile::index_t N_Tile;
|
||||
ck_tile::index_t N_Warp;
|
||||
};
|
||||
|
||||
for(const auto& callable : callables)
|
||||
{
|
||||
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
|
||||
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
|
||||
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
|
||||
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
|
||||
GemmConfig gemmConfig = {};
|
||||
gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims);
|
||||
gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims);
|
||||
gemmConfig.N_Tile = std::get<1>(config.tile_dims);
|
||||
gemmConfig.N_Warp = std::get<1>(config.warp_dims);
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if(config.permuteN)
|
||||
{
|
||||
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
|
||||
return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig);
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
|
||||
return ck_tile::shuffle_b(b_k_n, gemmConfig);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user