File size: 2,148 Bytes
8026e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
diff --git a/csrc/fused_adam_cuda_kernel.cu b/csrc/fused_adam_cuda_kernel.cu
index 34f7aa2..95581d1 100644
--- a/csrc/fused_adam_cuda_kernel.cu
+++ b/csrc/fused_adam_cuda_kernel.cu
@@ -19,8 +19,8 @@ typedef enum{
 
 template <typename T, typename GRAD_T>
 __global__ void adam_cuda_kernel(
-        T* __restrict__ p,
-        GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
+        GRAD_T* __restrict__ p,
+        T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
         T* __restrict__ m,
         T* __restrict__ v,
         const GRAD_T * __restrict__ g,
@@ -50,7 +50,7 @@ __global__ void adam_cuda_kernel(
                 else // Mode 1
                     denom = sqrtf(v[j]) + eps;
                 float update = (m[j]/denom) + (decay*p[j]);
-                p[j] = p[j] - (step_size*update);
+                p[j] = (GRAD_T) (p[j] - (step_size*update));
                 if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
         }
 }
@@ -93,14 +93,14 @@ void fused_adam_cuda(
 
         if (g.scalar_type() == at::ScalarType::Half) {
 //all other values should be fp32 for half gradients
-            AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
+//            AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
 //dispatch is done on the gradient type
             using namespace at; // prevents "toString is undefined" errors
             DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", 
                 using accscalar_t = at::acc_type<scalar_t_0, true>;
                 adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
-                        p.data<accscalar_t>(),
-                        p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
+                        p.data<scalar_t_0>(),
+                        NULL, //don't output p_copy for fp32, it's wasted write
                         m.data<accscalar_t>(),
                         v.data<accscalar_t>(),
                         g.data<scalar_t_0>(),