Created
January 22, 2019 17:06
-
-
Save lolz0r/c3ab23c6763667edf2d8b01ee154aa13 to your computer and use it in GitHub Desktop.
Learned basis function, pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ConvSeluSVD(nn.Module): | |
def __init__(self, inputSize, outputSize, stride=1, maxpool=False, ownBasis=False): | |
super(ConvSeluSVD, self).__init__() | |
self.inputSize = inputSize | |
self.outputSize = outputSize | |
self.stride = stride | |
self.params = Parameter( torch.Tensor(outputSize * inputSize, 1,3).normal_(0, .02)) | |
self.selu = nn.SELU(True) | |
self.bias = Parameter( torch.zeros(outputSize)) | |
self.maxpool = maxpool | |
if ownBasis == True: | |
self.basisWeights = Parameter( torch.Tensor( | |
[[-0.21535662, -0.30022025, -0.26041868, -0.314888, -0.45471892, -0.3971264, | |
-0.26603645, -0.3896653, -0.33079177], | |
[ 0.34970352, 0.50572443, 0.36894855, 0.07661748, 0.08152138, 0.02740295, | |
-0.28591475, -0.49375448, -0.38343033], | |
[-0.3019736, -0.02775075, 0.29349312, -0.50207216, -0.05312577, 0.5471206, | |
-0.39858055, -0.09402011, 0.31616086]] )) | |
def forward(self, input, basis_=None): | |
if basis_ is None: | |
basis_ = self.basisWeights | |
basis = basis_.unsqueeze(0) | |
basis = basis.expand(self.params.size(0), basis.size(1), basis.size(2) ) | |
weights = torch.bmm(self.params, basis ) | |
weights = weights.squeeze() | |
weights = weights.view(self.outputSize, self.inputSize, 3,3) | |
x = torch.nn.functional.conv2d(input, | |
weights, | |
bias=self.bias, | |
stride=self.stride, | |
padding=1, | |
dilation=1, | |
groups=1) | |
x = self.selu(x) | |
if self.maxpool: | |
x = F.max_pool2d(x, 2) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment