mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#6302 (commit 8d419e8)
CK: Remove 41 commented-out dead code blocks (~200 lines) (#6302) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Depends on #6300 ## Summary Remove 41 commented-out code blocks across 33 files in Composable Kernel, totaling ~200 lines. Identified using an automated dead code scanning skill (`ck-dead-code`) with a calibrated two-stage pipeline: 1. **Pre-filter**: Keyword-based scan found 1,338 `//`-commented blocks. Calibrated heuristics (trained on 50-sample expert classification) reduced to 89 high-confidence candidates — 93% noise reduction. 2. **Expert triage**: LLM expert classified each block in context as CODE_REMOVE, CODE_KEEP, or NOT_CODE. | Classification | Count | |---------------|-------| | Removed (this PR) | 41 | | Kept (debug helpers, alt configs, reference impls) | 32 | | Not code (false positives) | 16 | Removed blocks include: superseded implementations, old test data, abandoned stubs, unreachable code, and buggy dead code.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4d0bbe5d17
commit
e0dfe58d66
@@ -745,14 +745,6 @@ struct PassThroughPack2
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
#if 0
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
|
||||
{
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<fp16x2_t>(t);
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
|
||||
{
|
||||
uint8_t x_u8 = bit_cast<uint8_t>(x);
|
||||
@@ -871,61 +863,6 @@ struct UnaryConvert
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
struct ConvertBF16RTN
|
||||
{
|
||||
// convert to bf16 using round to nearest (rtn)
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = bf16_convert_rtn<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertF8SR
|
||||
{
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = f8_convert_sr<Y>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertF8RNE
|
||||
{
|
||||
// convert to fp8 using rounding to nearest even
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
// check Y datatype
|
||||
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
// check X datatype
|
||||
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = f8_convert_rne<Y>(x);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
struct Scale
|
||||
{
|
||||
static constexpr const char* name = "Scale";
|
||||
|
||||
@@ -339,16 +339,6 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
{
|
||||
return hostArgs;
|
||||
}
|
||||
// CK_TILE_HOST static constexpr auto
|
||||
// MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
|
||||
// {
|
||||
// return hostArgs;
|
||||
// }
|
||||
// CK_TILE_HOST static constexpr auto
|
||||
// MakeKernelArgs(const MaskedGroupedFlatmmHostArgs& hostArgs)
|
||||
// {
|
||||
// return hostArgs;
|
||||
// }
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
|
||||
@@ -483,13 +483,6 @@ struct MoeFlatmmKernel
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
|
||||
// {
|
||||
// std::cerr << "Can't support N that is not a multiple of NPerBlock"
|
||||
// " without padding!"
|
||||
// << std::endl;
|
||||
// return false;
|
||||
// }
|
||||
if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
|
||||
|
||||
@@ -392,10 +392,6 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
|
||||
@@ -1151,11 +1151,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
// if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
// {
|
||||
// block_sync_lds();
|
||||
// }
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -1636,10 +1631,6 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
// {
|
||||
// load_perM = load_perM + 1;
|
||||
// }
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,13 +103,8 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
|
||||
// static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread *
|
||||
// ContinuousScaleKPerThread; static constexpr index_t ScaleBload_num =
|
||||
// kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
|
||||
// WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
|
||||
// static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
|
||||
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
@@ -352,10 +347,6 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
// {
|
||||
// load_perM = load_perM + 1;
|
||||
// }
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -390,10 +390,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
// {
|
||||
// load_perM = load_perM + 1;
|
||||
// }
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -692,9 +692,6 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
// constexpr index_t SingleVSize =
|
||||
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t BufferSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
|
||||
|
||||
|
||||
@@ -456,9 +456,6 @@ struct MoeSortingKernel
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
// constexpr int reduce_stage = 6; // 1<<6=64
|
||||
// clang-format off
|
||||
constexpr int reduce_stage = [](){
|
||||
if constexpr(wave_size_ == 2) return 1;
|
||||
else if constexpr(wave_size_ == 4) return 2;
|
||||
@@ -1206,17 +1203,21 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
// constexpr int reduce_stage = 6; // 1<<6=64
|
||||
// clang-format off
|
||||
constexpr int reduce_stage = [](){
|
||||
if constexpr(wave_size_ == 2) return 1;
|
||||
else if constexpr(wave_size_ == 4) return 2;
|
||||
else if constexpr(wave_size_ == 8) return 3;
|
||||
else if constexpr(wave_size_ == 16) return 4;
|
||||
else if constexpr(wave_size_ == 32) return 5;
|
||||
else if constexpr(wave_size_ == 64) return 6;
|
||||
else return 0;
|
||||
constexpr int reduce_stage = []() {
|
||||
if constexpr(wave_size_ == 2)
|
||||
return 1;
|
||||
else if constexpr(wave_size_ == 4)
|
||||
return 2;
|
||||
else if constexpr(wave_size_ == 8)
|
||||
return 3;
|
||||
else if constexpr(wave_size_ == 16)
|
||||
return 4;
|
||||
else if constexpr(wave_size_ == 32)
|
||||
return 5;
|
||||
else if constexpr(wave_size_ == 64)
|
||||
return 6;
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
// clang-format on
|
||||
T v_local = local;
|
||||
@@ -3047,53 +3048,6 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
x_r = x_v;
|
||||
#endif
|
||||
{
|
||||
#if 0
|
||||
#pragma unroll
|
||||
for(int j = 0; j < index_pack / 2; j++)
|
||||
{
|
||||
int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize;
|
||||
index_t x = x_d[j];
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int position = cumsum - i_show;
|
||||
prev_cumsum = s[0]; // update the last cumsum
|
||||
|
||||
if(i_show)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position] =
|
||||
MOE_SORTING_MOCK_ID(i_token, i_topk);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position] = i_token;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position] =
|
||||
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
{
|
||||
d_t i_topk;
|
||||
d_t i_show;
|
||||
@@ -3151,68 +3105,6 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
}
|
||||
position += i_show[j];
|
||||
});
|
||||
|
||||
#if 0
|
||||
int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2;
|
||||
index_t x = x_d[j];
|
||||
index_t x0 = static_cast<index_t>(x & 0xffff);
|
||||
index_t x1 = static_cast<index_t>(x >> 16);
|
||||
int i_topk_0 = x0 - 1; // topk of this token
|
||||
int i_show_0 = x0 != 0 ? 1 : 0; // has this token or not
|
||||
int i_topk_1 = x1 - 1; // topk of this token
|
||||
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show_0 + i_show_1;
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
});
|
||||
cumsum += prev_cumsum; // add previous round cumsum
|
||||
if(threadIdx.x == kBlockSize - 1)
|
||||
{
|
||||
s[0] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int position_0 = cumsum - i_show_0 - i_show_1;
|
||||
prev_cumsum = s[0]; // update the last cumsum
|
||||
|
||||
if(i_show_0)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position_0] =
|
||||
MOE_SORTING_MOCK_ID(i_token, i_topk_0);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position_0] = i_token;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position_0] =
|
||||
p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0];
|
||||
}
|
||||
|
||||
int position_1 = cumsum - i_show_1;
|
||||
|
||||
if(i_show_1)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start + position_1] =
|
||||
MOE_SORTING_MOCK_ID(i_token + 1, i_topk_1);
|
||||
#else
|
||||
p_sorted_token_ids[e_start + position_1] = i_token + 1;
|
||||
#endif
|
||||
p_sorted_weights[e_start + position_1] =
|
||||
p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
|
||||
// struct MoeSortingPipeline
|
||||
// {
|
||||
// // TODO: this kernel only support warp per row
|
||||
// using Problem = remove_cvref_t<Problem_>;
|
||||
// using Policy = remove_cvref_t<Policy_>;
|
||||
// using WeightType = typename Problem::WeightType;
|
||||
|
||||
// template <typename TopkIdWindow, typename WeightWindow>
|
||||
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
|
||||
// const WeightWindow& weight_window,
|
||||
|
||||
@@ -36,9 +36,6 @@ struct BlockGemmARegBSmemCRegOneWarpV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
@@ -19,30 +19,7 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
|
||||
|
||||
constexpr index_t NumWarp = kBlockSize / get_warp_size();
|
||||
|
||||
// FIXME
|
||||
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
|
||||
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
|
||||
@@ -16,30 +16,7 @@ struct BlockGemmARegBSmemCRegV2DefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
|
||||
|
||||
constexpr index_t NumWarp = kBlockSize / get_warp_size();
|
||||
|
||||
// FIXME
|
||||
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
|
||||
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -19,30 +19,7 @@ struct BlockGemmASmemBRegCRegV1DefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
|
||||
|
||||
constexpr index_t NumWarp = kBlockSize / get_warp_size();
|
||||
|
||||
// FIXME
|
||||
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
|
||||
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
|
||||
}
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
|
||||
@@ -120,10 +120,6 @@ struct BlockNormReduceSync
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
// const auto rs_idx =
|
||||
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
|
||||
@@ -360,17 +356,6 @@ struct BlockNormReduceCrossWarpSync
|
||||
template <typename BlockShape>
|
||||
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
|
||||
{
|
||||
#if 0
|
||||
using S = BlockShape;
|
||||
index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
|
||||
constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
|
||||
index_t iNLane = get_thread_id() % NThread;
|
||||
index_t iN0 = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
|
||||
index_t iN1 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
|
||||
index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
|
||||
index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
|
||||
return iN0 * S::Vector_N + iN3;
|
||||
#endif
|
||||
using S_ = BlockShape;
|
||||
constexpr index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;
|
||||
|
||||
|
||||
@@ -140,28 +140,6 @@ struct BlockReduce2d
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
|
||||
|
||||
// FIXME: hard coded to reduce 2nd axis
|
||||
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
|
||||
constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto y = y_tensor[y_dstr_idx];
|
||||
|
||||
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
|
||||
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
|
||||
const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
|
||||
y = reduce_func(y, x);
|
||||
});
|
||||
|
||||
y_tensor(y_dstr_idx) = y;
|
||||
});
|
||||
#endif
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeYBlockTile()
|
||||
{
|
||||
@@ -240,10 +218,6 @@ struct BlockReduce2dSync
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
// const auto rs_idx =
|
||||
// y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// loop over thread data
|
||||
|
||||
Reference in New Issue
Block a user