mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Fixes after merge.
This commit is contained in:
@@ -48,14 +48,10 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<ck_tile::BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmShape,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>>;
|
||||
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
@@ -71,14 +67,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
GemmShape,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
kPadA,
|
||||
kPadB,
|
||||
kPadC,
|
||||
Traits,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
@@ -164,7 +164,13 @@ int run_gemm_example(int argc, char* argv[])
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
|
||||
@@ -18,7 +18,7 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
@@ -31,7 +31,7 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
|
||||
std::is_same_v<AccDataType, typename CBlockTensor::DataType>,
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
|
||||
@@ -195,7 +195,7 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<AccDataType>(c_block_dstr);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::AccDataType, float>)
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
@@ -45,7 +45,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::AccDataType, float>)
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -87,7 +87,7 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 0
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy>
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
|
||||
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
|
||||
|
||||
@@ -1,27 +1,25 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadA_,
|
||||
bool kPadB_,
|
||||
bool kPadC_,
|
||||
typename LayoutA_,
|
||||
typename LayoutB_,
|
||||
typename LayoutC_>
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_>
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
|
||||
using LayoutA = LayoutA_;
|
||||
using LayoutB = LayoutB_;
|
||||
using LayoutC = LayoutC_;
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -69,14 +69,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<ck_tile::BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmShape,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>>;
|
||||
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
@@ -90,14 +86,8 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
GemmShape,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
kPadA,
|
||||
kPadB,
|
||||
kPadC,
|
||||
Traits,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
Reference in New Issue
Block a user