Skip to content

Instantly share code, notes, and snippets.

@tsbertalan
Last active September 14, 2018 19:39
Show Gist options
  • Save tsbertalan/b64ae509f2bd5f44bc262cc3580b999c to your computer and use it in GitHub Desktop.
Save tsbertalan/b64ae509f2bd5f44bc262cc3580b999c to your computer and use it in GitHub Desktop.
Matrix-of-matrices class implementing a recursive matrix-multiplication.
MatOMat([[array([[ 0.05056171, 0.49995133],
[-0.99590893, 0.69359851]]),
array([[-0.41830152, -1.58457724],
[-0.64770677, 0.59857517]])],
[array([[ 0.33225003, -1.14747663],
[ 0.61866969, -0.08798693]]),
array([[ 0.4250724 , 0.33225315],
[-1.15681626, 0.35099715]])],
[array([[-0.60688728, 1.54697933],
[ 0.72334161, 0.04613557]]),
array([[-0.98299165, 0.05443274],
[ 0.15989294, -1.20894816]])]])
@
MatOMat([[array([[ 2.22336022, 0.39429521],
[ 1.69235772, -1.11281215]]),
array([[ 1.63574754, -1.36096559],
[-0.65122583, 0.54245131]]),
array([[ 0.04800625, -2.35807363],
[-1.10558404, 0.83783635]])],
[array([[ 2.08787087, 0.91484096],
[-0.27620335, 0.7965119 ]]),
array([[-1.14379857, 0.50991978],
[-1.3474603 , -0.0093601 ]]),
array([[-0.13070464, 0.80208661],
[-0.30296397, 1.20200259]])]])
=
MatOMat([[array([[-0.79868078, -2.81671485],
[ 2.35407818, -0.58279507]]),
array([[-0.94914879, 1.16335729],
[ 2.59928392, -0.86656829]]),
array([[-1.62745963, 2.00418427],
[ 0.32908715, -2.99194956]])],
[array([[-0.79868078, -2.81671485],
[ 2.35407818, -0.58279507]]),
array([[-0.94914879, 1.16335729],
[ 2.59928392, -0.86656829]]),
array([[-1.62745963, 2.00418427],
[ 0.32908715, -2.99194956]])],
[array([[-0.79868078, -2.81671485],
[ 2.35407818, -0.58279507]]),
array([[-0.94914879, 1.16335729],
[ 2.59928392, -0.86656829]]),
array([[-1.62745963, 2.00418427],
[ 0.32908715, -2.99194956]])]])
The entry at position (0, 1) in A is:
[[-0.41830152 -1.58457724]
[-0.64770677 0.59857517]]
#!/usr/bin/env python3
import numpy as np
from pprint import pformat
class MatOMat:
def __init__(self, A):
self.storage = A
# Obviously there are some edge cases ignored here:
self.nrows = len(A)
self.ncols = len(A[0])
def getrow(self, i):
return self.storage[i]
def getcol(self, j):
return [row[j] for row in self.storage]
def dot(self, other):
res = [[None]*other.ncols]*self.nrows
assert self.ncols == other.nrows, "Cols of first (%d) != rows of second (%d)" % (self.ncols, other.nrows)
for i in range(self.nrows):
for j in range(other.ncols):
res[i][j] = sum([
a @ b
for (a, b) in zip(
self.getrow(i),
other.getcol(j),
)
])
return MatOMat(res)
def __matmul__(self, other):
return self.dot(other)
def __rmatmul__(self, other):
return other.matmul(self)
def __repr__(self):
return '%s(%s)' % (type(self).__name__, pformat(self.storage))
def __getitem__(self, sl):
return np.asarray(self.storage).__getitem__(sl)
def test():
np.random.seed(4)
def r(nr=2, nc=2):
return np.random.normal(size=(nr, nc))
A = MatOMat([
[r(), r()],
[r(), r()],
[r(), r()],
])
B = MatOMat([
[r(), r(), r()],
[r(), r(), r()],
])
print(A)
print('@')
print(B)
print('=')
print(A @ B)
print()
print('The entry at position (0, 1) in A is:')
print(A[0, 1])
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment