device_batchnorm_backward_impl.hpp Source File#
device_batchnorm_backward_impl.hpp
Go to the documentation of this file.
13#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
14#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp"
15#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp"
123 static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition gridwise_multiblock_welford_first_half.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__global__ void kernel_welford_second_half_reduce_first_half(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:27
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__global__ void kernel_batchnorm_backward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, long_index_t reduce_size, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:31
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__global__ void kernel_reduce_second_half_batchnorm_backward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, index_t blkgroup_size, long_index_t reduce_size, index_t num_xy_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration, const DscaleDbiasDataType *const __restrict__ p_reduce_dscale, const DscaleDbiasDataType *const __restrict__ p_reduce_dbias, const MeanVarDataType *const __restrict__ p_mean, const MeanVarDataType *const __restrict__ p_inv_var, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_batchnorm_backward_blockwise_welford.hpp:100
Definition gridwise_multiblock_welford_first_half.hpp:55
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:99
Definition gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:96
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition functional2.hpp:33
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_batchnorm_backward.hpp:27
Definition device_batchnorm_backward_impl.hpp:197
std::array< index_t, Rank > dyStrides_
Definition device_batchnorm_backward_impl.hpp:295
XYGridDesc_M_K x_grid_desc_m_k
Definition device_batchnorm_backward_impl.hpp:320
AccDataType epsilon_
Definition device_batchnorm_backward_impl.hpp:289
DscaleDbiasDataType * p_dscale_
Definition device_batchnorm_backward_impl.hpp:310
std::array< index_t, Rank > xStrides_
Definition device_batchnorm_backward_impl.hpp:294
std::array< index_t, Rank > xyLengths_
Definition device_batchnorm_backward_impl.hpp:293
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides_
Definition device_batchnorm_backward_impl.hpp:299
bool haveSavedMeanInvVar_
Definition device_batchnorm_backward_impl.hpp:291
const MeanVarDataType * p_savedMean_
Definition device_batchnorm_backward_impl.hpp:306
int blkGroupSize
Definition device_batchnorm_backward_impl.hpp:316
std::array< index_t, Rank > dxStrides_
Definition device_batchnorm_backward_impl.hpp:296
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m
Definition device_batchnorm_backward_impl.hpp:324
std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides_
Definition device_batchnorm_backward_impl.hpp:301
void * workspace_reduce_dbias
Definition device_batchnorm_backward_impl.hpp:335
const ScaleDataType * p_scale_
Definition device_batchnorm_backward_impl.hpp:305
long_index_t reduce_length
Definition device_batchnorm_backward_impl.hpp:314
const DyDataType * p_dy_
Definition device_batchnorm_backward_impl.hpp:304
Argument(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const XDataType *p_x, const DyDataType *p_dy, const ScaleDataType *p_scale, const MeanVarDataType *p_savedMean, const MeanVarDataType *p_savedInvVar, const DyElementwiseOp dy_elementwise_op, double epsilon, DxDataType *p_dx, DscaleDbiasDataType *p_dscale, DscaleDbiasDataType *p_dbias)
Definition device_batchnorm_backward_impl.hpp:198
ScaleBiasGridDesc_M scale_grid_desc_m
Definition device_batchnorm_backward_impl.hpp:323
size_t gridSize
Definition device_batchnorm_backward_impl.hpp:318
DxDataType * p_dx_
Definition device_batchnorm_backward_impl.hpp:309
void * workspace_variance
Definition device_batchnorm_backward_impl.hpp:328
MeanVarGridDesc_M mean_var_grid_desc_m
Definition device_batchnorm_backward_impl.hpp:325
const XDataType * p_x_
Definition device_batchnorm_backward_impl.hpp:303
void * workspace_savedMean
Definition device_batchnorm_backward_impl.hpp:331
int numBlockTileIteration
Definition device_batchnorm_backward_impl.hpp:317
void * workspace_mean
Definition device_batchnorm_backward_impl.hpp:327
void * workspace_savedInvVar
Definition device_batchnorm_backward_impl.hpp:332
long_index_t invariant_length
Definition device_batchnorm_backward_impl.hpp:313
DscaleDbiasDataType * p_dbias_
Definition device_batchnorm_backward_impl.hpp:311
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths_
Definition device_batchnorm_backward_impl.hpp:298
std::array< index_t, Rank - NumBatchNormReduceDim > bnDscaleDbiasStrides_
Definition device_batchnorm_backward_impl.hpp:300
void * workspace_count
Definition device_batchnorm_backward_impl.hpp:329
XYGridDesc_M_K dy_grid_desc_m_k
Definition device_batchnorm_backward_impl.hpp:321
const MeanVarDataType * p_savedInvVar_
Definition device_batchnorm_backward_impl.hpp:307
const DyElementwiseOp dy_elementwise_op_
Definition device_batchnorm_backward_impl.hpp:308
void * workspace_reduce_dscale
Definition device_batchnorm_backward_impl.hpp:334
XYGridDesc_M_K dx_grid_desc_m_k
Definition device_batchnorm_backward_impl.hpp:322
Definition device_batchnorm_backward_impl.hpp:436
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batchnorm_backward_impl.hpp:437
float Run(const BaseArgument *pArg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batchnorm_backward_impl.hpp:742
Definition device_batchnorm_backward_impl.hpp:58
std::string GetTypeString() const override
Definition device_batchnorm_backward_impl.hpp:858
static constexpr index_t NumInvariantDim
Definition device_batchnorm_backward_impl.hpp:71
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const void *p_x, const void *p_dy, const void *p_scale, const void *p_savedMean, const void *p_savedInvVar, double epsilon, const DyElementwiseOp dy_elementwise_op, void *p_dx, void *p_dscale, void *p_dbias) override
Definition device_batchnorm_backward_impl.hpp:812
static constexpr index_t M_BlockTileSize
Definition device_batchnorm_backward_impl.hpp:73
bool IsSupportedArgument(const BaseArgument *pArg) override
Definition device_batchnorm_backward_impl.hpp:749
static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
Definition device_batchnorm_backward_impl.hpp:123
ScaleBiasGridDesc_M MeanVarGridDesc_M
Definition device_batchnorm_backward_impl.hpp:194
static constexpr index_t K_BlockTileSize
Definition device_batchnorm_backward_impl.hpp:74
static auto MakeXY2dDescriptor(const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
Definition device_batchnorm_backward_impl.hpp:76
decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})) ScaleBiasGridDesc_M
Definition device_batchnorm_backward_impl.hpp:193
static auto MakeScaleBiasMeanVar1dDescriptor(const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
Definition device_batchnorm_backward_impl.hpp:163
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_batchnorm_backward_impl.hpp:338
static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize)
Definition device_batchnorm_backward_impl.hpp:141
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_batchnorm_backward_impl.hpp:379
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batchnorm_backward_impl.hpp:853
decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)) XYGridDesc_M_K
Definition device_batchnorm_backward_impl.hpp:192
Definition welford_helper.hpp:44