mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
fix build (#246)
* fix build * Revert "fix build" This reverts commitd73102384b. * post PR #235 merge fix * amend Co-authored-by: Anthony Chang <ac.chang@outlook.com> [ROCm/composable_kernel commit:ba58a93f60]
This commit is contained in:
@@ -802,17 +802,16 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
M01_,
|
||||
N01_))
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -871,14 +870,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
|
||||
}
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
|
||||
@@ -1066,8 +1065,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -495,12 +496,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
@@ -532,31 +533,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check M01, N01
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
if(!(M0 % M01 == 0 && N0 % N01 == 0))
|
||||
if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
// const bool has_main_k0_block_loop = K0 > K0PerBlock;
|
||||
@@ -588,37 +573,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(KBatch),
|
||||
make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
|
||||
|
||||
const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor;
|
||||
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>(
|
||||
c_m_n_grid_desc, M01, N01, KBatch);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -667,6 +623,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
|
||||
if(!c_block_cluster_adaptor.ValidCTileIndex(
|
||||
make_tuple(block_work_idx[I1], block_work_idx[I2]),
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
|
||||
|
||||
Reference in New Issue
Block a user