Last active
May 20, 2023 05:06
-
-
Save yangchenyun/3777ea6ebae8cd2489e6acfcd61ec7ee to your computer and use it in GitHub Desktop.
convolution_tensor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Conv(TensorOp): | |
def __init__(self, stride: Optional[int] = 1, padding: Optional[int] = 0): | |
self.stride = stride or 1 | |
self.padding = padding or 0 | |
def compute(self, A, B): | |
N,H,W,C_in = A.shape | |
K,_,_,C_out = B.shape | |
P = self.padding | |
S = self.stride | |
A_pad = A.pad(axes=((0, 0), (P, P), (P, P), (0, 0))) | |
Ns, Hs, Ws, Cs = A_pad.strides | |
conv_strides = (Ns, Hs*S, Ws*S, Hs, Ws, Cs) | |
conv_shape = tuple(np.array([N, (H+2*P-K)/S + 1, (W+2*P-K)/S + 1], dtype=np.int64)) | |
inner_dim = K * K * C_in | |
out = A_pad.as_strided(conv_shape + (K, K, C_in), conv_strides).compact() | |
# Flatten the inner dimensions | |
out = out.reshape((out.size//inner_dim, inner_dim)) @ B.compact().reshape((inner_dim, C_out)) | |
out = out.reshape(conv_shape + (C_out,)) | |
return out | |
def gradient(self, out_grad, node): | |
Z, W = node.inputs | |
N,Hz,Wz,C_in = Z.shape | |
_,Ho,Wo,_ = out_grad.shape | |
K,_,_,C_out = W.shape | |
revP = K-1-self.padding | |
if self.stride > 1: | |
out_grad = dilate(out_grad, (1,2), self.stride - 1) | |
# Reverse calcuate the expected dimensions | |
H_g = (Hz - 1) + K - 2 * revP | |
W_g = (Wz - 1) + K - 2 * revP | |
assert H_g == out_grad.shape[1] | |
assert W_g == out_grad.shape[2] | |
# TODO: slice operator missing | |
# It is needed if the input is odd number (which we would avoid) | |
# Perform a full convolution | |
# flip kernel dimensions | |
# swap C_in and C_out | |
# dW: K,K,C_in,C_out -> K,K,C_out,C_in | |
fW = flip(W, (0, 1)) | |
dZ = conv(out_grad, transpose(fW, (2, 3)), padding=revP) | |
assert dZ.shape == Z.shape | |
# Perform a cross-validate convolution | |
# Z: N,H,W,C_in -> C_in,H,W,N, treating N as input channels | |
# out_grad: N,H,W,C_out -> W,H,N,C_out -> H,W,N,C_out, treating N as input channels, H,W as kernel window | |
# dW: C_in,K,K,C_out -> K,K,C_in,C_out (keep the order of two kernel dimensions) | |
tZ = transpose(Z, (0, 3)) | |
tOut_grad = transpose(transpose(out_grad, (0, 2)), (0, 1)) | |
tW = conv(tZ, tOut_grad, padding=self.padding) # apply the same padding as in forward pass | |
dW = transpose(transpose(tW, (0, 2)), (0, 1)) | |
assert dW.shape == W.shape | |
return dZ, dW |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment