mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Grab device and gridwise files from bkp branch, this should enable splitK support for convolution and also we no longer ForceThreadTileTransfer for explicit gemm. Also grab some updates from 7e7243783008b11e904f127ecf1df55ef95e9af2 to fix building on clang20.
This commit is contained in:
@@ -144,39 +144,18 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
end(e_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
if(split_k < 0)
|
||||
{
|
||||
if(split_k < 0)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k_ = split_k;
|
||||
}
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
split_k_ = split_k;
|
||||
}
|
||||
split_k_ = split_k;
|
||||
}
|
||||
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
@@ -339,16 +318,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if constexpr(!IsTwoStageNeeded)
|
||||
{
|
||||
if(arg.split_k_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -322,7 +322,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false,
|
||||
false>;
|
||||
false,
|
||||
true>;
|
||||
|
||||
static constexpr auto MakeElementwiseInputSequence()
|
||||
{
|
||||
|
||||
@@ -374,7 +374,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false,
|
||||
false>;
|
||||
false,
|
||||
true>;
|
||||
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
|
||||
@@ -396,7 +396,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false,
|
||||
false>;
|
||||
false,
|
||||
true>;
|
||||
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
|
||||
@@ -289,7 +289,8 @@ struct ABTransferThreadTiles
|
||||
__device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
|
||||
BlockDescriptor& block_descriptor,
|
||||
ABElementwiseOperation& ab_element_op,
|
||||
const index_t block_mn_id)
|
||||
const index_t block_mn_id,
|
||||
const index_t k_id)
|
||||
{
|
||||
constexpr index_t NumABTensor = ABsDataType::Size();
|
||||
const index_t mn_block_data_idx_on_grid =
|
||||
@@ -298,7 +299,7 @@ struct ABTransferThreadTiles
|
||||
if constexpr(NumABTensor > 1)
|
||||
{
|
||||
const auto idx_as_block_begin = generate_tuple(
|
||||
[&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
|
||||
[&](auto) { return make_multi_index(k_id, mn_block_data_idx_on_grid, 0); },
|
||||
Number<NumABTensor>{});
|
||||
|
||||
return ThreadGroupTensorSliceTransfer_v7r2<
|
||||
@@ -351,7 +352,7 @@ struct ABTransferThreadTiles
|
||||
ABThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
GlobalBufferNum>(grid_descriptor[I0],
|
||||
make_multi_index(0, mn_block_data_idx_on_grid, 0),
|
||||
make_multi_index(k_id, mn_block_data_idx_on_grid, 0),
|
||||
ab_element_op,
|
||||
block_descriptor,
|
||||
make_multi_index(0, 0, 0),
|
||||
|
||||
@@ -264,7 +264,8 @@ struct ABTransferWaveTiles
|
||||
__device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
|
||||
BlockDescriptor& block_descriptor,
|
||||
ABElementwiseOperation& ab_element_op,
|
||||
const index_t block_mn_id)
|
||||
const index_t block_mn_id,
|
||||
const index_t)
|
||||
{
|
||||
// Note: GlobalBufferNum is currently not used but it will be needed
|
||||
// once we add other pipelines. It is currently needed only for
|
||||
|
||||
@@ -176,7 +176,7 @@ template <typename ALayout,
|
||||
typename ComputeTypeB,
|
||||
bool PermuteA,
|
||||
bool PermuteB,
|
||||
bool ForceThreadTileTransfer = true>
|
||||
bool ForceThreadTileTransfer = false>
|
||||
struct GridwiseGemm_wmma_cshuffle_v3
|
||||
: GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
@@ -327,8 +327,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using typename Base::AsGridPointer;
|
||||
using typename Base::BsGridPointer;
|
||||
using typename Base::DsGridPointer;
|
||||
using AsDataType_ = AsDataType;
|
||||
using BsDataType_ = BsDataType;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
|
||||
@@ -221,8 +221,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
using typename Base::AsGridPointer;
|
||||
using typename Base::BsGridPointer;
|
||||
using typename Base::DsGridPointer;
|
||||
using AsDataType_ = AsDataType;
|
||||
using BsDataType_ = BsDataType;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
|
||||
@@ -3,11 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
@@ -789,27 +784,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
{
|
||||
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
|
||||
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
|
||||
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<EDataType>, int8_t>::value)
|
||||
{
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "int8_t does not support KBatch > 1. KBatch: " << karg.KBatch
|
||||
<< " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -873,6 +847,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
// Note: arguments k_batch and k_id should be set if splitk is used
|
||||
// with implicit gemm (no pointer shift but shift using tensor descriptors)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -899,8 +875,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const index_t& block_n_id,
|
||||
const index_t& num_k_block_per_scale,
|
||||
BScaleStruct& b_scale_struct,
|
||||
[[maybe_unused]] const index_t k_batch = 1,
|
||||
[[maybe_unused]] const index_t k_id = 0)
|
||||
const index_t k_batch = 1,
|
||||
const index_t k_id = 0)
|
||||
{
|
||||
const auto as_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -942,7 +918,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
AsDataType,
|
||||
AElementwiseOperation,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id);
|
||||
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, k_id);
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
@@ -951,7 +927,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
BsDataType,
|
||||
BElementwiseOperation,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id);
|
||||
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, k_id);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
@@ -976,7 +952,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock);
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
|
||||
|
||||
Reference in New Issue
Block a user