|
#include <torch/extension.h> |
|
#include <ATen/ATen.h> |
|
|
|
#include <vector> |
|
|
|
|
|
|
|
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input); |
|
|
|
|
|
|
|
|
|
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, |
|
const at::Tensor var_biased_feature_nodes, |
|
int numel, |
|
const float eps); |
|
|
|
|
|
|
|
|
|
at::Tensor batchnorm_forward_CUDA(const at::Tensor input, |
|
const at::Tensor mean, |
|
const at::Tensor inv_std, |
|
const at::optional<at::Tensor> weight, |
|
const at::optional<at::Tensor> shift); |
|
|
|
|
|
|
|
|
|
|
|
std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output, |
|
const at::Tensor input, |
|
const at::Tensor mean, |
|
const at::Tensor inv_std, |
|
const at::optional<at::Tensor> weight); |
|
|
|
|
|
|
|
|
|
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output, |
|
const at::Tensor input, |
|
const at::Tensor mean, |
|
const at::Tensor inv_std, |
|
const at::optional<at::Tensor> weight, |
|
const at::Tensor mean_dy, |
|
const at::Tensor mean_dy_xmu); |
|
|
|
|
|
|
|
|
|
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input); |
|
|
|
|
|
|
|
|
|
|
|
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, |
|
const at::Tensor mean, |
|
const at::Tensor inv_std, |
|
const at::optional<at::Tensor> weight, |
|
const at::optional<at::Tensor> shift); |
|
|
|
|
|
|
|
|
|
|
|
std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output, |
|
const at::Tensor input, |
|
const at::Tensor mean, |
|
const at::Tensor inv_std, |
|
const at::optional<at::Tensor> weight); |
|
|
|
|
|
|
|
|
|
|
|
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, |
|
const at::Tensor input, |
|
const at::Tensor mean, |
|
const at::Tensor inv_std, |
|
const at::optional<at::Tensor> weight, |
|
const at::Tensor mean_dy, |
|
const at::Tensor mean_dy_xmu); |
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance"); |
|
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance"); |
|
m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward"); |
|
m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad"); |
|
m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad"); |
|
m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc"); |
|
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc"); |
|
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc"); |
|
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc"); |
|
} |
|
|