Last active
December 22, 2020 22:12
-
-
Save alecbz/3c83dcc112672c5796dee05b994344cc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://www.reddit.com/r/askmath/comments/kic0id/uneven_probabilities/ | |
from collections import defaultdict | |
from functools import lru_cache | |
import math | |
import matplotlib.pyplot as plt | |
from pytablewriter import MarkdownTableWriter | |
PROB = { | |
1: 0.4, | |
2: 0.2, | |
3: 0.2, | |
4: 0.2, | |
} | |
@lru_cache() | |
def n_tries(n): | |
if n == 0: | |
return {0: 1} | |
dist = defaultdict(int) | |
for val, p in n_tries(n - 1).items(): | |
for roll in PROB: | |
dist[roll + val] += p * PROB[roll] | |
return dist | |
def table(): | |
values = [] | |
biggest = -math.inf | |
for n in range(1, 11): | |
dist = n_tries(n) | |
biggest = max(biggest, max(dist.keys())) | |
values.append( | |
[n] + [f"{dist[val] * 100:.1f}%" for val in range(1, max(dist.keys()) + 1)] | |
) | |
writer = MarkdownTableWriter( | |
table_name="diamond probabilities", | |
headers=["rolls/value"] + list(range(1, int(max(dist.keys())) + 1)), | |
value_matrix=values, | |
) | |
writer.write_table() | |
def plot(): | |
for n in [2, 3, 4, 5, 10, 15, 20]: | |
dist = n_tries(n) | |
vals = sorted(dist.keys()) | |
plt.plot(vals, [dist[v] for v in vals], marker="o", label=f"{n} tries") | |
plt.legend() | |
plt.show() | |
if __name__ == "__main__": | |
table() | |
plot() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment