355 update

This commit is contained in:
kyle-256
2025-12-22 08:49:06 +00:00
parent 56aa5385c6
commit 5d34251e99
8 changed files with 137 additions and 31 deletions

View File

@@ -396,6 +396,12 @@ if(ENABLE_ASM_DUMP)
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
endif()
option(ENABLE_KERNEL_RESOURCE_USAGE "Print kernel resource usage (VGPR, SGPR, scratch) during compilation." OFF)
if(ENABLE_KERNEL_RESOURCE_USAGE)
add_compile_options(-Rpass-analysis=kernel-resource-usage)
message("CK compiled with ENABLE_KERNEL_RESOURCE_USAGE set to ${ENABLE_KERNEL_RESOURCE_USAGE}")
endif()
if (ENABLE_JSON_DUMP)
add_compile_definitions(CK_ENABLE_JSON_DUMP)
message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}")

View File

@@ -279,6 +279,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, std::
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Col{});
}
else if(a_layout == "R" && b_layout == "R" && c_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Row{}, Col{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A, B and C tensors!");

View File

@@ -67,8 +67,8 @@ struct GemmConfigBase
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;

View File

@@ -45,7 +45,10 @@ struct GemmPipelineAgBgCrImplBase
else
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
}();
// TEMP: Disable transpose load for B matrix to test RRR performance regression
static constexpr bool is_b_load_tr = false;
#if 0 // Original code - disabled for testing
static constexpr bool is_b_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
@@ -57,6 +60,7 @@ struct GemmPipelineAgBgCrImplBase
else
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
}();
#endif
#else
static constexpr bool is_a_load_tr = false;
static constexpr bool is_b_load_tr = false;

View File

@@ -57,7 +57,11 @@ struct UniversalGemmBasePolicy
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
tensor_layout::gemm::ColumnMajor>;
}();
// TEMP: Disable transpose load for B matrix to test RRR performance regression
template <typename Problem>
static constexpr bool is_b_load_tr = false;
#if 0 // Original code - disabled for testing
template <typename Problem>
static constexpr bool is_b_load_tr = []() {
using BDataType = remove_cvref_t<typename Problem::BDataType>;
@@ -73,6 +77,7 @@ struct UniversalGemmBasePolicy
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
tensor_layout::gemm::RowMajor>;
}();
#endif
#else
template <typename Problem>
static constexpr bool is_a_load_tr = false;

View File

@@ -78,6 +78,21 @@ class GemmKernelBuilder:
persistent,
) = trait_combo
# Validate that this tile_config is valid for this specific pipeline
if not self._validate_tile_config(
tile_config['tile_m'],
tile_config['tile_n'],
tile_config['tile_k'],
tile_config['warp_m'],
tile_config['warp_n'],
tile_config['warp_k'],
tile_config['warp_tile_m'],
tile_config['warp_tile_n'],
tile_config['warp_tile_k'],
pipeline,
):
continue
# Create kernel name with proper boolean capitalization
kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
@@ -160,14 +175,18 @@ class GemmKernelBuilder:
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
# Generate all combinations
default_pipeline = ""
if self.kernel_name_prefix == "gemm_universal":
default_pipeline = "compv4"
elif self.kernel_name_prefix == "gemm_multi_d":
default_pipeline = "compv4"
elif self.kernel_name_prefix == "gemm_preshuffle":
default_pipeline = "preshufflev2"
# Get pipelines from trait config to validate against
trait_config = self.config.get("trait_config", {})
pipelines = trait_config.get("pipeline", {}).get("values", [])
# Fallback to default pipeline if no pipelines in config
if not pipelines:
if self.kernel_name_prefix == "gemm_universal":
pipelines = ["compv4"]
elif self.kernel_name_prefix == "gemm_multi_d":
pipelines = ["compv4"]
elif self.kernel_name_prefix == "gemm_preshuffle":
pipelines = ["preshufflev2"]
configs = []
for tile_m in tile_m_values:
@@ -179,19 +198,25 @@ class GemmKernelBuilder:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
default_pipeline,
):
# Validate configuration against any pipeline
# A tile config is valid if it works for at least one pipeline
is_valid = False
for pipeline in pipelines:
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline,
):
is_valid = True
break
if is_valid:
configs.append(
{
"tile_m": tile_m,

View File

@@ -63,19 +63,16 @@
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
"intrawave"
]
},
"epilogue": {
"values": [
"cshuffle",
"default"
"cshuffle"
]
},
"pad_m": {
@@ -95,7 +92,6 @@
},
"persistent": {
"values": [
false,
true
]
}

View File

@@ -0,0 +1,62 @@
{
"tile_config": {
"tile_m": {
"values": [128, 256]
},
"tile_n": {
"values": [128, 256]
},
"tile_k": {
"values": [64]
},
"warp_m": {
"values": [1, 2, 4]
},
"warp_n": {
"values": [1, 2, 4]
},
"warp_k": {
"values": [1]
},
"warp_tile_m": {
"values": [16, 32]
},
"warp_tile_n": {
"values": [16, 32]
},
"warp_tile_k": {
"values": [16, 32]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [false]
},
"pad_n": {
"values": [false]
},
"pad_k": {
"values": [false]
},
"persistent": {
"values": [true]
}
},
"k_block_per_cu": 1
}