mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Wmma support for grouped convolution bwd weight (#2947)
* Convolution bwd weight device implementation
* Merge branch 'grouped_conv_bwd_weight_device_impl_wmma' into 'feature/conv_bwd_weight_wmma'
Convolution bwd weight device implementation
See merge request amd/ai/composable_kernel!38
* Fix bug and disable splitK=-1 tests for wmma
* Add generic instances for bf16 f32 bf16
* check gridwise level validity in device impl for 1 stage D0
* Fix bugs in device implementation:
- rdna3 compilation error
- gridwise layouts (need to be correct to ensure that CheckValidaity()
works correctly)
* Add padding in conv to gemm transformers for 1x1Stride1Pad0 specialization
* Remove workaround for 1x1Stride1Pad0 conv specialization
* Add instances for xdl parity (for pipeline v1)
* Add two stage instances (xdl parity)
* Add multiple Ds instances
* Add examples
* Uncomment scale instances
* Fix copyright
* Fix examples compilation
* Add atomic add float4
* Fix compilation error
* Fix instances
* Compute tolerances in examples instead of using default ones
* Compute tolerances instead of using default ones in bilinear and scale tests
* Merge branch 'grouped_conv_bwd_weight_instances_examples' into 'feature/conv_bwd_weight_wmma'
Grouped conv: Instances and example bwd weight
See merge request amd/ai/composable_kernel!47
* Device implementation of explicit gemm for grouped conv bwd weight
Based on batched gemm multiple D
* Add instances for pipeline v1 and v3
* Add support for occupancy-based splitk
* Fix ckProfiler dependencies
* Review fixes
* Merge branch 'explicit_bwd_weight' into 'feature/conv_bwd_weight_wmma'
Device implementation of explicit gemm for grouped conv bwd weight
See merge request amd/ai/composable_kernel!52
* Fix cmake file for tests
* fix clang format
* fix instance factory error
* Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test.
* Revert "Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test."
This reverts commit d20c869d3d.
* Disable splitk for 2stage xdl on rdna (bug to be fixed)
* Fix add_test_executable
* Always ForceThreadTileTransfer for now, WaveTileTransfer does not work for convolution yet.
* Grab device and gridwise files from bkp branch, this should enable splitK support for convolution and also we no longer ForceThreadTileTransfer for explicit gemm. Also grab some updates from 7e7243783008b11e904f127ecf1df55ef95e9af2 to fix building on clang20.
* Fix bug in various bwd wei device implementations / profiler where the occupancy based split_k value could not be found because the Argument did not derive from ArgumentSplitK, leading to incorrect error tolerances.
* Actually print the reason when a device implementation is not supported.
* Print number of valid instances in profiler and tests.
* Fix clang format for Two Stage implementation
* Fix copyright
* Address review comments
* Fix explicit conv bwd weight struct
* Fix gridwise common
* Fix gridwise ab scale
* Remove autodeduce 1 stage
* Restore example tolerance calculation
* Fix compilation error
* Fix gridwise common
* Fix gridwise gemm
* Fix typo
* Fix splitk
* Fix splitk ab scale
* Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test.
* Reduce instances to only the tuned wmma V3 ones for implicit v1 intra and explicit v1 intra pad/nopad.
* Add explicit oddMN support with custom tuned instances
* Add two stage instances based on the parameters from the tuned cshuffle V3 instances. CShuffleBlockTranserScalarPerVector adapted to 4, and mergegroups fixed to 1 for now. No more special instance lists.
* Replace cshuffle non-v3 lists with v3 lists, making sure to not have duplications. Also removing stride1pad0 support for NHWGC since we can use explicit for those cases.
* Remove some instances that give incorrect results (f16 NHWGC)
* Add bf16 f32 bf16 instances based on tuned b16 NHWGC GKYXC instances.
* Add back some generic instances to make sure we have the same shape / layout / datatype support as before the instance selection process.
* Add instances for scale and bilinear based on the bf16 NHWGC GKYXC tuning. Keep generic instances for support.
* Disable two stage f16 instances which produce incorrect results.
* Remove more instances which fail verification, for bf16_f32_bf16 and for f16 scale / bilinear.
* Disable all non-generic two-stage instances in the instance lists for NHWGC. They are never faster and support is already carried by CShuffleV3 and Explicit.
* Remove unused instance lists and related add_x_instance() functions, fwd declarations, cmakelists entries. Also merge the "wmma" and "wmma v3" instance list files, which are both v3.
* Re-enable all xdl instances (un-16x16-adapted) and dl instances. Remove custom ckProfiler target.
* Remove straggler comments
* Remove [[maybe_unused]]
* Fix clang format
* Remove unwanted instances. This includes all instances which are not NHWGCxGKYXC and F16 or BF16 (no mixed in-out types).
* Add comment
---------
Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
Co-authored-by: Kiefer van Teutem <50830967+krithalith@users.noreply.github.com>
This commit is contained in:
@@ -295,7 +295,7 @@ struct ABTransferThreadTiles
|
||||
BlockDescriptor& block_descriptor,
|
||||
ABElementwiseOperation& ab_element_op,
|
||||
const index_t block_mn_id,
|
||||
const index_t)
|
||||
const index_t k_id)
|
||||
{
|
||||
constexpr index_t NumABTensor = ABsDataType::Size();
|
||||
const index_t mn_block_data_idx_on_grid =
|
||||
@@ -304,7 +304,7 @@ struct ABTransferThreadTiles
|
||||
if constexpr(NumABTensor > 1)
|
||||
{
|
||||
const auto idx_as_block_begin = generate_tuple(
|
||||
[&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
|
||||
[&](auto) { return make_multi_index(k_id, mn_block_data_idx_on_grid, 0); },
|
||||
Number<NumABTensor>{});
|
||||
|
||||
return ThreadGroupTensorSliceTransfer_v7r2<
|
||||
@@ -357,7 +357,7 @@ struct ABTransferThreadTiles
|
||||
ABThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
GlobalBufferNum>(grid_descriptor[I0],
|
||||
make_multi_index(0, mn_block_data_idx_on_grid, 0),
|
||||
make_multi_index(k_id, mn_block_data_idx_on_grid, 0),
|
||||
ab_element_op,
|
||||
block_descriptor,
|
||||
make_multi_index(0, 0, 0),
|
||||
|
||||
@@ -333,6 +333,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem() = default;
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
@@ -409,6 +410,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
// Argument
|
||||
struct Argument : public tensor_operation::device::BaseArgument, public Problem
|
||||
{
|
||||
__host__ Argument() = default;
|
||||
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
|
||||
std::array<const void*, NumBTensor> p_bs_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
@@ -583,7 +585,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -651,7 +654,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -700,7 +704,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
Argument& karg,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -735,7 +740,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
@@ -748,20 +754,146 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument>(
|
||||
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args, k_id);
|
||||
EpilogueArgument>(p_shared,
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
DefaultBlock2CTileMap(karg),
|
||||
epilogue_args,
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
|
||||
{
|
||||
return Block2CTileMap{problem.M, problem.N, 4};
|
||||
}
|
||||
|
||||
// Run method for convolution (grid descriptors are passed as arguments,
|
||||
// not generated internally)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
index_t NumGroupsToMerge,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
|
||||
});
|
||||
|
||||
const auto ds_grid_desc_m_n =
|
||||
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
|
||||
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// Scale structs (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = Scale{};
|
||||
auto a_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_idx,
|
||||
k_idx,
|
||||
karg.KBatch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -723,7 +723,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
@@ -793,7 +794,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
@@ -806,7 +808,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -857,7 +860,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -101,7 +101,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
|
||||
p_shared,
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
epilogue_args,
|
||||
0, /* A_k_id == 0 (we shift the pointer for splitk) */
|
||||
k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
@@ -344,11 +349,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch)
|
||||
// 2D grid (x,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
|
||||
// 3D grid (x,y,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
@@ -706,8 +720,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
ReduceTrait>;
|
||||
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc& de_grid_desc_m_n,
|
||||
index_t MBlock,
|
||||
index_t NBlock)
|
||||
{
|
||||
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
de_grid_desc_m_n,
|
||||
@@ -1004,6 +1020,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
}
|
||||
|
||||
// Note: arguments k_batch and k_id should be set if splitk is used
|
||||
// with implicit gemm (no pointer shift but shift using tensor descriptors)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -1034,7 +1052,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
{
|
||||
const auto as_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -1066,7 +1086,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
AsDataType,
|
||||
AElementwiseOperation,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, k_id);
|
||||
as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, A_k_id);
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
@@ -1075,7 +1095,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
BsDataType,
|
||||
BElementwiseOperation,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, k_id);
|
||||
bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, B_k_id);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
@@ -1100,7 +1120,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock);
|
||||
ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
|
||||
|
||||
Reference in New Issue
Block a user