[rocm-libraries] ROCm/rocm-libraries#7112 (commit a6e5eac)

Add asynchronous XOR shuffle support to the Async GEMM pipeline and the MX GEMM pipeline (#7112)

## Motivation

The goal of this work is to apply XOR shuffle (swizzle) to the current
`comp_async` GEMM pipeline and the `gemm_mx` pipeline.
XOR swizzling has been helpful to avoid LDS bank conflicts, as data are
redistributed across LDS banks, such that simultaneous threads accessing
different rows land on different LDS banks.

## Technical Details

A similar approach to the work in the existing eight-waves pipeline was
followed.
Currently, XOR swizzle support is available for FP8 and BF8 types.
FP4 support is also available for MX GEMM.
Should the types not match, or should the async vector width be of an
unsupported size, then the pipeline falls through to the previously
existing ('unswizzled') path.

## Test Plan

Execute `test_ck_tile_gemm_pipeline_comp_async` for the Async GEMM
pipeline.
Execute `test_ck_tile_mx_gemm_fp8` and `test_ck_tile_mx_gemm_fp4` for
the MX GEMM pipeline.

## Test Result

The tests passed successfully in the `Alola` cluster with MI350
hardware.

## Submission Checklist

- [X] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Fernando Jiménez <fernando.jimenez@streamhpc.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
JP-Fernando
2026-05-21 09:36:41 +02:00
committed by GitHub
parent c31fc4df52
commit e7798e9560
13 changed files with 4369 additions and 3634 deletions

View File

@@ -1,329 +1,329 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "run_gemm_example_common.hpp"
#include "universal_gemm_invoker.hpp"
#include "ck_tile/core/utility/gemm_validation.hpp"
// Template function to run GEMM with optional prefetch comparison.
// GemmConfig takes (PrecType, DataCachePrefetchKind A, DataCachePrefetchKind B,
// ClusterM, ClusterN).
template <template <typename,
ck_tile::DataCachePrefetchKind,
ck_tile::DataCachePrefetchKind,
ck_tile::index_t,
ck_tile::index_t>
class GemmConfig,
ck_tile::index_t ClusterM,
ck_tile::index_t ClusterN,
typename ADataType,
typename... BCAccDataTypes>
bool run_gemm_with_prefetch_comparison(const std::string& a_layout,
const std::string& b_layout,
ck_tile::ArgParser& arg_parser,
bool compare_with_non_prefetch,
ck_tile::DataCachePrefetchKind prefetch_kind_a,
ck_tile::DataCachePrefetchKind prefetch_kind_b)
{
using Invoker = UniversalInvoker;
using Kind = ck_tile::DataCachePrefetchKind;
auto kind_str = [](Kind k) { return k == Kind::L1 ? "L1" : "L2"; };
std::cout << "\n=== Running with DataCache Prefetch ENABLED (A " << kind_str(prefetch_kind_a)
<< " / B " << kind_str(prefetch_kind_b) << ") ===\n"
<< std::endl;
bool pass_prefetch;
if(prefetch_kind_a == Kind::L1 && prefetch_kind_b == Kind::L1)
{
pass_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::L1, Kind::L1, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
}
else if(prefetch_kind_a == Kind::L1 && prefetch_kind_b == Kind::L2)
{
pass_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::L1, Kind::L2, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
}
else if(prefetch_kind_a == Kind::L2 && prefetch_kind_b == Kind::L1)
{
pass_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::L2, Kind::L1, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
}
else
{
pass_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::L2, Kind::L2, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
}
if(compare_with_non_prefetch)
{
std::cout << "\n=== Running with DataCache Prefetch DISABLED ===\n" << std::endl;
bool pass_no_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::None, Kind::None, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
std::cout << "\n=== Comparison Summary ===" << std::endl;
std::cout << "Note: Check the timing results above to compare performance." << std::endl;
std::cout << "With prefetch vs without prefetch - speedup can be observed in the "
"timing outputs."
<< std::endl;
return pass_prefetch && pass_no_prefetch;
}
return pass_prefetch;
}
// Common GEMM example runner
template <template <typename,
ck_tile::DataCachePrefetchKind,
ck_tile::DataCachePrefetchKind,
ck_tile::index_t,
ck_tile::index_t>
class GemmConfig,
ck_tile::index_t ClusterM,
ck_tile::index_t ClusterN>
int run_gemm_example_with_prefetch(ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string c_layout = arg_parser.get_str("c_layout");
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> gemm_sizes =
parse_gemm_size(arg_parser);
int m = std::get<0>(gemm_sizes);
int n = std::get<1>(gemm_sizes);
int k = std::get<2>(gemm_sizes);
int stride_a = arg_parser.get_int("stride_a");
int stride_b = arg_parser.get_int("stride_b");
int stride_c = arg_parser.get_int("stride_c");
bool compare_with_non_prefetch = arg_parser.get_int("compare") == 1;
auto prefetch_kind_a = arg_parser.get_int("prefetch_a_l1") == 1
? ck_tile::DataCachePrefetchKind::L1
: ck_tile::DataCachePrefetchKind::L2;
auto prefetch_kind_b = arg_parser.get_int("prefetch_b_l1") == 1
? ck_tile::DataCachePrefetchKind::L1
: ck_tile::DataCachePrefetchKind::L2;
ck_tile::validate_gemm_stride(
a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c);
if(data_type == "fp16")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::half_t,
ck_tile::half_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else if(data_type == "bf16")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::bf16_t,
ck_tile::bf16_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else if(data_type == "fp8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else if(data_type == "bf8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else if(data_type == "i8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::int8_t,
ck_tile::int8_t,
int32_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else
{
throw std::runtime_error("Unsupported data type for GEMM with prefetch!");
}
}
// TDM V1 GEMM Configuration with Data Cache Prefetch control
template <typename PrecType,
ck_tile::DataCachePrefetchKind DataCachePrefetchA_ = ck_tile::DataCachePrefetchKind::L2,
ck_tile::DataCachePrefetchKind DataCachePrefetchB_ = DataCachePrefetchA_,
ck_tile::index_t kClusterSizeM_ = 1,
ck_tile::index_t kClusterSizeN_ = 1>
struct GemmConfigTDMV1Prefetch : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_TDM_V1;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchA = DataCachePrefetchA_;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchB = DataCachePrefetchB_;
static constexpr ck_tile::index_t kClusterSizeM = kClusterSizeM_;
static constexpr ck_tile::index_t kClusterSizeN = kClusterSizeN_;
};
// TDM V2 GEMM Configuration with Data Cache Prefetch control
template <typename PrecType,
ck_tile::DataCachePrefetchKind DataCachePrefetchA_ = ck_tile::DataCachePrefetchKind::L2,
ck_tile::DataCachePrefetchKind DataCachePrefetchB_ = DataCachePrefetchA_,
ck_tile::index_t kClusterSizeM_ = 1,
ck_tile::index_t kClusterSizeN_ = 1>
struct GemmConfigTDMV2Prefetch : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
// TDM V2 (requires 4 waves): M_Warp * N_Warp * K_Warp == 4
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_TDM_V2;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchA = DataCachePrefetchA_;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchB = DataCachePrefetchB_;
static constexpr ck_tile::index_t kClusterSizeM = kClusterSizeM_;
static constexpr ck_tile::index_t kClusterSizeN = kClusterSizeN_;
};
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
const std::string pipeline = arg_parser.get_str("pipeline");
const bool use_cluster_2x2 = arg_parser.get_int("use_cluster_2x2") == 1;
const bool is_v2 = (pipeline == "v2");
if(!is_v2 && pipeline != "v1")
std::cerr << "Unknown pipeline '" << pipeline << "', defaulting to v1." << std::endl;
if(is_v2)
{
if(use_cluster_2x2)
return run_gemm_example_with_prefetch<GemmConfigTDMV2Prefetch, 2, 2>(arg_parser);
else
return run_gemm_example_with_prefetch<GemmConfigTDMV2Prefetch, 1, 1>(arg_parser);
}
else
{
if(use_cluster_2x2)
return run_gemm_example_with_prefetch<GemmConfigTDMV1Prefetch, 2, 2>(arg_parser);
else
return run_gemm_example_with_prefetch<GemmConfigTDMV1Prefetch, 1, 1>(arg_parser);
}
}
int main(int argc, char* argv[])
{
auto arg_parser = create_args();
arg_parser.insert(
"pipeline",
"v1",
"TDM pipeline version to use: v1 (8 waves) or v2 (4 waves, wave-specialized)");
arg_parser.insert("use_cluster_2x2",
"0",
"0: single workgroup, 1: enable 2x2 cluster launch for TDM multicast");
arg_parser.insert(
"compare",
"0",
"0: Run with data cache prefetch only, 1: Compare with/without data cache prefetch");
arg_parser.insert("prefetch_a_l1", "0", "0: Prefetch A to L2 cache, 1: Prefetch A to L1 cache");
arg_parser.insert("prefetch_b_l1", "1", "0: Prefetch B to L2 cache, 1: Prefetch B to L1 cache");
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
try
{
return !run_gemm_example(arg_parser);
}
catch(std::exception& e)
{
std::cerr << e.what() << std::endl;
return -1;
}
}
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "run_gemm_example_common.hpp"
#include "universal_gemm_invoker.hpp"
#include "ck_tile/core/utility/gemm_validation.hpp"
// Template function to run GEMM with optional prefetch comparison.
// GemmConfig takes (PrecType, DataCachePrefetchKind A, DataCachePrefetchKind B,
// ClusterM, ClusterN).
template <template <typename,
ck_tile::DataCachePrefetchKind,
ck_tile::DataCachePrefetchKind,
ck_tile::index_t,
ck_tile::index_t>
class GemmConfig,
ck_tile::index_t ClusterM,
ck_tile::index_t ClusterN,
typename ADataType,
typename... BCAccDataTypes>
bool run_gemm_with_prefetch_comparison(const std::string& a_layout,
const std::string& b_layout,
ck_tile::ArgParser& arg_parser,
bool compare_with_non_prefetch,
ck_tile::DataCachePrefetchKind prefetch_kind_a,
ck_tile::DataCachePrefetchKind prefetch_kind_b)
{
using Invoker = UniversalInvoker;
using Kind = ck_tile::DataCachePrefetchKind;
auto kind_str = [](Kind k) { return k == Kind::L1 ? "L1" : "L2"; };
std::cout << "\n=== Running with DataCache Prefetch ENABLED (A " << kind_str(prefetch_kind_a)
<< " / B " << kind_str(prefetch_kind_b) << ") ===\n"
<< std::endl;
bool pass_prefetch;
if(prefetch_kind_a == Kind::L1 && prefetch_kind_b == Kind::L1)
{
pass_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::L1, Kind::L1, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
}
else if(prefetch_kind_a == Kind::L1 && prefetch_kind_b == Kind::L2)
{
pass_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::L1, Kind::L2, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
}
else if(prefetch_kind_a == Kind::L2 && prefetch_kind_b == Kind::L1)
{
pass_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::L2, Kind::L1, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
}
else
{
pass_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::L2, Kind::L2, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
}
if(compare_with_non_prefetch)
{
std::cout << "\n=== Running with DataCache Prefetch DISABLED ===\n" << std::endl;
bool pass_no_prefetch = run_gemm_example_prec_type<
GemmConfig<ADataType, Kind::None, Kind::None, ClusterM, ClusterN>,
Invoker,
ADataType,
BCAccDataTypes...>(a_layout, b_layout, arg_parser);
std::cout << "\n=== Comparison Summary ===" << std::endl;
std::cout << "Note: Check the timing results above to compare performance." << std::endl;
std::cout << "With prefetch vs without prefetch - speedup can be observed in the "
"timing outputs."
<< std::endl;
return pass_prefetch && pass_no_prefetch;
}
return pass_prefetch;
}
// Common GEMM example runner
template <template <typename,
ck_tile::DataCachePrefetchKind,
ck_tile::DataCachePrefetchKind,
ck_tile::index_t,
ck_tile::index_t>
class GemmConfig,
ck_tile::index_t ClusterM,
ck_tile::index_t ClusterN>
int run_gemm_example_with_prefetch(ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string c_layout = arg_parser.get_str("c_layout");
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> gemm_sizes =
parse_gemm_size(arg_parser);
int m = std::get<0>(gemm_sizes);
int n = std::get<1>(gemm_sizes);
int k = std::get<2>(gemm_sizes);
int stride_a = arg_parser.get_int("stride_a");
int stride_b = arg_parser.get_int("stride_b");
int stride_c = arg_parser.get_int("stride_c");
bool compare_with_non_prefetch = arg_parser.get_int("compare") == 1;
auto prefetch_kind_a = arg_parser.get_int("prefetch_a_l1") == 1
? ck_tile::DataCachePrefetchKind::L1
: ck_tile::DataCachePrefetchKind::L2;
auto prefetch_kind_b = arg_parser.get_int("prefetch_b_l1") == 1
? ck_tile::DataCachePrefetchKind::L1
: ck_tile::DataCachePrefetchKind::L2;
ck_tile::validate_gemm_stride(
a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c);
if(data_type == "fp16")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::half_t,
ck_tile::half_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else if(data_type == "bf16")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::bf16_t,
ck_tile::bf16_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else if(data_type == "fp8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else if(data_type == "bf8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else if(data_type == "i8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ClusterM,
ClusterN,
ck_tile::int8_t,
ck_tile::int8_t,
int32_t>(a_layout,
b_layout,
arg_parser,
compare_with_non_prefetch,
prefetch_kind_a,
prefetch_kind_b);
}
else
{
throw std::runtime_error("Unsupported data type for GEMM with prefetch!");
}
}
// TDM V1 GEMM Configuration with Data Cache Prefetch control
template <typename PrecType,
ck_tile::DataCachePrefetchKind DataCachePrefetchA_ = ck_tile::DataCachePrefetchKind::L2,
ck_tile::DataCachePrefetchKind DataCachePrefetchB_ = DataCachePrefetchA_,
ck_tile::index_t kClusterSizeM_ = 1,
ck_tile::index_t kClusterSizeN_ = 1>
struct GemmConfigTDMV1Prefetch : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_TDM_V1;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchA = DataCachePrefetchA_;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchB = DataCachePrefetchB_;
static constexpr ck_tile::index_t kClusterSizeM = kClusterSizeM_;
static constexpr ck_tile::index_t kClusterSizeN = kClusterSizeN_;
};
// TDM V2 GEMM Configuration with Data Cache Prefetch control
template <typename PrecType,
ck_tile::DataCachePrefetchKind DataCachePrefetchA_ = ck_tile::DataCachePrefetchKind::L2,
ck_tile::DataCachePrefetchKind DataCachePrefetchB_ = DataCachePrefetchA_,
ck_tile::index_t kClusterSizeM_ = 1,
ck_tile::index_t kClusterSizeN_ = 1>
struct GemmConfigTDMV2Prefetch : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
// TDM V2 (requires 4 waves): M_Warp * N_Warp * K_Warp == 4
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_TDM_V2;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchA = DataCachePrefetchA_;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchB = DataCachePrefetchB_;
static constexpr ck_tile::index_t kClusterSizeM = kClusterSizeM_;
static constexpr ck_tile::index_t kClusterSizeN = kClusterSizeN_;
};
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
const std::string pipeline = arg_parser.get_str("pipeline");
const bool use_cluster_2x2 = arg_parser.get_int("use_cluster_2x2") == 1;
const bool is_v2 = (pipeline == "v2");
if(!is_v2 && pipeline != "v1")
std::cerr << "Unknown pipeline '" << pipeline << "', defaulting to v1." << std::endl;
if(is_v2)
{
if(use_cluster_2x2)
return run_gemm_example_with_prefetch<GemmConfigTDMV2Prefetch, 2, 2>(arg_parser);
else
return run_gemm_example_with_prefetch<GemmConfigTDMV2Prefetch, 1, 1>(arg_parser);
}
else
{
if(use_cluster_2x2)
return run_gemm_example_with_prefetch<GemmConfigTDMV1Prefetch, 2, 2>(arg_parser);
else
return run_gemm_example_with_prefetch<GemmConfigTDMV1Prefetch, 1, 1>(arg_parser);
}
}
int main(int argc, char* argv[])
{
auto arg_parser = create_args();
arg_parser.insert(
"pipeline",
"v1",
"TDM pipeline version to use: v1 (8 waves) or v2 (4 waves, wave-specialized)");
arg_parser.insert("use_cluster_2x2",
"0",
"0: single workgroup, 1: enable 2x2 cluster launch for TDM multicast");
arg_parser.insert(
"compare",
"0",
"0: Run with data cache prefetch only, 1: Compare with/without data cache prefetch");
arg_parser.insert("prefetch_a_l1", "0", "0: Prefetch A to L2 cache, 1: Prefetch A to L1 cache");
arg_parser.insert("prefetch_b_l1", "1", "0: Prefetch B to L2 cache, 1: Prefetch B to L1 cache");
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
try
{
return !run_gemm_example(arg_parser);
}
catch(std::exception& e)
{
std::cerr << e.what() << std::endl;
return -1;
}
}

View File

@@ -1,210 +1,210 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <string>
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "gemm_weight_preshuffle_invoker.hpp"
template <template <typename, ck_tile::DataCachePrefetchKind, ck_tile::DataCachePrefetchKind>
class GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
bool run_gemm_with_prefetch_comparison(ck_tile::ArgParser& arg_parser,
bool compare_with_non_prefetch,
ck_tile::DataCachePrefetchKind prefetch_kind_a,
ck_tile::DataCachePrefetchKind prefetch_kind_b)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Invoker = WeightPreshuffleInvoker;
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout != "R" || b_layout != "C")
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
std::cout << "\n=== Running with DataCache Prefetch ENABLED (TDM ";
std::cout << (prefetch_kind_a == ck_tile::DataCachePrefetchKind::L1 ? "L1" : "L2")
<< " / Flat ";
std::cout << (prefetch_kind_b == ck_tile::DataCachePrefetchKind::L1 ? "L1" : "L2") << ") ===\n"
<< std::endl;
using Kind = ck_tile::DataCachePrefetchKind;
bool pass_prefetch = false;
if(prefetch_kind_a == Kind::L1 && prefetch_kind_b == Kind::L1)
{
pass_prefetch = run_gemm_example_with_layouts<GemmConfig<APrecType, Kind::L1, Kind::L1>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
}
else if(prefetch_kind_a == Kind::L1 && prefetch_kind_b == Kind::L2)
{
pass_prefetch = run_gemm_example_with_layouts<GemmConfig<APrecType, Kind::L1, Kind::L2>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
}
else if(prefetch_kind_a == Kind::L2 && prefetch_kind_b == Kind::L1)
{
pass_prefetch = run_gemm_example_with_layouts<GemmConfig<APrecType, Kind::L2, Kind::L1>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
}
else
{
pass_prefetch = run_gemm_example_with_layouts<GemmConfig<APrecType, Kind::L2, Kind::L2>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
}
if(compare_with_non_prefetch)
{
std::cout << "\n=== Running with DataCache Prefetch DISABLED ===\n" << std::endl;
bool pass_no_prefetch =
run_gemm_example_with_layouts<GemmConfig<APrecType,
ck_tile::DataCachePrefetchKind::None,
ck_tile::DataCachePrefetchKind::None>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
std::cout << "\n=== Comparison Summary ===" << std::endl;
std::cout << "Note: Check the timing results above to compare performance." << std::endl;
std::cout << "With prefetch vs without prefetch - speedup can be observed in the "
"timing outputs."
<< std::endl;
return pass_prefetch && pass_no_prefetch;
}
return pass_prefetch;
}
template <template <typename, ck_tile::DataCachePrefetchKind, ck_tile::DataCachePrefetchKind>
class GemmConfig>
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
const std::string data_type = arg_parser.get_str("prec");
const bool compare_with_non_prefetch = arg_parser.get_int("compare") == 1;
const auto prefetch_kind_a = arg_parser.get_int("prefetch_l1_a") == 1
? ck_tile::DataCachePrefetchKind::L1
: ck_tile::DataCachePrefetchKind::L2;
const auto prefetch_kind_b = arg_parser.get_int("prefetch_l1_b") == 1
? ck_tile::DataCachePrefetchKind::L1
: ck_tile::DataCachePrefetchKind::L2;
if(data_type == "fp16")
{
return run_gemm_with_prefetch_comparison<GemmConfig, ck_tile::half_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else if(data_type == "bf16")
{
return run_gemm_with_prefetch_comparison<GemmConfig, ck_tile::bf16_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else if(data_type == "fp8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else if(data_type == "bf8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else if(data_type == "int4")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else
{
throw std::runtime_error("Unsupported data type for GEMM weight preshuffle TDM prefetch!");
}
}
template <typename PrecType,
ck_tile::DataCachePrefetchKind DataCachePrefetchA_ = ck_tile::DataCachePrefetchKind::None,
ck_tile::DataCachePrefetchKind DataCachePrefetchB_ = DataCachePrefetchA_>
struct GemmConfigWeightPreshuffleTDMPrefetch : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_TDM;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchA = DataCachePrefetchA_;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchB = DataCachePrefetchB_;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
int main(int argc, char* argv[])
{
auto arg_parser = create_args();
arg_parser.insert(
"compare",
"0",
"0: Run with data cache prefetch only, 1: Compare with/without data cache prefetch");
arg_parser.insert("prefetch_l1_a", "0", "0: Prefetch A to L2 cache, 1: Prefetch A to L1 cache");
arg_parser.insert("prefetch_l1_b", "1", "0: Prefetch B to L2 cache, 1: Prefetch B to L1 cache");
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
try
{
return !run_gemm_example<GemmConfigWeightPreshuffleTDMPrefetch>(arg_parser);
}
catch(std::exception& e)
{
std::cerr << e.what() << std::endl;
return -1;
}
}
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <string>
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "gemm_weight_preshuffle_invoker.hpp"
template <template <typename, ck_tile::DataCachePrefetchKind, ck_tile::DataCachePrefetchKind>
class GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
bool run_gemm_with_prefetch_comparison(ck_tile::ArgParser& arg_parser,
bool compare_with_non_prefetch,
ck_tile::DataCachePrefetchKind prefetch_kind_a,
ck_tile::DataCachePrefetchKind prefetch_kind_b)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Invoker = WeightPreshuffleInvoker;
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout != "R" || b_layout != "C")
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
std::cout << "\n=== Running with DataCache Prefetch ENABLED (TDM ";
std::cout << (prefetch_kind_a == ck_tile::DataCachePrefetchKind::L1 ? "L1" : "L2")
<< " / Flat ";
std::cout << (prefetch_kind_b == ck_tile::DataCachePrefetchKind::L1 ? "L1" : "L2") << ") ===\n"
<< std::endl;
using Kind = ck_tile::DataCachePrefetchKind;
bool pass_prefetch = false;
if(prefetch_kind_a == Kind::L1 && prefetch_kind_b == Kind::L1)
{
pass_prefetch = run_gemm_example_with_layouts<GemmConfig<APrecType, Kind::L1, Kind::L1>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
}
else if(prefetch_kind_a == Kind::L1 && prefetch_kind_b == Kind::L2)
{
pass_prefetch = run_gemm_example_with_layouts<GemmConfig<APrecType, Kind::L1, Kind::L2>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
}
else if(prefetch_kind_a == Kind::L2 && prefetch_kind_b == Kind::L1)
{
pass_prefetch = run_gemm_example_with_layouts<GemmConfig<APrecType, Kind::L2, Kind::L1>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
}
else
{
pass_prefetch = run_gemm_example_with_layouts<GemmConfig<APrecType, Kind::L2, Kind::L2>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
}
if(compare_with_non_prefetch)
{
std::cout << "\n=== Running with DataCache Prefetch DISABLED ===\n" << std::endl;
bool pass_no_prefetch =
run_gemm_example_with_layouts<GemmConfig<APrecType,
ck_tile::DataCachePrefetchKind::None,
ck_tile::DataCachePrefetchKind::None>,
Invoker,
APrecType,
BPrecType,
CPrecType>(arg_parser, Row{}, Col{}, Row{});
std::cout << "\n=== Comparison Summary ===" << std::endl;
std::cout << "Note: Check the timing results above to compare performance." << std::endl;
std::cout << "With prefetch vs without prefetch - speedup can be observed in the "
"timing outputs."
<< std::endl;
return pass_prefetch && pass_no_prefetch;
}
return pass_prefetch;
}
template <template <typename, ck_tile::DataCachePrefetchKind, ck_tile::DataCachePrefetchKind>
class GemmConfig>
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
const std::string data_type = arg_parser.get_str("prec");
const bool compare_with_non_prefetch = arg_parser.get_int("compare") == 1;
const auto prefetch_kind_a = arg_parser.get_int("prefetch_l1_a") == 1
? ck_tile::DataCachePrefetchKind::L1
: ck_tile::DataCachePrefetchKind::L2;
const auto prefetch_kind_b = arg_parser.get_int("prefetch_l1_b") == 1
? ck_tile::DataCachePrefetchKind::L1
: ck_tile::DataCachePrefetchKind::L2;
if(data_type == "fp16")
{
return run_gemm_with_prefetch_comparison<GemmConfig, ck_tile::half_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else if(data_type == "bf16")
{
return run_gemm_with_prefetch_comparison<GemmConfig, ck_tile::bf16_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else if(data_type == "fp8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else if(data_type == "bf8")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else if(data_type == "int4")
{
return run_gemm_with_prefetch_comparison<GemmConfig,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(
arg_parser, compare_with_non_prefetch, prefetch_kind_a, prefetch_kind_b);
}
else
{
throw std::runtime_error("Unsupported data type for GEMM weight preshuffle TDM prefetch!");
}
}
template <typename PrecType,
ck_tile::DataCachePrefetchKind DataCachePrefetchA_ = ck_tile::DataCachePrefetchKind::None,
ck_tile::DataCachePrefetchKind DataCachePrefetchB_ = DataCachePrefetchA_>
struct GemmConfigWeightPreshuffleTDMPrefetch : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_TDM;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchA = DataCachePrefetchA_;
static constexpr ck_tile::DataCachePrefetchKind DataCachePrefetchB = DataCachePrefetchB_;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
int main(int argc, char* argv[])
{
auto arg_parser = create_args();
arg_parser.insert(
"compare",
"0",
"0: Run with data cache prefetch only, 1: Compare with/without data cache prefetch");
arg_parser.insert("prefetch_l1_a", "0", "0: Prefetch A to L2 cache, 1: Prefetch A to L1 cache");
arg_parser.insert("prefetch_l1_b", "1", "0: Prefetch B to L2 cache, 1: Prefetch B to L1 cache");
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
try
{
return !run_gemm_example<GemmConfigWeightPreshuffleTDMPrefetch>(arg_parser);
}
catch(std::exception& e)
{
std::cerr << e.what() << std::endl;
return -1;
}
}

View File

@@ -10,6 +10,7 @@
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
#include "ck_tile/core/arch/amd_cluster_load.hpp"
#include "ck_tile/core/arch/amd_tdm_descriptor.hpp"
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
#include "ck_tile/core/arch/amd_wave_read_first_lane.hpp"
@@ -86,6 +87,7 @@
#include "ck_tile/core/numeric/pk_f6.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/scale_util.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
@@ -114,6 +116,7 @@
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/data_cache_prefetch.hpp"
#include "ck_tile/core/utility/debug.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/functional.hpp"

View File

@@ -1,15 +1,15 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
enum class DataCachePrefetchKind
{
None,
L1,
L2
};
} // namespace ck_tile
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
enum class DataCachePrefetchKind
{
None,
L1,
L2
};
} // namespace ck_tile

