mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
355 update
This commit is contained in:
@@ -396,6 +396,12 @@ if(ENABLE_ASM_DUMP)
|
||||
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
|
||||
endif()
|
||||
|
||||
option(ENABLE_KERNEL_RESOURCE_USAGE "Print kernel resource usage (VGPR, SGPR, scratch) during compilation." OFF)
|
||||
if(ENABLE_KERNEL_RESOURCE_USAGE)
|
||||
add_compile_options(-Rpass-analysis=kernel-resource-usage)
|
||||
message("CK compiled with ENABLE_KERNEL_RESOURCE_USAGE set to ${ENABLE_KERNEL_RESOURCE_USAGE}")
|
||||
endif()
|
||||
|
||||
if (ENABLE_JSON_DUMP)
|
||||
add_compile_definitions(CK_ENABLE_JSON_DUMP)
|
||||
message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}")
|
||||
|
||||
@@ -279,6 +279,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, std::
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Row{}, Col{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R" && c_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Row{}, Col{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A, B and C tensors!");
|
||||
|
||||
@@ -67,8 +67,8 @@ struct GemmConfigBase
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
|
||||
@@ -45,7 +45,10 @@ struct GemmPipelineAgBgCrImplBase
|
||||
else
|
||||
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
}();
|
||||
|
||||
|
||||
// TEMP: Disable transpose load for B matrix to test RRR performance regression
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
#if 0 // Original code - disabled for testing
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
@@ -57,6 +60,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
else
|
||||
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
}();
|
||||
#endif
|
||||
#else
|
||||
static constexpr bool is_a_load_tr = false;
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
|
||||
@@ -57,7 +57,11 @@ struct UniversalGemmBasePolicy
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
|
||||
tensor_layout::gemm::ColumnMajor>;
|
||||
}();
|
||||
|
||||
|
||||
// TEMP: Disable transpose load for B matrix to test RRR performance regression
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
#if 0 // Original code - disabled for testing
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -73,6 +77,7 @@ struct UniversalGemmBasePolicy
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
|
||||
tensor_layout::gemm::RowMajor>;
|
||||
}();
|
||||
#endif
|
||||
#else
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr = false;
|
||||
|
||||
@@ -78,6 +78,21 @@ class GemmKernelBuilder:
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Validate that this tile_config is valid for this specific pipeline
|
||||
if not self._validate_tile_config(
|
||||
tile_config['tile_m'],
|
||||
tile_config['tile_n'],
|
||||
tile_config['tile_k'],
|
||||
tile_config['warp_m'],
|
||||
tile_config['warp_n'],
|
||||
tile_config['warp_k'],
|
||||
tile_config['warp_tile_m'],
|
||||
tile_config['warp_tile_n'],
|
||||
tile_config['warp_tile_k'],
|
||||
pipeline,
|
||||
):
|
||||
continue
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
@@ -160,14 +175,18 @@ class GemmKernelBuilder:
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
|
||||
|
||||
# Generate all combinations
|
||||
default_pipeline = ""
|
||||
if self.kernel_name_prefix == "gemm_universal":
|
||||
default_pipeline = "compv4"
|
||||
elif self.kernel_name_prefix == "gemm_multi_d":
|
||||
default_pipeline = "compv4"
|
||||
elif self.kernel_name_prefix == "gemm_preshuffle":
|
||||
default_pipeline = "preshufflev2"
|
||||
# Get pipelines from trait config to validate against
|
||||
trait_config = self.config.get("trait_config", {})
|
||||
pipelines = trait_config.get("pipeline", {}).get("values", [])
|
||||
|
||||
# Fallback to default pipeline if no pipelines in config
|
||||
if not pipelines:
|
||||
if self.kernel_name_prefix == "gemm_universal":
|
||||
pipelines = ["compv4"]
|
||||
elif self.kernel_name_prefix == "gemm_multi_d":
|
||||
pipelines = ["compv4"]
|
||||
elif self.kernel_name_prefix == "gemm_preshuffle":
|
||||
pipelines = ["preshufflev2"]
|
||||
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
@@ -179,19 +198,25 @@ class GemmKernelBuilder:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
default_pipeline,
|
||||
):
|
||||
# Validate configuration against any pipeline
|
||||
# A tile config is valid if it works for at least one pipeline
|
||||
is_valid = False
|
||||
for pipeline in pipelines:
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline,
|
||||
):
|
||||
is_valid = True
|
||||
break
|
||||
if is_valid:
|
||||
configs.append(
|
||||
{
|
||||
"tile_m": tile_m,
|
||||
|
||||
@@ -63,19 +63,16 @@
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle",
|
||||
"default"
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
@@ -95,7 +92,6 @@
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [128, 256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [128, 256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [1, 2, 4]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [1, 2, 4]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [1]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [16, 32]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [16, 32]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [16, 32]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [false]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [false]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [false]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [true]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1
|
||||
}
|
||||
Reference in New Issue
Block a user