Created
February 24, 2019 12:01
-
-
Save Souldiv/cabb5e794b9c341335a70269e7ab2730 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): | |
""" | |
Utility function for computing output of convolutions | |
takes a tuple of (h,w) and returns a tuple of (h,w) | |
""" | |
if type(h_w) is not tuple: | |
h_w = (h_w, h_w) | |
if type(kernel_size) is not tuple: | |
kernel_size = (kernel_size, kernel_size) | |
if type(stride) is not tuple: | |
stride = (stride, stride) | |
if type(pad) is not tuple: | |
pad = (pad, pad) | |
h = (h_w[0] + (2 * pad[0]) - (dilation * (kernel_size[0] - 1)) - 1)// stride[0] + 1 | |
w = (h_w[1] + (2 * pad[1]) - (dilation * (kernel_size[1] - 1)) - 1)// stride[1] + 1 | |
return h, w | |
def convtransp_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): | |
""" | |
Utility function for computing output of transposed convolutions | |
takes a tuple of (h,w) and returns a tuple of (h,w) | |
""" | |
if type(h_w) is not tuple: | |
h_w = (h_w, h_w) | |
if type(kernel_size) is not tuple: | |
kernel_size = (kernel_size, kernel_size) | |
if type(stride) is not tuple: | |
stride = (stride, stride) | |
if type(pad) is not tuple: | |
pad = (pad, pad) | |
h = (h_w[0] - 1) * stride[0] - 2 * pad[0] + kernel_size[0] + pad[0] | |
w = (h_w[1] - 1) * stride[1] - 2 * pad[1] + kernel_size[1] + pad[1] | |
return h, w |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment