Last active
September 14, 2018 19:39
-
-
Save tsbertalan/b64ae509f2bd5f44bc262cc3580b999c to your computer and use it in GitHub Desktop.
Matrix-of-matrices class implementing a recursive matrix-multiplication.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
MatOMat([[array([[ 0.05056171, 0.49995133], | |
[-0.99590893, 0.69359851]]), | |
array([[-0.41830152, -1.58457724], | |
[-0.64770677, 0.59857517]])], | |
[array([[ 0.33225003, -1.14747663], | |
[ 0.61866969, -0.08798693]]), | |
array([[ 0.4250724 , 0.33225315], | |
[-1.15681626, 0.35099715]])], | |
[array([[-0.60688728, 1.54697933], | |
[ 0.72334161, 0.04613557]]), | |
array([[-0.98299165, 0.05443274], | |
[ 0.15989294, -1.20894816]])]]) | |
@ | |
MatOMat([[array([[ 2.22336022, 0.39429521], | |
[ 1.69235772, -1.11281215]]), | |
array([[ 1.63574754, -1.36096559], | |
[-0.65122583, 0.54245131]]), | |
array([[ 0.04800625, -2.35807363], | |
[-1.10558404, 0.83783635]])], | |
[array([[ 2.08787087, 0.91484096], | |
[-0.27620335, 0.7965119 ]]), | |
array([[-1.14379857, 0.50991978], | |
[-1.3474603 , -0.0093601 ]]), | |
array([[-0.13070464, 0.80208661], | |
[-0.30296397, 1.20200259]])]]) | |
= | |
MatOMat([[array([[-0.79868078, -2.81671485], | |
[ 2.35407818, -0.58279507]]), | |
array([[-0.94914879, 1.16335729], | |
[ 2.59928392, -0.86656829]]), | |
array([[-1.62745963, 2.00418427], | |
[ 0.32908715, -2.99194956]])], | |
[array([[-0.79868078, -2.81671485], | |
[ 2.35407818, -0.58279507]]), | |
array([[-0.94914879, 1.16335729], | |
[ 2.59928392, -0.86656829]]), | |
array([[-1.62745963, 2.00418427], | |
[ 0.32908715, -2.99194956]])], | |
[array([[-0.79868078, -2.81671485], | |
[ 2.35407818, -0.58279507]]), | |
array([[-0.94914879, 1.16335729], | |
[ 2.59928392, -0.86656829]]), | |
array([[-1.62745963, 2.00418427], | |
[ 0.32908715, -2.99194956]])]]) | |
The entry at position (0, 1) in A is: | |
[[-0.41830152 -1.58457724] | |
[-0.64770677 0.59857517]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import numpy as np | |
from pprint import pformat | |
class MatOMat: | |
def __init__(self, A): | |
self.storage = A | |
# Obviously there are some edge cases ignored here: | |
self.nrows = len(A) | |
self.ncols = len(A[0]) | |
def getrow(self, i): | |
return self.storage[i] | |
def getcol(self, j): | |
return [row[j] for row in self.storage] | |
def dot(self, other): | |
res = [[None]*other.ncols]*self.nrows | |
assert self.ncols == other.nrows, "Cols of first (%d) != rows of second (%d)" % (self.ncols, other.nrows) | |
for i in range(self.nrows): | |
for j in range(other.ncols): | |
res[i][j] = sum([ | |
a @ b | |
for (a, b) in zip( | |
self.getrow(i), | |
other.getcol(j), | |
) | |
]) | |
return MatOMat(res) | |
def __matmul__(self, other): | |
return self.dot(other) | |
def __rmatmul__(self, other): | |
return other.matmul(self) | |
def __repr__(self): | |
return '%s(%s)' % (type(self).__name__, pformat(self.storage)) | |
def __getitem__(self, sl): | |
return np.asarray(self.storage).__getitem__(sl) | |
def test(): | |
np.random.seed(4) | |
def r(nr=2, nc=2): | |
return np.random.normal(size=(nr, nc)) | |
A = MatOMat([ | |
[r(), r()], | |
[r(), r()], | |
[r(), r()], | |
]) | |
B = MatOMat([ | |
[r(), r(), r()], | |
[r(), r(), r()], | |
]) | |
print(A) | |
print('@') | |
print(B) | |
print('=') | |
print(A @ B) | |
print() | |
print('The entry at position (0, 1) in A is:') | |
print(A[0, 1]) | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment