mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge branch 'develop' into jenkins-skip-ci-non-relevant-files
This commit is contained in:
3
.github/workflows/therock-ci-linux.yml
vendored
3
.github/workflows/therock-ci-linux.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: 409f43ad9d564454bb1b23f8c8aa15d6b9d25200
|
||||
ref: 3f62012a748df3a3099c51fa95d104db643a4588 # 10-03-2025 commit
|
||||
path: "TheRock"
|
||||
|
||||
- name: Runner Health Settings
|
||||
@@ -54,6 +54,7 @@ jobs:
|
||||
|
||||
- name: Patch rocm-libraries
|
||||
run: |
|
||||
rm ./TheRock/patches/amd-mainline/rocm-libraries/0009-Use-workgroupMappingDim-in-rocroller_host.patch
|
||||
git config --global --add safe.directory '*'
|
||||
git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
|
||||
static constexpr auto DsVectorLengthSequence = generate_sequence_v2(
|
||||
[](auto i) {
|
||||
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(is_same<CLayout, DLayout>::value)
|
||||
return Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{};
|
||||
else
|
||||
@@ -253,7 +253,7 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
DsLengths[i] = out_lengths;
|
||||
|
||||
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
DsStrides[i] = {arg.StrideDs[i], 1};
|
||||
|
||||
@@ -433,8 +433,13 @@ struct CShuffleEpilogue
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
|
||||
|
||||
static_assert(MPerXdl % RowsPerLane == 0,
|
||||
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
|
||||
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM2 = RowsPerLane;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
@@ -515,32 +520,25 @@ struct CShuffleEpilogue
|
||||
// Pack 4 “rows per lane” as you already do
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// source indices in shuffle_acc: (n_idx * product(Y) + row)
|
||||
const index_t base = n_idx * c_warp_y_lengths.product();
|
||||
const index_t plane = c_warp_y_lengths.product();
|
||||
|
||||
// local lambda to fuse scale (if present) and convert
|
||||
auto emit = [&](index_t out_idx, index_t src_row) {
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
|
||||
|
||||
static_for<0, kM2, 1>{}([&](auto m_lane) {
|
||||
const int src = n_idx * plane + m_lane; // source row in this N-plane
|
||||
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[src];
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
// same linear index mapping on the permuted distribution
|
||||
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
|
||||
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
|
||||
v = static_cast<AccDataType>(v * s_m * s_n);
|
||||
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
|
||||
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
|
||||
v = static_cast<AccDataType>(v * sm * sn);
|
||||
}
|
||||
|
||||
c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
|
||||
};
|
||||
|
||||
// Your current packing pattern (rows 0..3, spaced by NRepeat)
|
||||
emit(n_idx + 0 * NRepeat, 0);
|
||||
emit(n_idx + 1 * NRepeat, 1);
|
||||
emit(n_idx + 2 * NRepeat, 2);
|
||||
emit(n_idx + 3 * NRepeat, 3);
|
||||
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
|
||||
});
|
||||
});
|
||||
|
||||
// store/update
|
||||
|
||||
Reference in New Issue
Block a user