View File

@@ -17,6 +17,7 @@
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/pinned_host_releaser.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_contraction.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"

View File

@@ -56,8 +56,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"

View File

@@ -17,7 +17,10 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
@@ -37,6 +40,8 @@
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
@@ -85,6 +90,7 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_highprec_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp"

View File

@@ -311,6 +311,24 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
},
number<BsLayout::size()>{});
// for XOR swizzle: policy makes async global-to-LDS stores match LDS reads
// otherwise: no change to view
auto a_async_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(Policy::template MakeAsyncLoadADramWindow<Problem>(
a_tile_windows[number<idx>{}]),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
auto b_async_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(Policy::template MakeAsyncLoadBDramWindow<Problem>(
b_tile_windows[number<idx>{}]),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
// this pipeline has a pair of LDS buffers per logical tile
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
@@ -353,9 +371,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
// read A(0), B(0) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
a_copy_lds_window0, a_async_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
b_copy_lds_window0, b_async_tile_windows[number<0>{}], b_dram_tile_window_step);
// initialize block gemm
auto block_gemm = BlockGemm();
@@ -367,9 +385,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
// read A(1), B(1) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step);
a_copy_lds_window1, a_async_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
b_copy_lds_window1, b_async_tile_windows[number<0>{}], b_dram_tile_window_step);
// tile distribution for the register tiles
constexpr auto ALdsTileDistr =
@@ -433,9 +451,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
if constexpr((!HasHotLoop && (TailNum == TailNumber::Three)) || HasHotLoop)
{
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
a_copy_lds_window0, a_async_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
b_copy_lds_window0, b_async_tile_windows[number<0>{}], b_dram_tile_window_step);
}
if constexpr(HasHotLoop)
@@ -456,10 +474,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
// read A(i), B(i) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window1,
a_tile_windows[number<0>{}],
a_async_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window1,
b_tile_windows[number<0>{}],
b_async_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-3) = A(i-3) @ B(i-3)
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
@@ -477,10 +495,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
// read A(i+1), B(i+1) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window0,
a_tile_windows[number<0>{}],
a_async_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window0,
b_tile_windows[number<0>{}],
b_async_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-2) = A(i-2) @ B(i-2)
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);

