File size: 3,073 Bytes
61c2d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <torch/extension.h>

#include <vector>

#include "inplace_abn.h"

std::vector<at::Tensor> mean_var(at::Tensor x) {
  if (x.is_cuda()) {
    if (x.type().scalarType() == at::ScalarType::Half) {
      return mean_var_cuda_h(x);
    } else {
      return mean_var_cuda(x);
    }
  } else {
    return mean_var_cpu(x);
  }
}

at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
                   bool affine, float eps) {
  if (x.is_cuda()) {
    if (x.type().scalarType() == at::ScalarType::Half) {
      return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
    } else {
      return forward_cuda(x, mean, var, weight, bias, affine, eps);
    }
  } else {
    return forward_cpu(x, mean, var, weight, bias, affine, eps);
  }
}

std::vector<at::Tensor> edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
                                 bool affine, float eps) {
  if (z.is_cuda()) {
    if (z.type().scalarType() == at::ScalarType::Half) {
      return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
    } else {
      return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
	}
  } else {
    return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
  }
}

at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
                                 at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
  if (z.is_cuda()) {
    if (z.type().scalarType() == at::ScalarType::Half) {
      return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
	} else {
      return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
    }
  } else {
    return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
  }
}

void leaky_relu_forward(at::Tensor z, float slope) {
  at::leaky_relu_(z, slope);
}

void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
  if (z.is_cuda()) {
    if (z.type().scalarType() == at::ScalarType::Half) {
      return leaky_relu_backward_cuda_h(z, dz, slope);
	} else {
      return leaky_relu_backward_cuda(z, dz, slope);
    }
  } else {
    return leaky_relu_backward_cpu(z, dz, slope);
  }
}

void elu_forward(at::Tensor z) {
  at::elu_(z);
}

void elu_backward(at::Tensor z, at::Tensor dz) {
  if (z.is_cuda()) {
    return elu_backward_cuda(z, dz);
  } else {
    return elu_backward_cpu(z, dz);
  }
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("mean_var", &mean_var, "Mean and variance computation");
  m.def("forward", &forward, "In-place forward computation");
  m.def("edz_eydz", &edz_eydz, "First part of backward computation");
  m.def("backward", &backward, "Second part of backward computation");
  m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
  m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
  m.def("elu_forward", &elu_forward, "Elu forward computation");
  m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
}