View File

@@ -19,13 +19,202 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
using Base = UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy<EnableSubTile>>;
using Base::GetSmemPackA;
using Base::GetSmemPackB;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::is_a_load_tr;
using Base::is_b_load_tr;
// Async copy supports 32-bit, 96-bit, or 128-bit transfers (4, 12, 16 bytes)
template <typename DataType, index_t KPack>
static constexpr bool IsSupportedAsyncVectorWidth =
sizeof(DataType) * KPack == 4 || sizeof(DataType) * KPack == 12 ||
sizeof(DataType) * KPack == 16;
// XOR Swizzle: support FP8 / BF8
template <typename Problem>
static constexpr bool IsSupportedXorSwizzleDataType =
(std::is_same_v<remove_cvref_t<typename Problem::ADataType>, fp8_t> || // A FP8
std::is_same_v<remove_cvref_t<typename Problem::ADataType>, bf8_t>) && // A BF8
(std::is_same_v<remove_cvref_t<typename Problem::BDataType>, fp8_t> || // B FP8
std::is_same_v<remove_cvref_t<typename Problem::BDataType>, bf8_t>); // B BF8
// Check that async vector store to LDS is supported
template <typename Problem>
static constexpr bool IsSupportedXorSwizzleAsyncWidth =
IsSupportedAsyncVectorWidth<typename Problem::ADataType,
Base::template GetSmemPackA<Problem>()> &&
IsSupportedAsyncVectorWidth<typename Problem::BDataType,
Base::template GetSmemPackB<Problem>()>;
// Assume normal LDS layout, not transpose-load
template <typename Problem>
static constexpr bool UseXorSwizzle =
!Base::template is_a_load_tr<Problem> && !Base::template is_b_load_tr<Problem> &&
IsSupportedXorSwizzleDataType<Problem> && IsSupportedXorSwizzleAsyncWidth<Problem>;
// Compute the number of LDS read accesses for A or B
// IsLoadTr=true if ds_read_tr is used
template <bool IsLoadTr, typename DataType, index_t ThreadElements>
CK_TILE_HOST_DEVICE static constexpr auto CalculateWGAttrNumAccess()
{
if constexpr(IsLoadTr)
{
// Transpose-load path: ds_read_tr reads DS_READ_TR_SIZE bytes per instruction.
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(DataType) * numeric_traits<DataType>::PackedSize;
if constexpr(vector_size == ThreadElements)
return WGAttrNumAccessEnum::Single;
else if constexpr(vector_size * 2 == ThreadElements)
return WGAttrNumAccessEnum::Double;
else if constexpr(vector_size * 4 == ThreadElements)
return WGAttrNumAccessEnum::Quad;
else
return WGAttrNumAccessEnum::Invalid;
}
else
{
// Non-transpose path: ds_read_b128 reads 16 bytes per instruction
constexpr index_t bytes_per_lane =
sizeof(DataType) * ThreadElements / numeric_traits<DataType>::PackedSize;
constexpr index_t ds_read_b128_width = 16;
if constexpr(bytes_per_lane <= ds_read_b128_width)
return WGAttrNumAccessEnum::Single;
else if constexpr(bytes_per_lane <= ds_read_b128_width * 2)
return WGAttrNumAccessEnum::Double;
else if constexpr(bytes_per_lane <= ds_read_b128_width * 4)
return WGAttrNumAccessEnum::Quad;
else
return WGAttrNumAccessEnum::Invalid;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAWGAttrNumAccess()
{
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t thread_elements = WarpTile::at(I0) * WarpTile::at(I2) / get_warp_size();
return CalculateWGAttrNumAccess<Base::template is_a_load_tr<Problem>,
typename Problem::ADataType,
thread_elements>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBWGAttrNumAccess()
{
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
return CalculateWGAttrNumAccess<Base::template is_b_load_tr<Problem>,
typename Problem::BDataType,
thread_elements>();
}
// Get number of accesses
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWGAttrNumAccess()
{
constexpr auto num_access_a = GetAWGAttrNumAccess<Problem>();
constexpr auto num_access_b = GetBWGAttrNumAccess<Problem>();
if constexpr(num_access_a == WGAttrNumAccessEnum::Invalid ||
num_access_b == WGAttrNumAccessEnum::Invalid)
return WGAttrNumAccessEnum::Invalid;
else if constexpr(static_cast<index_t>(num_access_a) >= static_cast<index_t>(num_access_b))
return num_access_a;
else
return num_access_b;
}
template <typename Problem,
index_t MNPerBlock,
index_t WarpTileMN,
index_t K2,
WGAttrNumAccessEnum WGAttrNumAccess>
CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzledABLdsBlockDescriptor()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KWarps = BlockWarps::at(I2);
constexpr index_t K1 = WarpTile::at(I2) / K2;
constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2);
constexpr index_t warp_size = get_warp_size();
constexpr index_t warp_num = BlockSize / warp_size;
constexpr index_t wg_attr_num_access_v = static_cast<index_t>(WGAttrNumAccess);
static_assert(warp_num * warp_size == BlockSize, "Wrong!");
static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!");
static_assert(WGAttrNumAccess != WGAttrNumAccessEnum::Invalid,
"XOR swizzle: unsupported LDS read access count for this configuration");
constexpr index_t M4 = warp_size / wg_attr_num_access_v / K1;
constexpr index_t M3 = wg_attr_num_access_v;
constexpr index_t M2 = WarpTileMN / M4 / M3;
constexpr index_t M1 = (warp_num / Problem::NumWaveGroups) / M2;
constexpr index_t M0 = MNPerBlock / M1 / M2 / M3 / M4;
static_assert(M1 * M0 * M2 * M3 * M4 == MNPerBlock, "Wrong!");
constexpr index_t PadSize = 16;
constexpr auto desc_0 = make_naive_tensor_descriptor(
number_tuple<M2, KWarps, M1, M0, K0, M3, M4, K1, K2>{},
number_tuple<KWarps * M1 * M0 * K0 * M3 * M4 * K1 * K2 + PadSize,
M1 * M0 * K0 * M3 * M4 * K1 * K2,
M0 * K0 * M3 * M4 * K1 * K2,
K0 * M3 * M4 * K1 * K2,
M3 * M4 * K1 * K2,
M4 * K1 * K2,
K1 * K2,
K2,
1>{},
number<K2>{},
number<1>{});
constexpr auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(number<M2>{}),
make_pass_through_transform(number<KWarps>{}),
make_pass_through_transform(number<M1>{}),
make_pass_through_transform(number<M0>{}),
make_pass_through_transform(number<K0>{}),
make_pass_through_transform(number<M3>{}),
make_xor_transform(make_tuple(number<M4>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{},
sequence<6, 7>{},
sequence<8>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{},
sequence<6, 7>{},
sequence<8>{}));
constexpr auto desc_2 = transform_tensor_descriptor(
desc_1,
make_tuple(make_merge_transform_v3_division_mod(number_tuple<M0, M1, M2, M3, M4>{}),
make_merge_transform_v3_division_mod(number_tuple<KWarps, K0, K1, K2>{})),
make_tuple(sequence<3, 2, 0, 5, 6>{}, sequence<1, 4, 7, 8>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return desc_2;
}
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
#if defined(__gfx125__)
@@ -47,21 +236,38 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
}
else
{
constexpr index_t KPack = Base::template GetSmemPackA<Problem>();
if constexpr(UseXorSwizzle<Problem>)
{
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t KPack = Base::template GetSmemPackA<Problem>();
constexpr auto desc =
MakeXorSwizzledABLdsBlockDescriptor<Problem,
MPerBlock,
WarpTile::at(I0),
KPack,
GetWGAttrNumAccess<Problem>()>();
static_assert(desc.get_element_space_size() >= MPerBlock * KPerBlock,
"XOR swizzle LDS allocation must cover the A tile");
return desc;
}
else
{
constexpr index_t KPack = Base::template GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
#endif
}
@@ -88,25 +294,137 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
}
else
{
constexpr index_t KPack = Base::template GetSmemPackB<Problem>();
if constexpr(UseXorSwizzle<Problem>)
{
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t KPack = Base::template GetSmemPackB<Problem>();
constexpr auto desc =
MakeXorSwizzledABLdsBlockDescriptor<Problem,
NPerBlock,
WarpTile::at(I1),
KPack,
GetWGAttrNumAccess<Problem>()>();
static_assert(desc.get_element_space_size() >= NPerBlock * KPerBlock,
"XOR swizzle LDS allocation must cover the B tile");
return desc;
}
else
{
constexpr index_t KPack = Base::template GetSmemPackB<Problem>();
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NPerBlock>{}),
make_merge_transform(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
#endif
}
template <typename Problem, index_t K2, WGAttrNumAccessEnum WGAttrNumAccess, typename Window>
CK_TILE_DEVICE static constexpr auto MakeAsyncLoadABDramWindow(const Window& window)
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr auto ndims = std::decay_t<decltype(window)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KWarps = BlockWarps::at(I2);
constexpr index_t K1 = WarpTile::at(I2) / K2;
static_assert(K1 * K2 == WarpTile::at(I2), "Wrong!");
static_assert(KPerBlock % (KWarps * K1 * K2) == 0, "Wrong!");
constexpr index_t wg_attr_num_access_v = static_cast<index_t>(WGAttrNumAccess);
constexpr index_t M4 = get_warp_size() / wg_attr_num_access_v / K1;
static_assert(get_warp_size() % (wg_attr_num_access_v * K1) == 0,
"warp_size must be divisible by (wg_attr_num_access_v * K1)");
auto&& tensor_view = window.get_bottom_tensor_view();
const auto [rows, cols] = tensor_view.get_tensor_descriptor().get_lengths();
const index_t k_tiles = cols / (KWarps * K1 * K2);
const auto col_lens = make_tuple(k_tiles, number<KWarps>{}, number<K1>{}, number<K2>{});
const index_t M0 = integer_divide_ceil(rows, M4);
const auto row_lens = make_tuple(M0, number<M4>{});
const auto desc_0 = transform_tensor_descriptor(
tensor_view.get_tensor_descriptor(),
make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),
make_xor_transform(make_tuple(number<M4>{}, number<K1>{})),
make_pass_through_transform(k_tiles),
make_pass_through_transform(number<KWarps>{}),
make_pass_through_transform(number<K2>{})),
make_tuple(
sequence<0>{}, sequence<1, 4>{}, sequence<2>{}, sequence<3>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1, 4>{}, sequence<2>{}, sequence<3>{}, sequence<5>{}));
const auto desc =
transform_tensor_descriptor(desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tile_window(
make_tensor_view<address_space_enum::global>(&tensor_view.get_buffer_view()(0), desc),
window.get_window_lengths(),
window.get_window_origin());
}
template <typename Problem, typename Window>
CK_TILE_DEVICE static constexpr auto MakeAsyncLoadADramWindow(const Window& window)
{
if constexpr(UseXorSwizzle<Problem>)
{
constexpr index_t KPack = Base::template GetSmemPackA<Problem>();
return MakeAsyncLoadABDramWindow<Problem, KPack, GetWGAttrNumAccess<Problem>()>(window);
}
else
{
return make_tile_window(window.get_bottom_tensor_view(),
window.get_window_lengths(),
window.get_window_origin());
}
}
template <typename Problem, typename Window>
CK_TILE_DEVICE static constexpr auto MakeAsyncLoadBDramWindow(const Window& window)
{
if constexpr(UseXorSwizzle<Problem>)
{
constexpr index_t KPack = Base::template GetSmemPackB<Problem>();
return MakeAsyncLoadABDramWindow<Problem, KPack, GetWGAttrNumAccess<Problem>()>(window);
}
else
{
return make_tile_window(window.get_bottom_tensor_view(),
window.get_window_lengths(),
window.get_window_origin());
}
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetEstimatedVgprCount()
{
@@ -167,20 +485,7 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
#if defined(__gfx950__)
constexpr index_t vector_size =
DS_READ_TR_SIZE() / sizeof(typename Problem::AComputeDataType);
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
constexpr auto wg_attr_num_access =
!(Base::template is_a_load_tr<Problem> || Base::template is_b_load_tr<Problem>)
? WGAttrNumAccessEnum::Single
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
#else
constexpr auto wg_attr_num_access = WGAttrNumAccessEnum::Default;
#endif
constexpr auto wg_attr_num_access = GetWGAttrNumAccess<Problem>();
constexpr auto pipeline_tune_params = GetPipelineSubTileNum<Problem>();
constexpr index_t sub_tile_num = EnableSubTile ? pipeline_tune_params.value : 1;

View File

@@ -316,6 +316,24 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
},
number<BsLayout::size()>{});
// for XOR swizzle: policy makes async global-to-LDS stores match LDS reads
// otherwise: no change to view
auto a_async_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(Policy::template MakeAsyncLoadADramWindow<Problem>(
a_tile_windows[number<idx>{}]),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
auto b_async_tile_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(Policy::template MakeAsyncLoadBDramWindow<Problem>(
b_tile_windows[number<idx>{}]),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
////////////// MX Scale windows (pre-packed int32_t) /////////////////
// Get WarpGemm configuration
using BlockWarps = typename BlockGemmShape::BlockWarps;
@@ -404,9 +422,9 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
// read A(0), B(0) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
a_copy_lds_window0, a_async_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
b_copy_lds_window0, b_async_tile_windows[number<0>{}], b_dram_tile_window_step);
// Initialize block gemm and C block tile
auto block_gemm = BlockGemm();
@@ -416,9 +434,9 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
// read A(1), B(1) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step);
a_copy_lds_window1, a_async_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
b_copy_lds_window1, b_async_tile_windows[number<0>{}], b_dram_tile_window_step);
// tile distribution for the register tiles
constexpr auto ALdsTileDistr =
@@ -442,10 +460,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
(KPerBlock * sizeof(BDataType) / BPackedSize) *
MWarp / BlockSize,
"BLdsTile size is wrong!");
static_assert(Policy::template GetSmemSizeA<Problem>() ==
static_assert(Policy::template GetSmemSizeA<Problem>() >=
MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize),
"SmemSizeA size is wrong!");
static_assert(Policy::template GetSmemSizeB<Problem>() ==
static_assert(Policy::template GetSmemSizeB<Problem>() >=
(KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock,
"SmemSizeB size is wrong!");
@@ -522,9 +540,9 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
// read A(2), B(2) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
a_copy_lds_window0, a_async_tile_windows[number<0>{}], a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
b_copy_lds_window0, b_async_tile_windows[number<0>{}], b_dram_tile_window_step);
// Load scales for iteration 0 (ping)
load_scales_from_dram(scale_a_tile_ping, scale_b_tile_ping);
@@ -553,10 +571,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
// read A(i), B(i) from DRAM to LDS window(1)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window1,
a_tile_windows[number<0>{}],
a_async_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window1,
b_tile_windows[number<0>{}],
b_async_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-3) = A(i-3) @ B(i-3) with MX scaling
block_gemm(c_block_tile,
@@ -580,10 +598,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
// read A(i+1), B(i+1) from DRAM to LDS window(0)
// and advance the DRAM windows
Base::GlobalPrefetchAsync(a_copy_lds_window0,
a_tile_windows[number<0>{}],
a_async_tile_windows[number<0>{}],
a_dram_tile_window_step);
Base::GlobalPrefetchAsync(b_copy_lds_window0,
b_tile_windows[number<0>{}],
b_async_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-2) = A(i-2) @ B(i-2) with MX scaling
block_gemm(c_block_tile,

View File

@@ -22,6 +22,225 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
// Async copy supports 32-bit, 96-bit, or 128-bit transfers (4, 12, 16 bytes)
// Take PackedSize into consideration (for example for FP4 support)
template <typename DataType, index_t KPack>
static constexpr index_t AsyncVectorBytes =
sizeof(DataType) * KPack / numeric_traits<remove_cvref_t<DataType>>::PackedSize;
template <typename DataType, index_t KPack>
static constexpr bool IsSupportedAsyncVectorWidth =
AsyncVectorBytes<DataType, KPack> == 4 || AsyncVectorBytes<DataType, KPack> == 12 ||
AsyncVectorBytes<DataType, KPack> == 16;
template <typename DataType>
static constexpr bool IsF8XorSwizzleDataType =
std::is_same_v<remove_cvref_t<DataType>, fp8_t> ||
std::is_same_v<remove_cvref_t<DataType>, bf8_t>;
template <typename DataType>
static constexpr bool IsFP4XorSwizzleDataType =
std::is_same_v<remove_cvref_t<DataType>, pk_fp4_t>;
// XOR Swizzle: support F8/F8 and FP4/FP4. Mixed F8/FP4 stays on the plain path.
template <typename Problem>
static constexpr bool IsSupportedXorSwizzleDataType =
(IsF8XorSwizzleDataType<typename Problem::ADataType> &&
IsF8XorSwizzleDataType<typename Problem::BDataType>) ||
(IsFP4XorSwizzleDataType<typename Problem::ADataType> &&
IsFP4XorSwizzleDataType<typename Problem::BDataType>);
// FP4 needs the XOR KPack in logical elements
// so the async transaction remains 16 bytes
template <typename DataType, index_t SmemPack>
static constexpr index_t GetXorSwizzleKPack()
{
return SmemPack * numeric_traits<remove_cvref_t<DataType>>::PackedSize;
}
template <typename Problem>
static constexpr index_t GetXorSwizzleKPackA()
{
return GetXorSwizzleKPack<typename Problem::ADataType, GetSmemPackA<Problem>()>();
}
template <typename Problem>
static constexpr index_t GetXorSwizzleKPackB()
{
return GetXorSwizzleKPack<typename Problem::BDataType, GetSmemPackB<Problem>()>();
}
// Check that async vector store to LDS is supported
template <typename Problem>
static constexpr bool IsSupportedXorSwizzleAsyncWidth =
IsSupportedAsyncVectorWidth<typename Problem::ADataType, GetXorSwizzleKPackA<Problem>()> &&
IsSupportedAsyncVectorWidth<typename Problem::BDataType, GetXorSwizzleKPackB<Problem>()>;
// gfx950 scales:16x16x128 warp tile, 16-element smem pack, KWarps==1
template <typename Problem>
static constexpr bool IsSupportedXorSwizzleShape = []() {
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
return Problem::NumWaveGroups == 1 && BlockWarps::at(number<2>{}) == 1 &&
WarpTile::at(number<0>{}) == 16 && WarpTile::at(number<1>{}) == 16 &&
WarpTile::at(number<2>{}) == 128 && GetSmemPackA<Problem>() == 16 &&
GetSmemPackB<Problem>() == 16;
}();
// Assume normal LDS layout, not transpose-load
template <typename Problem>
static constexpr bool UseXorSwizzle =
!is_a_load_tr<Problem> && !is_b_load_tr<Problem> &&
IsSupportedXorSwizzleDataType<Problem> && IsSupportedXorSwizzleAsyncWidth<Problem> &&
IsSupportedXorSwizzleShape<Problem>;
template <typename Problem, index_t MNPerBlock, index_t K2>
CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzleABDramTileDistribution()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KWarps = BlockWarps::at(I2);
constexpr index_t K1 = WarpTile::at(I2) / K2;
constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2);
constexpr index_t warp_size = get_warp_size();
constexpr index_t warp_num = BlockSize / warp_size;
static_assert(KWarps == 1, "MX XOR swizzle currently supports KWarps == 1");
static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!");
constexpr index_t M2 = warp_size / K1;
constexpr index_t M1 = warp_num / Problem::NumWaveGroups;
constexpr index_t M0 = MNPerBlock / (M1 * M2);
static_assert(M0 * M1 * M2 == MNPerBlock, "Wrong!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
if constexpr(UseXorSwizzle<Problem>)
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPack = GetXorSwizzleKPackA<Problem>();
return MakeXorSwizzleABDramTileDistribution<Problem, MPerBlock, KPack>();
}
else
{
return UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>::
template MakeADramTileDistribution<Problem>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
if constexpr(UseXorSwizzle<Problem>)
{
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPack = GetXorSwizzleKPackB<Problem>();
return MakeXorSwizzleABDramTileDistribution<Problem, NPerBlock, KPack>();
}
else
{
return UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>::
template MakeBDramTileDistribution<Problem>();
}
}
template <typename Problem,
index_t MNPerBlock,
index_t WarpTileMN,
index_t K2,
WGAttrNumAccessEnum WGAttrNumAccess>
CK_TILE_HOST_DEVICE static constexpr auto MakeXorSwizzledABLdsBlockDescriptor()
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KWarps = BlockWarps::at(I2);
constexpr index_t K1 = WarpTile::at(I2) / K2;
constexpr index_t K0 = KPerBlock / (KWarps * K1 * K2);
constexpr index_t warp_size = get_warp_size();
constexpr index_t warp_num = BlockSize / warp_size;
constexpr index_t wg_attr_num_access_v = static_cast<index_t>(WGAttrNumAccess);
static_assert(warp_num * warp_size == BlockSize, "Wrong!");
static_assert(KWarps * K0 * K1 * K2 == KPerBlock, "Wrong!");
static_assert(KWarps == 1, "MX XOR swizzle currently supports KWarps == 1");
static_assert(wg_attr_num_access_v == 1 || wg_attr_num_access_v == 2,
"MX XOR swizzle currently supports FP8, BF8, FP4");
constexpr index_t K2Pad = K2 < 16 ? 16 : K2;
constexpr index_t M3 = 4;
constexpr index_t M2 = warp_size / K1 / M3;
constexpr index_t M1 = WarpTileMN / (M2 * M3);
constexpr index_t M0 = MNPerBlock / (M1 * M2 * M3);
static_assert(M0 * M1 * M2 * M3 == MNPerBlock, "Wrong!");
constexpr index_t PadSize = 4 * K2;
constexpr auto desc_0 = make_naive_tensor_descriptor(
number_tuple<M0, K0, M1, M2, M3, K1, K2>{},
number_tuple<K0*(M1 * (M2 * M3 * K1 * K2Pad) + (M1 - 1) * PadSize),
M1*(M2 * M3 * K1 * K2Pad) + (M1 - 1) * PadSize,
M2 * M3 * K1 * K2Pad + PadSize,
M3 * K1 * K2Pad,
K1 * K2Pad,
K2Pad,
1>{},
number<K2>{},
number<1>{});
constexpr auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(number<M0>{}),
make_pass_through_transform(number<K0>{}),
make_pass_through_transform(number<M1>{}),
make_pass_through_transform(number<M2>{}),
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}));
constexpr auto desc_2 = transform_tensor_descriptor(
desc_1,
make_tuple(make_merge_transform_v3_division_mod(number_tuple<M0, M1, M2, M3>{}),
make_merge_transform_v3_division_mod(number_tuple<K0, K1, K2>{})),
make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return desc_2;
}
// MX scaling configuration: each e8m0 scale covers 32 elements in K
static constexpr int BlockScaleSize = 32;
@@ -43,21 +262,38 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
}
else
{
constexpr index_t KPack = GetSmemPackA<Problem>();
if constexpr(UseXorSwizzle<Problem>)
{
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t KPack = GetXorSwizzleKPackA<Problem>();
constexpr auto desc =
MakeXorSwizzledABLdsBlockDescriptor<Problem,
MPerBlock,
WarpTile::at(I0),
KPack,
GetWGAttrNumAccess<Problem>()>();
static_assert(desc.get_element_space_size() >= MPerBlock * KPerBlock,
"XOR swizzle LDS allocation must cover the A tile");
return desc;
}
else
{
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
}
@@ -78,21 +314,172 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
}
else
{
constexpr index_t KPack = GetSmemPackB<Problem>();
if constexpr(UseXorSwizzle<Problem>)
{
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr index_t KPack = GetXorSwizzleKPackB<Problem>();
constexpr auto desc =
MakeXorSwizzledABLdsBlockDescriptor<Problem,
NPerBlock,
WarpTile::at(I1),
KPack,
GetWGAttrNumAccess<Problem>()>();
static_assert(desc.get_element_space_size() >= NPerBlock * KPerBlock,
"XOR swizzle LDS allocation must cover the B tile");
return desc;
}
else
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
return transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NPerBlock>{}),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NPerBlock>{}),
make_merge_transform(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
}
// MX GEMM: Double access for FP8/BF8, Single for FP4
template <typename DataType_>
CK_TILE_HOST_DEVICE static constexpr auto CalculateWGAttrNumAccess()
{
using DataType = remove_cvref_t<DataType_>;
if constexpr(std::is_same_v<DataType, fp8_t> || std::is_same_v<DataType, bf8_t>)
{
return WGAttrNumAccessEnum::Double;
}
else if constexpr(std::is_same_v<DataType, pk_fp4_t>)
{
return WGAttrNumAccessEnum::Single;
}
else
{
static_assert(sizeof(DataType) == 0,
"CalculateWGAttrNumAccess(): unsupported data type");
return WGAttrNumAccessEnum::Invalid;
}
}
// Get number of accesses
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWGAttrNumAccess()
{
constexpr auto num_access_a = CalculateWGAttrNumAccess<typename Problem::ADataType>();
constexpr auto num_access_b = CalculateWGAttrNumAccess<typename Problem::BDataType>();
if constexpr(static_cast<index_t>(num_access_a) >= static_cast<index_t>(num_access_b))
{
return num_access_a;
}
else
{
return num_access_b;
}
}
template <typename Problem, index_t K2, WGAttrNumAccessEnum WGAttrNumAccess, typename Window>
CK_TILE_DEVICE static constexpr auto MakeAsyncLoadABDramWindow(const Window& window)
{
using BlockGemmShape = typename Problem::BlockGemmShape;
using BlockWarps = typename BlockGemmShape::BlockWarps;
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr auto ndims = std::decay_t<decltype(window)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KWarps = BlockWarps::at(I2);
constexpr index_t K1 = WarpTile::at(I2) / K2;
static_assert(K1 * K2 == WarpTile::at(I2), "Wrong!");
static_assert(KPerBlock % (KWarps * K1 * K2) == 0, "Wrong!");
constexpr index_t wg_attr_num_access_v = static_cast<index_t>(WGAttrNumAccess);
constexpr index_t M4 = 4; // same as MakeXorSwizzledABLdsBlockDescriptor::M3
static_assert(get_warp_size() % (wg_attr_num_access_v * K1 * M4) == 0,
"warp_size must be divisible by (wg_attr_num_access_v * K1 * M4)");
auto&& tensor_view = window.get_bottom_tensor_view();
const auto [rows, cols] = tensor_view.get_tensor_descriptor().get_lengths();
const index_t k_tiles = cols / (KWarps * K1 * K2);
const auto col_lens = make_tuple(k_tiles, number<KWarps>{}, number<K1>{}, number<K2>{});
const index_t M0 = integer_divide_ceil(rows, M4);
const auto row_lens = make_tuple(M0, number<M4>{});
const auto desc_0 = transform_tensor_descriptor(
tensor_view.get_tensor_descriptor(),
make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),
make_xor_transform(make_tuple(number<M4>{}, number<K1>{})),
make_pass_through_transform(k_tiles),
make_pass_through_transform(number<KWarps>{}),
make_pass_through_transform(number<K2>{})),
make_tuple(
sequence<0>{}, sequence<1, 4>{}, sequence<2>{}, sequence<3>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1, 4>{}, sequence<2>{}, sequence<3>{}, sequence<5>{}));
const auto desc =
transform_tensor_descriptor(desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tile_window(
make_tensor_view<address_space_enum::global>(&tensor_view.get_buffer_view()(0), desc),
window.get_window_lengths(),
window.get_window_origin());
}
template <typename Problem, typename Window>
CK_TILE_DEVICE static constexpr auto MakeAsyncLoadADramWindow(const Window& window)
{
if constexpr(UseXorSwizzle<Problem>)
{
constexpr index_t KPack = GetXorSwizzleKPackA<Problem>();
return MakeAsyncLoadABDramWindow<Problem, KPack, GetWGAttrNumAccess<Problem>()>(window);
}
else
{
return make_tile_window(window.get_bottom_tensor_view(),
window.get_window_lengths(),
window.get_window_origin());
}
}
template <typename Problem, typename Window>
CK_TILE_DEVICE static constexpr auto MakeAsyncLoadBDramWindow(const Window& window)
{
if constexpr(UseXorSwizzle<Problem>)
{
constexpr index_t KPack = GetXorSwizzleKPackB<Problem>();
return MakeAsyncLoadABDramWindow<Problem, KPack, GetWGAttrNumAccess<Problem>()>(window);
}
else
{
return make_tile_window(window.get_bottom_tensor_view(),
window.get_window_lengths(),
window.get_window_origin());
}
}
@@ -107,10 +494,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
using CDataType = typename Problem::CDataType;
// FP4 and FP8 require different layouts for the scaled mfma instructions
constexpr auto wg_attr_num_access =
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<BDataType, fp8_t>)
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
constexpr auto wg_attr_num_access = GetWGAttrNumAccess<Problem>();
using WarpGemm = WarpGemmDispatcher<ADataType,
BDataType,