Last active
July 6, 2022 05:00
-
-
Save ourway/220198e46aa1d3c500a8cbb9468ede1d to your computer and use it in GitHub Desktop.
Best Python Snippets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## 1::calculating_with_dictionaries | |
# example.py | |
# | |
# Example of calculating with dictionaries | |
prices = {"ACME": 45.23, "AAPL": 612.78, "IBM": 205.55, "HPQ": 37.20, "FB": 10.75} | |
# Find min and max price | |
min_price = min(zip(prices.values(), prices.keys())) | |
max_price = max(zip(prices.values(), prices.keys())) | |
print("min price:", min_price) | |
print("max price:", max_price) | |
print("sorted prices:") | |
prices_sorted = sorted(zip(prices.values(), prices.keys())) | |
for price, name in prices_sorted: | |
print(" ", name, price) | |
################################################################################ | |
## 1::determine_the_top_n_items_occurring_in_a_list | |
# example.py | |
# | |
# Determine the most common words in a list | |
words = [ | |
"look", | |
"into", | |
"my", | |
"eyes", | |
"look", | |
"into", | |
"my", | |
"eyes", | |
"the", | |
"eyes", | |
"the", | |
"eyes", | |
"the", | |
"eyes", | |
"not", | |
"around", | |
"the", | |
"eyes", | |
"don't", | |
"look", | |
"around", | |
"the", | |
"eyes", | |
"look", | |
"into", | |
"my", | |
"eyes", | |
"you're", | |
"under", | |
] | |
from collections import Counter | |
word_counts = Counter(words) | |
top_three = word_counts.most_common(3) | |
print(top_three) | |
# outputs [('eyes', 8), ('the', 5), ('look', 4)] | |
# Example of merging in more words | |
morewords = ["why", "are", "you", "not", "looking", "in", "my", "eyes"] | |
word_counts.update(morewords) | |
print(word_counts.most_common(3)) | |
################################################################################ | |
## 1::extracting_a_subset_of_a_dictionary | |
# example of extracting a subset from a dictionary | |
from pprint import pprint | |
prices = {"ACME": 45.23, "AAPL": 612.78, "IBM": 205.55, "HPQ": 37.20, "FB": 10.75} | |
# Make a dictionary of all prices over 200 | |
p1 = {key: value for key, value in prices.items() if value > 200} | |
print("All prices over 200") | |
pprint(p1) | |
# Make a dictionary of tech stocks | |
tech_names = {"AAPL", "IBM", "HPQ", "MSFT"} | |
p2 = {key: value for key, value in prices.items() if key in tech_names} | |
print("All techs") | |
pprint(p2) | |
################################################################################ | |
## 1::filtering_list_elements | |
# Examples of different ways to filter data | |
mylist = [1, 4, -5, 10, -7, 2, 3, -1] | |
# All positive values | |
pos = [n for n in mylist if n > 0] | |
print(pos) | |
# All negative values | |
neg = [n for n in mylist if n < 0] | |
print(neg) | |
# Negative values clipped to 0 | |
neg_clip = [n if n > 0 else 0 for n in mylist] | |
print(neg_clip) | |
# Positive values clipped to 0 | |
pos_clip = [n if n < 0 else 0 for n in mylist] | |
print(pos_clip) | |
# Compressing example | |
addresses = [ | |
"5412 N CLARK", | |
"5148 N CLARK", | |
"5800 E 58TH", | |
"2122 N CLARK", | |
"5645 N RAVENSWOOD", | |
"1060 W ADDISON", | |
"4801 N BROADWAY", | |
"1039 W GRANVILLE", | |
] | |
counts = [0, 3, 10, 4, 1, 7, 6, 1] | |
from itertools import compress | |
more5 = [n > 5 for n in counts] | |
a = list(compress(addresses, more5)) | |
print(a) | |
################################################################################ | |
## 1::finding_out_what_two_dictionaries_have_in_common | |
# example.py | |
# | |
# Find out what two dictionaries have in common | |
a = {"x": 1, "y": 2, "z": 3} | |
b = {"w": 10, "x": 11, "y": 2} | |
print("Common keys:", a.keys() & b.keys()) | |
print("Keys in a not in b:", a.keys() - b.keys()) | |
print("(key,value) pairs in common:", a.items() & b.items()) | |
################################################################################ | |
## 1::finding_the_largest_or_smallest_n_items | |
# example.py | |
# | |
# Example of using heapq to find the N smallest or largest items | |
import heapq | |
portfolio = [ | |
{"name": "IBM", "shares": 100, "price": 91.1}, | |
{"name": "AAPL", "shares": 50, "price": 543.22}, | |
{"name": "FB", "shares": 200, "price": 21.09}, | |
{"name": "HPQ", "shares": 35, "price": 31.75}, | |
{"name": "YHOO", "shares": 45, "price": 16.35}, | |
{"name": "ACME", "shares": 75, "price": 115.65}, | |
] | |
cheap = heapq.nsmallest(3, portfolio, key=lambda s: s["price"]) | |
expensive = heapq.nlargest(3, portfolio, key=lambda s: s["price"]) | |
print(cheap) | |
print(expensive) | |
################################################################################ | |
## 1::grouping-records-together-based-on-a-field | |
rows = [ | |
{"address": "5412 N CLARK", "date": "07/01/2012"}, | |
{"address": "5148 N CLARK", "date": "07/04/2012"}, | |
{"address": "5800 E 58TH", "date": "07/02/2012"}, | |
{"address": "2122 N CLARK", "date": "07/03/2012"}, | |
{"address": "5645 N RAVENSWOOD", "date": "07/02/2012"}, | |
{"address": "1060 W ADDISON", "date": "07/02/2012"}, | |
{"address": "4801 N BROADWAY", "date": "07/01/2012"}, | |
{"address": "1039 W GRANVILLE", "date": "07/04/2012"}, | |
] | |
from itertools import groupby | |
rows.sort(key=lambda r: r["date"]) | |
for date, items in groupby(rows, key=lambda r: r["date"]): | |
print(date) | |
for i in items: | |
print(" ", i) | |
# Example of building a multidict | |
from collections import defaultdict | |
rows_by_date = defaultdict(list) | |
for row in rows: | |
rows_by_date[row["date"]].append(row) | |
for r in rows_by_date["07/01/2012"]: | |
print(r) | |
################################################################################ | |
## 1::implementing_a_priority_queue | |
# example.py | |
# | |
# Example of a priority queue | |
import heapq | |
class PriorityQueue: | |
def __init__(self): | |
self._queue = [] | |
self._index = 0 | |
def push(self, item, priority): | |
heapq.heappush(self._queue, (-priority, self._index, item)) | |
self._index += 1 | |
def pop(self): | |
return heapq.heappop(self._queue)[-1] | |
# Example use | |
class Item: | |
def __init__(self, name): | |
self.name = name | |
def __repr__(self): | |
return "Item({!r})".format(self.name) | |
q = PriorityQueue() | |
q.push(Item("foo"), 1) | |
q.push(Item("bar"), 5) | |
q.push(Item("spam"), 4) | |
q.push(Item("grok"), 1) | |
print("Should be bar:", q.pop()) | |
print("Should be spam:", q.pop()) | |
print("Should be foo:", q.pop()) | |
print("Should be grok:", q.pop()) | |
################################################################################ | |
## 1::keeping_the_last_n_items | |
from collections import deque | |
def search(lines, pattern, history=5): | |
previous_lines = deque(maxlen=history) | |
for line in lines: | |
if pattern in line: | |
yield line, previous_lines | |
previous_lines.append(line) | |
# Example use on a file | |
if __name__ == "__main__": | |
with open("somefile.txt") as f: | |
for line, prevlines in search(f, "python", 5): | |
for pline in prevlines: | |
print(pline, end="") | |
print(line, end="") | |
print("-" * 20) | |
################################################################################ | |
## 1::mapping_names_to_sequence_elements | |
# example.py | |
from collections import namedtuple | |
Stock = namedtuple("Stock", ["name", "shares", "price"]) | |
def compute_cost(records): | |
total = 0.0 | |
for rec in records: | |
s = Stock(*rec) | |
total += s.shares * s.price | |
return total | |
# Some Data | |
records = [("GOOG", 100, 490.1), ("ACME", 100, 123.45), ("IBM", 50, 91.15)] | |
print(compute_cost(records)) | |
################################################################################ | |
## 1::removing_duplicates_from_a_sequence_while_maintaining_order | |
# example.py | |
# | |
# Remove duplicate entries from a sequence while keeping order | |
def dedupe(items): | |
seen = set() | |
for item in items: | |
if item not in seen: | |
yield item | |
seen.add(item) | |
if __name__ == "__main__": | |
a = [1, 5, 2, 1, 9, 1, 5, 10] | |
print(a) | |
print(list(dedupe(a))) | |
################################################################################ | |
## 1::removing_duplicates_from_a_sequence_while_maintaining_order | |
# example2.py | |
# | |
# Remove duplicate entries from a sequence while keeping order | |
def dedupe(items, key=None): | |
seen = set() | |
for item in items: | |
val = item if key is None else key(item) | |
if val not in seen: | |
yield item | |
seen.add(val) | |
if __name__ == "__main__": | |
a = [ | |
{"x": 2, "y": 3}, | |
{"x": 1, "y": 4}, | |
{"x": 2, "y": 3}, | |
{"x": 2, "y": 3}, | |
{"x": 10, "y": 15}, | |
] | |
print(a) | |
print(list(dedupe(a, key=lambda a: (a["x"], a["y"])))) | |
################################################################################ | |
## 1::sort_a_list_of_dictionaries_by_a_common_key | |
# example.py | |
# | |
# Sort a list of a dicts on a common key | |
rows = [ | |
{"fname": "Brian", "lname": "Jones", "uid": 1003}, | |
{"fname": "David", "lname": "Beazley", "uid": 1002}, | |
{"fname": "John", "lname": "Cleese", "uid": 1001}, | |
{"fname": "Big", "lname": "Jones", "uid": 1004}, | |
] | |
from operator import itemgetter | |
rows_by_fname = sorted(rows, key=itemgetter("fname")) | |
rows_by_uid = sorted(rows, key=itemgetter("uid")) | |
from pprint import pprint | |
print("Sorted by fname:") | |
pprint(rows_by_fname) | |
print("Sorted by uid:") | |
pprint(rows_by_uid) | |
rows_by_lfname = sorted(rows, key=itemgetter("lname", "fname")) | |
print("Sorted by lname,fname:") | |
pprint(rows_by_lfname) | |
################################################################################ | |
## 1::sort_objects_without_native_comparison_support | |
from operator import attrgetter | |
class User: | |
def __init__(self, user_id): | |
self.user_id = user_id | |
def __repr__(self): | |
return "User({})".format(self.user_id) | |
# Example | |
users = [User(23), User(3), User(99)] | |
print(users) | |
# Sort it by user-id | |
print(sorted(users, key=attrgetter("user_id"))) | |
################################################################################ | |
## 1::transforming_and_reducing_data_at_the_same_time | |
# example.py | |
# | |
# Some examples of using generators in arguments | |
import os | |
files = os.listdir(os.path.expanduser("~")) | |
if any(name.endswith(".py") for name in files): | |
print("There be python!") | |
else: | |
print("Sorry, no python.") | |
# Output a tuple as CSV | |
s = ("ACME", 50, 123.45) | |
print(",".join(str(x) for x in s)) | |
# Data reduction across fields of a data structure | |
portfolio = [ | |
{"name": "GOOG", "shares": 50}, | |
{"name": "YHOO", "shares": 75}, | |
{"name": "AOL", "shares": 20}, | |
{"name": "SCOX", "shares": 65}, | |
] | |
min_shares = min(s["shares"] for s in portfolio) | |
print(min_shares) | |
################################################################################ | |
## 1::unpack_a_fixed_number_of_elements_from_iterables_of_arbitrary_length | |
# example.py | |
# | |
# Unpacking of tagged tuples of varying sizes | |
records = [ | |
("foo", 1, 2), | |
("bar", "hello"), | |
("foo", 3, 4), | |
] | |
def do_foo(x, y): | |
print("foo", x, y) | |
def do_bar(s): | |
print("bar", s) | |
for tag, *args in records: | |
if tag == "foo": | |
do_foo(*args) | |
elif tag == "bar": | |
do_bar(*args) | |
################################################################################ | |
## 1::working_with_multiple_mappings_as_a_single_mapping | |
# example.py | |
# | |
# Example of combining dicts into a chainmap | |
a = {"x": 1, "z": 3} | |
b = {"y": 2, "z": 4} | |
# (a) Simple example of combining | |
from collections import ChainMap | |
c = ChainMap(a, b) | |
print(c["x"]) # Outputs 1 (from a) | |
print(c["y"]) # Outputs 2 (from b) | |
print(c["z"]) # Outputs 3 (from a) | |
# Output some common values | |
print("len(c):", len(c)) | |
print("c.keys():", list(c.keys())) | |
print("c.values():", list(c.values())) | |
# Modify some values | |
c["z"] = 10 | |
c["w"] = 40 | |
del c["x"] | |
print("a:", a) | |
# Example of stacking mappings (like scopes) | |
values = ChainMap() | |
values["x"] = 1 | |
# Add a new mapping | |
values = values.new_child() | |
values["x"] = 2 | |
# Add a new mapping | |
values = values.new_child() | |
values["x"] = 3 | |
print(values) | |
print(values["x"]) | |
# Discard last mapping | |
values = values.parents | |
print(values) | |
print(values["x"]) | |
# Discard last mapping | |
values = values.parents | |
print(values) | |
print(values["x"]) | |
################################################################################ | |
## 10::loading_modules_from_a_remote_machine_using_import_hooks | |
# Example of explicit module loading using imp library | |
import imp | |
import urllib.request | |
import sys | |
def load_module(url): | |
u = urllib.request.urlopen(url) | |
source = u.read().decode("utf-8") | |
mod = sys.modules.setdefault(url, imp.new_module(url)) | |
code = compile(source, url, "exec") | |
mod.__file__ = url | |
mod.__package__ = "" | |
exec(code, mod.__dict__) | |
return mod | |
if __name__ == "__main__": | |
fib = load_module("http://localhost:15000/fib.py") | |
print(fib.fib(10)) | |
spam = load_module("http://localhost:15000/spam.py") | |
spam.hello("Guido") | |
print(fib) | |
print(spam) | |
################################################################################ | |
## 10::loading_modules_from_a_remote_machine_using_import_hooks | |
# metaexample.py | |
# | |
# Example of using a meta-path importer | |
# Enable for debugging | |
if False: | |
import logging | |
logging.basicConfig(level=logging.DEBUG) | |
import urlimport | |
urlimport.install_meta("http://localhost:15000") | |
import fib | |
import spam | |
import grok.blah | |
print(grok.blah.__file__) | |
################################################################################ | |
## 10::loading_modules_from_a_remote_machine_using_import_hooks | |
# Example of path-path import hook | |
# Enable for debugging | |
if False: | |
import logging | |
logging.basicConfig(level=logging.DEBUG) | |
import urlimport | |
urlimport.install_path_hook() | |
import sys | |
sys.path.append("http://localhost:15000") | |
import fib | |
import spam | |
import grok.blah | |
print(grok.blah.__file__) | |
################################################################################ | |
## 10::loading_modules_from_a_remote_machine_using_import_hooks | |
# urlimport.py | |
import sys | |
import importlib.abc | |
import imp | |
from urllib.request import urlopen | |
from urllib.error import HTTPError, URLError | |
from html.parser import HTMLParser | |
# Debugging | |
import logging | |
log = logging.getLogger(__name__) | |
# Get links from a given URL | |
def _get_links(url): | |
class LinkParser(HTMLParser): | |
def handle_starttag(self, tag, attrs): | |
if tag == "a": | |
attrs = dict(attrs) | |
links.add(attrs.get("href").rstrip("/")) | |
links = set() | |
try: | |
log.debug("Getting links from %s" % url) | |
u = urlopen(url) | |
parser = LinkParser() | |
parser.feed(u.read().decode("utf-8")) | |
except Exception as e: | |
log.debug("Could not get links. %s", e) | |
log.debug("links: %r", links) | |
return links | |
class UrlMetaFinder(importlib.abc.MetaPathFinder): | |
def __init__(self, baseurl): | |
self._baseurl = baseurl | |
self._links = {} | |
self._loaders = {baseurl: UrlModuleLoader(baseurl)} | |
def find_module(self, fullname, path=None): | |
log.debug("find_module: fullname=%r, path=%r", fullname, path) | |
if path is None: | |
baseurl = self._baseurl | |
else: | |
if not path[0].startswith(self._baseurl): | |
return None | |
baseurl = path[0] | |
parts = fullname.split(".") | |
basename = parts[-1] | |
log.debug("find_module: baseurl=%r, basename=%r", baseurl, basename) | |
# Check link cache | |
if basename not in self._links: | |
self._links[baseurl] = _get_links(baseurl) | |
# Check if it's a package | |
if basename in self._links[baseurl]: | |
log.debug("find_module: trying package %r", fullname) | |
fullurl = self._baseurl + "/" + basename | |
# Attempt to load the package (which accesses __init__.py) | |
loader = UrlPackageLoader(fullurl) | |
try: | |
loader.load_module(fullname) | |
self._links[fullurl] = _get_links(fullurl) | |
self._loaders[fullurl] = UrlModuleLoader(fullurl) | |
log.debug("find_module: package %r loaded", fullname) | |
except ImportError as e: | |
log.debug("find_module: package failed. %s", e) | |
loader = None | |
return loader | |
# A normal module | |
filename = basename + ".py" | |
if filename in self._links[baseurl]: | |
log.debug("find_module: module %r found", fullname) | |
return self._loaders[baseurl] | |
else: | |
log.debug("find_module: module %r not found", fullname) | |
return None | |
def invalidate_caches(self): | |
log.debug("invalidating link cache") | |
self._links.clear() | |
# Module Loader for a URL | |
class UrlModuleLoader(importlib.abc.SourceLoader): | |
def __init__(self, baseurl): | |
self._baseurl = baseurl | |
self._source_cache = {} | |
def module_repr(self, module): | |
return "<urlmodule %r from %r>" % (module.__name__, module.__file__) | |
# Required method | |
def load_module(self, fullname): | |
code = self.get_code(fullname) | |
mod = sys.modules.setdefault(fullname, imp.new_module(fullname)) | |
mod.__file__ = self.get_filename(fullname) | |
mod.__loader__ = self | |
mod.__package__ = fullname.rpartition(".")[0] | |
exec(code, mod.__dict__) | |
return mod | |
# Optional extensions | |
def get_code(self, fullname): | |
src = self.get_source(fullname) | |
return compile(src, self.get_filename(fullname), "exec") | |
def get_data(self, path): | |
pass | |
def get_filename(self, fullname): | |
return self._baseurl + "/" + fullname.split(".")[-1] + ".py" | |
def get_source(self, fullname): | |
filename = self.get_filename(fullname) | |
log.debug("loader: reading %r", filename) | |
if filename in self._source_cache: | |
log.debug("loader: cached %r", filename) | |
return self._source_cache[filename] | |
try: | |
u = urlopen(filename) | |
source = u.read().decode("utf-8") | |
log.debug("loader: %r loaded", filename) | |
self._source_cache[filename] = source | |
return source | |
except (HTTPError, URLError) as e: | |
log.debug("loader: %r failed. %s", filename, e) | |
raise ImportError("Can't load %s" % filename) | |
def is_package(self, fullname): | |
return False | |
# Package loader for a URL | |
class UrlPackageLoader(UrlModuleLoader): | |
def load_module(self, fullname): | |
mod = super().load_module(fullname) | |
mod.__path__ = [self._baseurl] | |
mod.__package__ = fullname | |
def get_filename(self, fullname): | |
return self._baseurl + "/" + "__init__.py" | |
def is_package(self, fullname): | |
return True | |
# Utility functions for installing/uninstalling the loader | |
_installed_meta_cache = {} | |
def install_meta(address): | |
if address not in _installed_meta_cache: | |
finder = UrlMetaFinder(address) | |
_installed_meta_cache[address] = finder | |
sys.meta_path.append(finder) | |
log.debug("%r installed on sys.meta_path", finder) | |
def remove_meta(address): | |
if address in _installed_meta_cache: | |
finder = _installed_meta_cache.pop(address) | |
sys.meta_path.remove(finder) | |
log.debug("%r removed from sys.meta_path", finder) | |
# Path finder class for a URL | |
class UrlPathFinder(importlib.abc.PathEntryFinder): | |
def __init__(self, baseurl): | |
self._links = None | |
self._loader = UrlModuleLoader(baseurl) | |
self._baseurl = baseurl | |
def find_loader(self, fullname): | |
log.debug("find_loader: %r", fullname) | |
parts = fullname.split(".") | |
basename = parts[-1] | |
# Check link cache | |
if self._links is None: | |
self._links = [] # See discussion | |
self._links = _get_links(self._baseurl) | |
# Check if it's a package | |
if basename in self._links: | |
log.debug("find_loader: trying package %r", fullname) | |
fullurl = self._baseurl + "/" + basename | |
# Attempt to load the package (which accesses __init__.py) | |
loader = UrlPackageLoader(fullurl) | |
try: | |
loader.load_module(fullname) | |
log.debug("find_loader: package %r loaded", fullname) | |
except ImportError as e: | |
log.debug("find_loader: %r is a namespace package", fullname) | |
loader = None | |
return (loader, [fullurl]) | |
# A normal module | |
filename = basename + ".py" | |
if filename in self._links: | |
log.debug("find_loader: module %r found", fullname) | |
return (self._loader, []) | |
else: | |
log.debug("find_loader: module %r not found", fullname) | |
return (None, []) | |
def invalidate_caches(self): | |
log.debug("invalidating link cache") | |
self._links = None | |
# Check path to see if it looks like a URL | |
_url_path_cache = {} | |
def handle_url(path): | |
if path.startswith(("http://", "https://")): | |
log.debug("Handle path? %s. [Yes]", path) | |
if path in _url_path_cache: | |
finder = _url_path_cache[path] | |
else: | |
finder = UrlPathFinder(path) | |
_url_path_cache[path] = finder | |
return finder | |
else: | |
log.debug("Handle path? %s. [No]", path) | |
def install_path_hook(): | |
sys.path_hooks.append(handle_url) | |
sys.path_importer_cache.clear() | |
log.debug("Installing handle_url") | |
def remove_path_hook(): | |
sys.path_hooks.remove(handle_url) | |
sys.path_importer_cache.clear() | |
log.debug("Removing handle_url") | |
################################################################################ | |
## 10::making_separate_directories_import_under_a_common_namespace | |
import sys | |
sys.path.extend(["foo-package", "bar-package"]) | |
import spam.blah | |
import spam.grok | |
################################################################################ | |
## 10::monkeypatching_modules_on_import | |
from postimport import when_imported | |
@when_imported("threading") | |
def warn_threads(mod): | |
print("Threads? Are you crazy?") | |
if __name__ == "__main__": | |
import threading | |
################################################################################ | |
## 10::monkeypatching_modules_on_import | |
from postimport import when_imported | |
from functools import wraps | |
def logged(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
print("Calling", func.__name__, args, kwargs) | |
return func(*args, **kwargs) | |
return wrapper | |
# Example | |
@when_imported("math") | |
def add_logging(mod): | |
mod.cos = logged(mod.cos) | |
mod.sin = logged(mod.sin) | |
if __name__ == "__main__": | |
import math | |
print(math.cos(2)) | |
print(math.sin(2)) | |
################################################################################ | |
## 10::monkeypatching_modules_on_import | |
# postimport.py | |
import importlib | |
import sys | |
from collections import defaultdict | |
_post_import_hooks = defaultdict(list) | |
class PostImportFinder: | |
def __init__(self): | |
self._skip = set() | |
def find_module(self, fullname, path=None): | |
if fullname in self._skip: | |
return None | |
self._skip.add(fullname) | |
return PostImportLoader(self) | |
class PostImportLoader: | |
def __init__(self, finder): | |
self._finder = finder | |
def load_module(self, fullname): | |
importlib.import_module(fullname) | |
module = sys.modules[fullname] | |
for func in _post_import_hooks[fullname]: | |
func(module) | |
self._finder._skip.remove(fullname) | |
return module | |
def when_imported(fullname): | |
def decorate(func): | |
if fullname in sys.modules: | |
func(sys.modules[fullname]) | |
else: | |
_post_import_hooks[fullname].append(func) | |
return func | |
return decorate | |
sys.meta_path.insert(0, PostImportFinder()) | |
################################################################################ | |
## 10::splitting_a_module_into_multiple_files | |
import mymodule | |
a = mymodule.A() | |
a.spam() | |
b = mymodule.B() | |
b.bar() | |
################################################################################ | |
## 11::adding_ssl_to_network_servers | |
# echoclient.py | |
# | |
# An example of a client that connects to an SSL server | |
# and verifies its certificate | |
from socket import socket, AF_INET, SOCK_STREAM | |
import ssl | |
s = socket(AF_INET, SOCK_STREAM) | |
# Wrap with an SSL layer and require the server to present its certificate | |
ssl_s = ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED, ca_certs="server_cert.pem",) | |
ssl_s.connect(("localhost", 20000)) | |
# Communicate with the server | |
ssl_s.send(b"Hello World!") | |
resp = ssl_s.recv(8192) | |
print("Got:", resp) | |
# Done | |
ssl_s.close() | |
################################################################################ | |
## 11::adding_ssl_to_network_servers | |
from socket import socket, AF_INET, SOCK_STREAM | |
from socket import SOL_SOCKET, SO_REUSEADDR | |
import ssl | |
KEYFILE = "server_key.pem" # Private key of the server | |
CERTFILE = "server_cert.pem" # Server certificate (given to client) | |
def echo_client(s): | |
while True: | |
data = s.recv(8192) | |
if data == b"": | |
break | |
s.send(data) | |
s.close() | |
print("Connection closed") | |
def echo_server(address): | |
s = socket(AF_INET, SOCK_STREAM) | |
s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) | |
s.bind(address) | |
s.listen(1) | |
# Wrap with an SSL layer requiring client certs | |
s_ssl = ssl.wrap_socket(s, keyfile=KEYFILE, certfile=CERTFILE, server_side=True) | |
# Wait for connections | |
while True: | |
try: | |
c, a = s_ssl.accept() | |
print("Got connection", c, a) | |
echo_client(c) | |
except Exception as e: | |
print("{}: {}".format(e.__class__.__name__, e)) | |
echo_server(("", 20000)) | |
################################################################################ | |
## 11::adding_ssl_to_network_servers | |
# ssl_xmlrpc_client.py | |
# | |
# An XML-RPC client that verifies the server certificate | |
from xmlrpc.client import SafeTransport, ServerProxy | |
import ssl | |
class VerifyCertSafeTransport(SafeTransport): | |
def __init__(self, cafile, certfile=None, keyfile=None): | |
super().__init__() | |
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) | |
self._ssl_context.load_verify_locations(cafile) | |
if certfile: | |
self._ssl_context.load_cert_chain(certfile, keyfile) | |
self._ssl_context.verify_mode = ssl.CERT_REQUIRED | |
def make_connection(self, host): | |
s = super().make_connection((host, {"context": self._ssl_context})) | |
return s | |
# Create the client proxy | |
s = ServerProxy( | |
"https://localhost:15000", | |
transport=VerifyCertSafeTransport( | |
"server_cert.pem", "client_cert.pem", "client_key.pem" | |
), | |
# transport=VerifyCertSafeTransport('server_cert.pem'), | |
allow_none=True, | |
) | |
s.set("foo", "bar") | |
s.set("spam", [1, 2, 3]) | |
print(s.keys()) | |
print(s.get("foo")) | |
print(s.get("spam")) | |
s.delete("spam") | |
print(s.exists("spam")) | |
################################################################################ | |
## 11::adding_ssl_to_network_servers | |
# ssl_xmlrpc_server.py | |
# | |
# An example of an SSL-XMLRPC Server. | |
import ssl | |
from xmlrpc.server import SimpleXMLRPCServer | |
from sslmixin import SSLMixin | |
class SSLSimpleXMLRPCServer(SSLMixin, SimpleXMLRPCServer): | |
pass | |
class KeyValueServer: | |
_rpc_methods_ = ["get", "set", "delete", "exists", "keys"] | |
def __init__(self, *args, **kwargs): | |
self._data = {} | |
self._serv = SSLSimpleXMLRPCServer(*args, allow_none=True, **kwargs) | |
for name in self._rpc_methods_: | |
self._serv.register_function(getattr(self, name)) | |
def get(self, name): | |
return self._data[name] | |
def set(self, name, value): | |
self._data[name] = value | |
def delete(self, name): | |
del self._data[name] | |
def exists(self, name): | |
return name in self._data | |
def keys(self): | |
return list(self._data) | |
def serve_forever(self): | |
self._serv.serve_forever() | |
if __name__ == "__main__": | |
KEYFILE = "server_key.pem" # Private key of the server | |
CERTFILE = "server_cert.pem" # Server certificate | |
CA_CERTS = "client_cert.pem" # Certificates of accepted clients | |
kvserv = KeyValueServer( | |
("", 15000), | |
keyfile=KEYFILE, | |
certfile=CERTFILE, | |
ca_certs=CA_CERTS, | |
cert_reqs=ssl.CERT_REQUIRED, | |
) | |
kvserv.serve_forever() | |
################################################################################ | |
## 11::adding_ssl_to_network_servers | |
import ssl | |
class SSLMixin: | |
def __init__( | |
self, | |
*args, | |
keyfile=None, | |
certfile=None, | |
ca_certs=None, | |
cert_reqs=ssl.CERT_NONE, | |
**kwargs | |
): | |
self._keyfile = keyfile | |
self._certfile = certfile | |
self._ca_certs = ca_certs | |
self._cert_reqs = cert_reqs | |
super().__init__(*args, **kwargs) | |
def get_request(self): | |
client, addr = super().get_request() | |
client_ssl = ssl.wrap_socket( | |
client, | |
keyfile=self._keyfile, | |
certfile=self._certfile, | |
ca_certs=self._ca_certs, | |
cert_reqs=self._cert_reqs, | |
server_side=True, | |
) | |
return client_ssl, addr | |
################################################################################ | |
## 11::creating_a_simple_rest_based_interface | |
from urllib.request import urlopen | |
u = urlopen("http://localhost:8080/hello?name=Guido") | |
print(u.read().decode("utf-8")) | |
u = urlopen("http://localhost:8080/localtime") | |
print(u.read().decode("utf-8")) | |
################################################################################ | |
## 11::creating_a_simple_rest_based_interface | |
import time | |
_hello_resp = """\ | |
<html> | |
<head> | |
<title>Hello {name}</title> | |
</head> | |
<body> | |
<h1>Hello {name}!</h1> | |
</body> | |
</html>""" | |
def hello_world(environ, start_response): | |
start_response("200 OK", [("Content-type", "text/html")]) | |
params = environ["params"] | |
resp = _hello_resp.format(name=params.get("name")) | |
yield resp.encode("utf-8") | |
_localtime_resp = """\ | |
<?xml version="1.0"?> | |
<time> | |
<year>{t.tm_year}</year> | |
<month>{t.tm_mon}</month> | |
<day>{t.tm_mday}</day> | |
<hour>{t.tm_hour}</hour> | |
<minute>{t.tm_min}</minute> | |
<second>{t.tm_sec}</second> | |
</time>""" | |
def localtime(environ, start_response): | |
start_response("200 OK", [("Content-type", "application/xml")]) | |
resp = _localtime_resp.format(t=time.localtime()) | |
yield resp.encode("utf-8") | |
if __name__ == "__main__": | |
from resty import PathDispatcher | |
from wsgiref.simple_server import make_server | |
# Create the dispatcher and register functions | |
dispatcher = PathDispatcher() | |
dispatcher.register("GET", "/hello", hello_world) | |
dispatcher.register("GET", "/localtime", localtime) | |
# Launch a basic server | |
httpd = make_server("", 8080, dispatcher) | |
print("Serving on port 8080...") | |
httpd.serve_forever() | |
################################################################################ | |
## 11::creating_a_simple_rest_based_interface | |
# resty.py | |
import cgi | |
def notfound_404(environ, start_response): | |
start_response("404 Not Found", [("Content-type", "text/plain")]) | |
return [b"Not Found"] | |
class PathDispatcher: | |
def __init__(self): | |
self.pathmap = {} | |
def __call__(self, environ, start_response): | |
path = environ["PATH_INFO"] | |
params = cgi.FieldStorage(environ["wsgi.input"], environ=environ) | |
method = environ["REQUEST_METHOD"].lower() | |
environ["params"] = {key: params.getvalue(key) for key in params} | |
handler = self.pathmap.get((method, path), notfound_404) | |
return handler(environ, start_response) | |
def register(self, method, path, function): | |
self.pathmap[method.lower(), path] = function | |
return function | |
################################################################################ | |
## 11::creating_a_tcp_server | |
from socket import socket, AF_INET, SOCK_STREAM | |
s = socket(AF_INET, SOCK_STREAM) | |
s.connect(("localhost", 20000)) | |
s.send(b"Hello\n") | |
resp = s.recv(8192) | |
print("Response:", resp) | |
s.close() | |
################################################################################ | |
## 11::creating_a_tcp_server | |
from socketserver import BaseRequestHandler, TCPServer | |
class EchoHandler(BaseRequestHandler): | |
def handle(self): | |
print("Got connection from", self.client_address) | |
while True: | |
msg = self.request.recv(8192) | |
if not msg: | |
break | |
self.request.send(msg) | |
if __name__ == "__main__": | |
serv = TCPServer(("", 20000), EchoHandler) | |
print("Echo server running on port 20000") | |
serv.serve_forever() | |
################################################################################ | |
## 11::creating_a_tcp_server | |
from socketserver import BaseRequestHandler, TCPServer | |
class EchoHandler(BaseRequestHandler): | |
def handle(self): | |
print("Got connection from", self.client_address) | |
while True: | |
msg = self.request.recv(8192) | |
if not msg: | |
break | |
self.request.send(msg) | |
if __name__ == "__main__": | |
serv = TCPServer(("", 20000), EchoHandler) | |
print("Echo server running on port 20000") | |
serv.serve_forever() | |
################################################################################ | |
## 11::creating_a_tcp_server | |
from socketserver import StreamRequestHandler, TCPServer | |
class EchoHandler(StreamRequestHandler): | |
def handle(self): | |
print("Got connection from", self.client_address) | |
# self.rfile is a file-like object for reading | |
for line in self.rfile: | |
# self.wfile is a file-like object for writing | |
self.wfile.write(line) | |
if __name__ == "__main__": | |
serv = TCPServer(("", 20000), EchoHandler) | |
print("Echo server running on port 20000") | |
serv.serve_forever() | |
################################################################################ | |
## 11::creating_a_tcp_server | |
from socketserver import StreamRequestHandler, TCPServer | |
class EchoHandler(StreamRequestHandler): | |
def handle(self): | |
print("Got connection from", self.client_address) | |
# self.rfile is a file-like object for reading | |
for line in self.rfile: | |
# self.wfile is a file-like object for writing | |
self.wfile.write(line) | |
if __name__ == "__main__": | |
import socket | |
serv = TCPServer(("", 20000), EchoHandler, bind_and_activate=False) | |
# Set up various socket options | |
serv.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) | |
# Bind and activate | |
serv.server_bind() | |
serv.server_activate() | |
print("Echo server running on port 20000") | |
serv.serve_forever() | |
################################################################################ | |
## 11::creating_a_tcp_server | |
from socketserver import StreamRequestHandler, TCPServer | |
import socket | |
class EchoHandler(StreamRequestHandler): | |
timeout = 5 | |
rbufsize = -1 | |
wbufsize = 0 | |
disable_nagle_algorithm = False | |
def handle(self): | |
print("Got connection from", self.client_address) | |
# self.rfile is a file-like object for reading | |
try: | |
for line in self.rfile: | |
# self.wfile is a file-like object for writing | |
self.wfile.write(line) | |
except socket.timeout: | |
print("Timed out!") | |
if __name__ == "__main__": | |
serv = TCPServer(("", 20000), EchoHandler) | |
print("Echo server running on port 20000") | |
serv.serve_forever() | |
################################################################################ | |
## 11::creating_a_tcp_server | |
# Echo server using sockets directly | |
from socket import socket, AF_INET, SOCK_STREAM | |
def echo_handler(address, client_sock): | |
print("Got connection from {}".format(address)) | |
while True: | |
msg = client_sock.recv(8192) | |
if not msg: | |
break | |
client_sock.sendall(msg) | |
client_sock.close() | |
def echo_server(address, backlog=5): | |
sock = socket(AF_INET, SOCK_STREAM) | |
sock.bind(address) | |
sock.listen(backlog) | |
while True: | |
client_sock, client_addr = sock.accept() | |
echo_handler(client_addr, client_sock) | |
if __name__ == "__main__": | |
echo_server(("", 20000)) | |
################################################################################ | |
## 11::creating_a_tcp_server | |
from socketserver import StreamRequestHandler, TCPServer | |
class EchoHandler(StreamRequestHandler): | |
def handle(self): | |
print("Got connection from", self.client_address) | |
# self.rfile is a file-like object for reading | |
for line in self.rfile: | |
# self.wfile is a file-like object for writing | |
self.wfile.write(line) | |
if __name__ == "__main__": | |
from threading import Thread | |
NWORKERS = 16 | |
serv = TCPServer(("", 20000), EchoHandler) | |
for n in range(NWORKERS): | |
t = Thread(target=serv.serve_forever) | |
t.daemon = True | |
t.start() | |
print("Multithreaded server running on port 20000") | |
serv.serve_forever() | |
################################################################################ | |
## 11::creating_a_udp_server | |
from socket import socket, AF_INET, SOCK_DGRAM | |
s = socket(AF_INET, SOCK_DGRAM) | |
s.sendto(b"", ("localhost", 20000)) | |
print(s.recvfrom(8192)) | |
################################################################################ | |
## 11::creating_a_udp_server | |
from socketserver import BaseRequestHandler, UDPServer | |
import time | |
class TimeHandler(BaseRequestHandler): | |
def handle(self): | |
print("Got connection from", self.client_address) | |
# Get message and client socket | |
msg, sock = self.request | |
resp = time.ctime() | |
sock.sendto(resp.encode("ascii"), self.client_address) | |
if __name__ == "__main__": | |
serv = UDPServer(("", 20000), TimeHandler) | |
serv.serve_forever() | |
################################################################################ | |
## 11::creating_a_udp_server | |
from socket import socket, AF_INET, SOCK_DGRAM | |
import time | |
def time_server(address): | |
sock = socket(AF_INET, SOCK_DGRAM) | |
sock.bind(address) | |
while True: | |
msg, addr = sock.recvfrom(8192) | |
print("Got message from", addr) | |
resp = time.ctime() | |
sock.sendto(resp.encode("ascii"), addr) | |
if __name__ == "__main__": | |
time_server(("", 20000)) | |
################################################################################ | |
## 11::event_driven_io_explained | |
class EventHandler: | |
def fileno(self): | |
"Return the associated file descriptor" | |
raise NotImplemented("must implement") | |
def wants_to_receive(self): | |
"Return True if receiving is allowed" | |
return False | |
def handle_receive(self): | |
"Perform the receive operation" | |
pass | |
def wants_to_send(self): | |
"Return True if sending is requested" | |
return False | |
def handle_send(self): | |
"Send outgoing data" | |
pass | |
import select | |
def event_loop(handlers): | |
while True: | |
wants_recv = [h for h in handlers if h.wants_to_receive()] | |
wants_send = [h for h in handlers if h.wants_to_send()] | |
can_recv, can_send, _ = select.select(wants_recv, wants_send, []) | |
for h in can_recv: | |
h.handle_receive() | |
for h in can_send: | |
h.handle_send() | |
################################################################################ | |
## 11::event_driven_io_explained | |
from socket import socket, AF_INET, SOCK_STREAM | |
s = socket(AF_INET, SOCK_STREAM) | |
s.connect(("localhost", 16000)) | |
s.send(b"Hello\n") | |
print("Got:", s.recv(8192)) | |
s.close() | |
################################################################################ | |
## 11::event_driven_io_explained | |
# TCP Example | |
import socket | |
from eventhandler import EventHandler, event_loop | |
class TCPServer(EventHandler): | |
def __init__(self, address, client_handler, handler_list): | |
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) | |
self.sock.bind(address) | |
self.sock.listen(1) | |
self.client_handler = client_handler | |
self.handler_list = handler_list | |
def fileno(self): | |
return self.sock.fileno() | |
def wants_to_receive(self): | |
return True | |
def handle_receive(self): | |
client, addr = self.sock.accept() | |
# Add the client to the event loop's handler list | |
self.handler_list.append(self.client_handler(client, self.handler_list)) | |
class TCPClient(EventHandler): | |
def __init__(self, sock, handler_list): | |
self.sock = sock | |
self.handler_list = handler_list | |
self.outgoing = bytearray() | |
def fileno(self): | |
return self.sock.fileno() | |
def close(self): | |
self.sock.close() | |
# Remove myself from the event loop's handler list | |
self.handler_list.remove(self) | |
def wants_to_send(self): | |
return True if self.outgoing else False | |
def handle_send(self): | |
nsent = self.sock.send(self.outgoing) | |
self.outgoing = self.outgoing[nsent:] | |
class TCPEchoClient(TCPClient): | |
def wants_to_receive(self): | |
return True | |
def handle_receive(self): | |
data = self.sock.recv(8192) | |
if not data: | |
self.close() | |
else: | |
self.outgoing.extend(data) | |
if __name__ == "__main__": | |
handlers = [] | |
handlers.append(TCPServer(("", 16000), TCPEchoClient, handlers)) | |
event_loop(handlers) | |
################################################################################ | |
## 11::event_driven_io_explained | |
import socket | |
from concurrent.futures import ThreadPoolExecutor | |
from eventhandler import EventHandler, event_loop | |
class ThreadPoolHandler(EventHandler): | |
def __init__(self, nworkers): | |
self.signal_done_sock, self.done_sock = socket.socketpair() | |
self.pending = [] | |
self.pool = ThreadPoolExecutor(nworkers) | |
def fileno(self): | |
return self.done_sock.fileno() | |
# Callback that executes when the thread is done | |
def _complete(self, callback, r): | |
self.pending.append((callback, r.result())) | |
self.signal_done_sock.send(b"x") | |
# Run a function in a thread pool | |
def run(self, func, args=(), kwargs={}, *, callback): | |
r = self.pool.submit(func, *args, **kwargs) | |
r.add_done_callback(lambda r: self._complete(callback, r)) | |
def wants_to_receive(self): | |
return True | |
# Run callback functions of completed work | |
def handle_receive(self): | |
# Invoke all pending callback functions | |
for callback, result in self.pending: | |
callback(result) | |
self.done_sock.recv(1) | |
self.pending = [] | |
# A really bad fibonacci implementation | |
def fib(n): | |
if n < 2: | |
return 1 | |
else: | |
return fib(n - 1) + fib(n - 2) | |
class UDPServer(EventHandler): | |
def __init__(self, address): | |
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
self.sock.bind(address) | |
def fileno(self): | |
return self.sock.fileno() | |
def wants_to_receive(self): | |
return True | |
class UDPFibServer(UDPServer): | |
def handle_receive(self): | |
msg, addr = self.sock.recvfrom(128) | |
n = int(msg) | |
pool.run(fib, (n,), callback=lambda r: self.respond(r, addr)) | |
def respond(self, result, addr): | |
self.sock.sendto(str(result).encode("ascii"), addr) | |
if __name__ == "__main__": | |
pool = ThreadPoolHandler(16) | |
handlers = [pool, UDPFibServer(("", 16000))] | |
event_loop(handlers) | |
################################################################################ | |
## 11::event_driven_io_explained | |
from socket import * | |
sock = socket(AF_INET, SOCK_DGRAM) | |
for x in range(40): | |
sock.sendto(str(x).encode("ascii"), ("localhost", 16000)) | |
resp = sock.recvfrom(8192) | |
print(resp[0]) | |
################################################################################ | |
## 11::event_driven_io_explained | |
from socket import * | |
s = socket(AF_INET, SOCK_DGRAM) | |
s.sendto(b"", ("localhost", 14000)) | |
print(s.recvfrom(128)) | |
s.sendto(b"Hello", ("localhost", 15000)) | |
print(s.recvfrom(128)) | |
################################################################################ | |
## 11::event_driven_io_explained | |
import socket | |
import time | |
from eventhandler import EventHandler, event_loop | |
class UDPServer(EventHandler): | |
def __init__(self, address): | |
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
self.sock.bind(address) | |
def fileno(self): | |
return self.sock.fileno() | |
def wants_to_receive(self): | |
return True | |
class UDPTimeServer(UDPServer): | |
def handle_receive(self): | |
msg, addr = self.sock.recvfrom(1) | |
self.sock.sendto(time.ctime().encode("ascii"), addr) | |
class UDPEchoServer(UDPServer): | |
def handle_receive(self): | |
msg, addr = self.sock.recvfrom(8192) | |
self.sock.sendto(msg, addr) | |
if __name__ == "__main__": | |
handlers = [UDPTimeServer(("", 14000)), UDPEchoServer(("", 15000))] | |
event_loop(handlers) | |
################################################################################ | |
## 11::implementing_remote_procedure_call | |
import json | |
class RPCProxy: | |
def __init__(self, connection): | |
self._connection = connection | |
def __getattr__(self, name): | |
def do_rpc(*args, **kwargs): | |
self._connection.send(json.dumps((name, args, kwargs))) | |
result = json.loads(self._connection.recv()) | |
return result | |
return do_rpc | |
# Example use | |
from multiprocessing.connection import Client | |
c = Client(("localhost", 17000), authkey=b"peekaboo") | |
proxy = RPCProxy(c) | |
print(proxy.add(2, 3)) | |
print(proxy.sub(2, 3)) | |
try: | |
print(proxy.sub([1, 2], 4)) | |
except Exception as e: | |
print(e) | |
################################################################################ | |
## 11::implementing_remote_procedure_call | |
# rpcserver.py | |
import json | |
class RPCHandler: | |
def __init__(self): | |
self._functions = {} | |
def register_function(self, func): | |
self._functions[func.__name__] = func | |
def handle_connection(self, connection): | |
try: | |
while True: | |
# Receive a message | |
func_name, args, kwargs = json.loads(connection.recv()) | |
# Run the RPC and send a response | |
try: | |
r = self._functions[func_name](*args, **kwargs) | |
connection.send(json.dumps(r)) | |
except Exception as e: | |
connection.send(json.dumps(str(e))) | |
except EOFError: | |
pass | |
# Example use | |
from multiprocessing.connection import Listener | |
from threading import Thread | |
def rpc_server(handler, address, authkey): | |
sock = Listener(address, authkey=authkey) | |
while True: | |
client = sock.accept() | |
t = Thread(target=handler.handle_connection, args=(client,)) | |
t.daemon = True | |
t.start() | |
# Some remote functions | |
def add(x, y): | |
return x + y | |
def sub(x, y): | |
return x - y | |
# Register with a handler | |
handler = RPCHandler() | |
handler.register_function(add) | |
handler.register_function(sub) | |
# Run the server | |
rpc_server(handler, ("localhost", 17000), authkey=b"peekaboo") | |
################################################################################ | |
## 11::implementing_remote_procedure_call | |
import pickle | |
class RPCProxy: | |
def __init__(self, connection): | |
self._connection = connection | |
def __getattr__(self, name): | |
def do_rpc(*args, **kwargs): | |
self._connection.send(pickle.dumps((name, args, kwargs))) | |
result = pickle.loads(self._connection.recv()) | |
if isinstance(result, Exception): | |
raise result | |
return result | |
return do_rpc | |
# Example use | |
from multiprocessing.connection import Client | |
c = Client(("localhost", 17000), authkey=b"peekaboo") | |
proxy = RPCProxy(c) | |
print(proxy.add(2, 3)) | |
print(proxy.sub(2, 3)) | |
try: | |
proxy.sub([1, 2], 4) | |
except Exception as e: | |
print(e) | |
################################################################################ | |
## 11::implementing_remote_procedure_call | |
# rpcserver.py | |
import pickle | |
class RPCHandler: | |
def __init__(self): | |
self._functions = {} | |
def register_function(self, func): | |
self._functions[func.__name__] = func | |
def handle_connection(self, connection): | |
try: | |
while True: | |
# Receive a message | |
func_name, args, kwargs = pickle.loads(connection.recv()) | |
# Run the RPC and send a response | |
try: | |
r = self._functions[func_name](*args, **kwargs) | |
connection.send(pickle.dumps(r)) | |
except Exception as e: | |
connection.send(pickle.dumps(e)) | |
except EOFError: | |
pass | |
# Example use | |
from multiprocessing.connection import Listener | |
from threading import Thread | |
def rpc_server(handler, address, authkey): | |
sock = Listener(address, authkey=authkey) | |
while True: | |
client = sock.accept() | |
t = Thread(target=handler.handle_connection, args=(client,)) | |
t.daemon = True | |
t.start() | |
# Some remote functions | |
def add(x, y): | |
return x + y | |
def sub(x, y): | |
return x - y | |
# Register with a handler | |
handler = RPCHandler() | |
handler.register_function(add) | |
handler.register_function(sub) | |
# Run the server | |
rpc_server(handler, ("localhost", 17000), authkey=b"peekaboo") | |
################################################################################ | |
## 11::interacting_with_http_services_as_a_client | |
# A basic GET request | |
from urllib import request, parse | |
# Base URL being accessed | |
url = "http://httpbin.org/get" | |
# Dictionary of query parameters (if any) | |
parms = {"name1": "value1", "name2": "value2"} | |
# Encode the query string | |
querystring = parse.urlencode(parms) | |
# Make a GET request and read the response | |
u = request.urlopen(url + "?" + querystring) | |
resp = u.read() | |
import json | |
from pprint import pprint | |
json_resp = json.loads(resp.decode("utf-8")) | |
pprint(json_resp) | |
################################################################################ | |
## 11::interacting_with_http_services_as_a_client | |
# A basic POST request | |
from urllib import request, parse | |
# Base URL being accessed | |
url = "http://httpbin.org/post" | |
# Dictionary of query parameters (if any) | |
parms = {"name1": "value1", "name2": "value2"} | |
# Encode the query string | |
querystring = parse.urlencode(parms) | |
# Make a POST request and read the response | |
u = request.urlopen(url, querystring.encode("ascii")) | |
resp = u.read() | |
import json | |
from pprint import pprint | |
json_resp = json.loads(resp.decode("utf-8")) | |
pprint(json_resp) | |
################################################################################ | |
## 11::interacting_with_http_services_as_a_client | |
# A POST request using requests library | |
import requests | |
# Base URL being accessed | |
url = "http://httpbin.org/post" | |
# Dictionary of query parameters (if any) | |
parms = {"name1": "value1", "name2": "value2"} | |
# Extra headers | |
headers = {"User-agent": "none/ofyourbusiness", "Spam": "Eggs"} | |
resp = requests.post(url, data=parms, headers=headers) | |
# Decoded text returned by the request | |
text = resp.text | |
from pprint import pprint | |
pprint(resp.json) | |
################################################################################ | |
## 11::interacting_with_http_services_as_a_client | |
# Example of a HEAD request | |
import requests | |
resp = requests.head("http://www.python.org/index.html") | |
status = resp.status_code | |
last_modified = resp.headers["last-modified"] | |
content_type = resp.headers["content-type"] | |
content_length = resp.headers["content-length"] | |
print(status) | |
print(last_modified) | |
print(content_type) | |
print(content_length) | |
################################################################################ | |
## 11::passing_a_socket_file_descriptor_between_processes | |
from socket import socket, AF_INET, SOCK_STREAM | |
s = socket(AF_INET, SOCK_STREAM) | |
s.connect(("localhost", 15000)) | |
s.send(b"Hello\n") | |
print("Got:", s.recv(8192)) | |
s.send(b"World\n") | |
print("Got:", s.recv(8192)) | |
s.close() | |
################################################################################ | |
## 11::passing_a_socket_file_descriptor_between_processes | |
# server.py | |
import socket | |
import struct | |
def send_fd(sock, fd): | |
""" | |
Send a single file descriptor. | |
""" | |
sock.sendmsg([b"x"], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, struct.pack("i", fd))]) | |
ack = sock.recv(2) | |
assert ack == b"OK" | |
def server(work_address, port): | |
# Wait for the worker to connect | |
work_serv = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
work_serv.bind(work_address) | |
work_serv.listen(1) | |
worker, addr = work_serv.accept() | |
# Now run a TCP/IP server and send clients to worker | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) | |
s.bind(("", port)) | |
s.listen(1) | |
while True: | |
client, addr = s.accept() | |
print("SERVER: Got connection from", addr) | |
send_fd(worker, client.fileno()) | |
client.close() | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) != 3: | |
print("Usage: server.py server_address port", file=sys.stderr) | |
raise SystemExit(1) | |
server(sys.argv[1], int(sys.argv[2])) | |
################################################################################ | |
## 11::passing_a_socket_file_descriptor_between_processes | |
# Example of file descriptor passing using multiprocessing | |
import multiprocessing | |
from multiprocessing.reduction import recv_handle, send_handle | |
import socket | |
def worker(in_p, out_p): | |
out_p.close() | |
while True: | |
fd = recv_handle(in_p) | |
print("CHILD: GOT FD", fd) | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=fd) as s: | |
while True: | |
msg = s.recv(1024) | |
if not msg: | |
break | |
print("CHILD: RECV {!r}".format(msg)) | |
s.send(msg) | |
def server(address, in_p, out_p, worker_pid): | |
in_p.close() | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) | |
s.bind(address) | |
s.listen(1) | |
while True: | |
client, addr = s.accept() | |
print("SERVER: Got connection from", addr) | |
send_handle(out_p, client.fileno(), worker_pid) | |
client.close() | |
if __name__ == "__main__": | |
c1, c2 = multiprocessing.Pipe() | |
worker_p = multiprocessing.Process(target=worker, args=(c1, c2)) | |
worker_p.start() | |
server_p = multiprocessing.Process( | |
target=server, args=(("", 15000), c1, c2, worker_p.pid) | |
) | |
server_p.start() | |
c1.close() | |
c2.close() | |
################################################################################ | |
## 11::passing_a_socket_file_descriptor_between_processes | |
# servermp.py | |
from multiprocessing.connection import Listener | |
from multiprocessing.reduction import send_handle | |
import socket | |
def server(work_address, port): | |
# Wait for the worker to connect | |
work_serv = Listener(work_address, authkey=b"peekaboo") | |
worker = work_serv.accept() | |
worker_pid = worker.recv() | |
# Now run a TCP/IP server and send clients to worker | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) | |
s.bind(("", port)) | |
s.listen(1) | |
while True: | |
client, addr = s.accept() | |
print("SERVER: Got connection from", addr) | |
send_handle(worker, client.fileno(), worker_pid) | |
client.close() | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) != 3: | |
print("Usage: server.py server_address port", file=sys.stderr) | |
raise SystemExit(1) | |
server(sys.argv[1], int(sys.argv[2])) | |
################################################################################ | |
## 11::passing_a_socket_file_descriptor_between_processes | |
# worker.py | |
import socket | |
import struct | |
def recv_fd(sock): | |
""" | |
Receive a single file descriptor | |
""" | |
msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_LEN(struct.calcsize("i"))) | |
cmsg_level, cmsg_type, cmsg_data = ancdata[0] | |
assert cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS | |
sock.sendall(b"OK") | |
return struct.unpack("i", cmsg_data)[0] | |
def worker(server_address): | |
serv = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) | |
serv.connect(server_address) | |
while True: | |
fd = recv_fd(serv) | |
print("WORKER: GOT FD", fd) | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=fd) as client: | |
while True: | |
msg = client.recv(1024) | |
if not msg: | |
break | |
print("WORKER: RECV {!r}".format(msg)) | |
client.send(msg) | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) != 2: | |
print("Usage: worker.py server_address", file=sys.stderr) | |
raise SystemExit(1) | |
worker(sys.argv[1]) | |
################################################################################ | |
## 11::passing_a_socket_file_descriptor_between_processes | |
# workermp.py | |
from multiprocessing.connection import Client | |
from multiprocessing.reduction import recv_handle | |
import os | |
import socket | |
def worker(server_address): | |
serv = Client(server_address, authkey=b"peekaboo") | |
serv.send(os.getpid()) | |
while True: | |
fd = recv_handle(serv) | |
print("WORKER: GOT FD", fd) | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=fd) as client: | |
while True: | |
msg = client.recv(1024) | |
if not msg: | |
break | |
print("WORKER: RECV {!r}".format(msg)) | |
client.send(msg) | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) != 2: | |
print("Usage: worker.py server_address", file=sys.stderr) | |
raise SystemExit(1) | |
worker(sys.argv[1]) | |
################################################################################ | |
## 11::simple_authentication_of_clients | |
# auth.py | |
import hmac | |
import os | |
def client_authenticate(connection, secret_key): | |
""" | |
Authenticate client to a remote service. | |
connection represents a network connection. | |
secret_key is a key known only to both client/server. | |
""" | |
message = connection.recv(32) | |
hash = hmac.new(secret_key, message) | |
digest = hash.digest() | |
connection.send(digest) | |
def server_authenticate(connection, secret_key): | |
""" | |
Request client authentication. | |
""" | |
message = os.urandom(32) | |
connection.send(message) | |
hash = hmac.new(secret_key, message) | |
digest = hash.digest() | |
response = connection.recv(len(digest)) | |
return hmac.compare_digest(digest, response) | |
################################################################################ | |
## 11::simple_authentication_of_clients | |
from socket import socket, AF_INET, SOCK_STREAM | |
from auth import client_authenticate | |
secret_key = b"peekaboo" | |
s = socket(AF_INET, SOCK_STREAM) | |
s.connect(("localhost", 18000)) | |
client_authenticate(s, secret_key) | |
s.send(b"Hello World") | |
resp = s.recv(1024) | |
print("Got:", resp) | |
################################################################################ | |
## 11::simple_authentication_of_clients | |
from socket import socket, AF_INET, SOCK_STREAM | |
from auth import server_authenticate | |
secret_key = b"peekaboo" | |
def echo_handler(client_sock): | |
if not server_authenticate(client_sock, secret_key): | |
client_sock.close() | |
return | |
while True: | |
msg = client_sock.recv(8192) | |
if not msg: | |
break | |
client_sock.sendall(msg) | |
def echo_server(address): | |
s = socket(AF_INET, SOCK_STREAM) | |
s.bind(address) | |
s.listen(5) | |
while True: | |
c, a = s.accept() | |
echo_handler(c) | |
print("Echo server running on port 18000") | |
echo_server(("", 18000)) | |
################################################################################ | |
## 11::simple_communication_between_interpreters | |
from multiprocessing.connection import Client | |
c = Client(("localhost", 25000), authkey=b"peekaboo") | |
c.send("hello") | |
print("Got:", c.recv()) | |
c.send(42) | |
print("Got:", c.recv()) | |
c.send([1, 2, 3, 4, 5]) | |
print("Got:", c.recv()) | |
################################################################################ | |
## 11::simple_communication_between_interpreters | |
from multiprocessing.connection import Listener | |
import traceback | |
def echo_client(conn): | |
try: | |
while True: | |
msg = conn.recv() | |
conn.send(msg) | |
except EOFError: | |
print("Connection closed") | |
def echo_server(address, authkey): | |
serv = Listener(address, authkey=authkey) | |
while True: | |
try: | |
client = serv.accept() | |
echo_client(client) | |
except Exception: | |
traceback.print_exc() | |
echo_server(("", 25000), authkey=b"peekaboo") | |
################################################################################ | |
## 11::simple_remote_procedure_call_with_xmlrpc | |
from xmlrpc.client import ServerProxy | |
s = ServerProxy("http://localhost:15000", allow_none=True) | |
s.set("foo", "bar") | |
s.set("spam", [1, 2, 3]) | |
print(s.keys()) | |
print(s.get("foo")) | |
print(s.get("spam")) | |
s.delete("spam") | |
print(s.exists("spam")) | |
################################################################################ | |
## 11::simple_remote_procedure_call_with_xmlrpc | |
from xmlrpc.server import SimpleXMLRPCServer | |
class KeyValueServer: | |
_rpc_methods_ = ["get", "set", "delete", "exists", "keys"] | |
def __init__(self, address): | |
self._data = {} | |
self._serv = SimpleXMLRPCServer(address, allow_none=True) | |
for name in self._rpc_methods_: | |
self._serv.register_function(getattr(self, name)) | |
def get(self, name): | |
return self._data[name] | |
def set(self, name, value): | |
self._data[name] = value | |
def delete(self, name): | |
del self._data[name] | |
def exists(self, name): | |
return name in self._data | |
def keys(self): | |
return list(self._data) | |
def serve_forever(self): | |
self._serv.serve_forever() | |
# Example | |
if __name__ == "__main__": | |
kvserv = KeyValueServer(("", 15000)) | |
kvserv.serve_forever() | |
################################################################################ | |
## 11::zero_copy_sending_and_receiving_of_large_arrays | |
from zerocopy import recv_into | |
from socket import * | |
c = socket(AF_INET, SOCK_STREAM) | |
c.connect(("localhost", 25000)) | |
import numpy | |
a = numpy.zeros(shape=50000000, dtype=float) | |
print(a[0:10]) | |
recv_into(a, c) | |
print(a[0:10]) | |
print(a[-10:]) | |
################################################################################ | |
## 11::zero_copy_sending_and_receiving_of_large_arrays | |
from zerocopy import send_from | |
from socket import * | |
s = socket(AF_INET, SOCK_STREAM) | |
s.bind(("", 25000)) | |
s.listen(1) | |
c, a = s.accept() | |
import numpy | |
a = numpy.arange(0.0, 50000000.0) | |
send_from(a, c) | |
c.close() | |
################################################################################ | |
## 11::zero_copy_sending_and_receiving_of_large_arrays | |
# zerocopy.py | |
def send_from(arr, dest): | |
view = memoryview(arr).cast("B") | |
while len(view): | |
nsent = dest.send(view) | |
view = view[nsent:] | |
def recv_into(arr, source): | |
view = memoryview(arr).cast("B") | |
while len(view): | |
nrecv = source.recv_into(view) | |
view = view[nrecv:] | |
################################################################################ | |
## 12::defining_an_actor_task | |
from queue import Queue | |
from threading import Thread, Event | |
# Sentinel used for shutdown | |
class ActorExit(Exception): | |
pass | |
class Actor: | |
def __init__(self): | |
self._mailbox = Queue() | |
def send(self, msg): | |
""" | |
Send a message to the actor | |
""" | |
self._mailbox.put(msg) | |
def recv(self): | |
""" | |
Receive an incoming message | |
""" | |
msg = self._mailbox.get() | |
if msg is ActorExit: | |
raise ActorExit() | |
return msg | |
def close(self): | |
""" | |
Close the actor, thus shutting it down | |
""" | |
self.send(ActorExit) | |
def start(self): | |
""" | |
Start concurrent execution | |
""" | |
self._terminated = Event() | |
t = Thread(target=self._bootstrap) | |
t.daemon = True | |
t.start() | |
def _bootstrap(self): | |
try: | |
self.run() | |
except ActorExit: | |
pass | |
finally: | |
self._terminated.set() | |
def join(self): | |
self._terminated.wait() | |
def run(self): | |
""" | |
Run method to be implemented by the user | |
""" | |
while True: | |
msg = self.recv() | |
# Sample ActorTask | |
class PrintActor(Actor): | |
def run(self): | |
while True: | |
msg = self.recv() | |
print("Got:", msg) | |
if __name__ == "__main__": | |
# Sample use | |
p = PrintActor() | |
p.start() | |
p.send("Hello") | |
p.send("World") | |
p.close() | |
p.join() | |
################################################################################ | |
## 12::defining_an_actor_task | |
from actor import Actor | |
class TaggedActor(Actor): | |
def run(self): | |
while True: | |
tag, *payload = self.recv() | |
getattr(self, "do_" + tag)(*payload) | |
# Methods correponding to different message tags | |
def do_A(self, x): | |
print("Running A", x) | |
def do_B(self, x, y): | |
print("Running B", x, y) | |
# Example | |
if __name__ == "__main__": | |
a = TaggedActor() | |
a.start() | |
a.send(("A", 1)) # Invokes do_A(1) | |
a.send(("B", 2, 3)) # Invokes do_B(2,3) | |
a.close() | |
a.join() | |
################################################################################ | |
## 12::defining_an_actor_task | |
from actor import Actor | |
from threading import Event | |
class Result: | |
def __init__(self): | |
self._evt = Event() | |
self._result = None | |
def set_result(self, value): | |
self._result = value | |
self._evt.set() | |
def result(self): | |
self._evt.wait() | |
return self._result | |
class Worker(Actor): | |
def submit(self, func, *args, **kwargs): | |
r = Result() | |
self.send((func, args, kwargs, r)) | |
return r | |
def run(self): | |
while True: | |
func, args, kwargs, r = self.recv() | |
r.set_result(func(*args, **kwargs)) | |
# Example use | |
if __name__ == "__main__": | |
worker = Worker() | |
worker.start() | |
r = worker.submit(pow, 2, 3) | |
print(r.result()) | |
worker.close() | |
worker.join() | |
################################################################################ | |
## 12::how_to_communicate_between_threads | |
from queue import Queue | |
from threading import Thread | |
import time | |
_sentinel = object() | |
# A thread that produces data | |
def producer(out_q): | |
n = 10 | |
while n > 0: | |
# Produce some data | |
out_q.put(n) | |
time.sleep(2) | |
n -= 1 | |
# Put the sentinel on the queue to indicate completion | |
out_q.put(_sentinel) | |
# A thread that consumes data | |
def consumer(in_q): | |
while True: | |
# Get some data | |
data = in_q.get() | |
# Check for termination | |
if data is _sentinel: | |
in_q.put(_sentinel) | |
break | |
# Process the data | |
print("Got:", data) | |
print("Consumer shutting down") | |
if __name__ == "__main__": | |
q = Queue() | |
t1 = Thread(target=consumer, args=(q,)) | |
t2 = Thread(target=producer, args=(q,)) | |
t1.start() | |
t2.start() | |
t1.join() | |
t2.join() | |
################################################################################ | |
## 12::how_to_communicate_between_threads | |
import heapq | |
import threading | |
import time | |
class PriorityQueue: | |
def __init__(self): | |
self._queue = [] | |
self._count = 0 | |
self._cv = threading.Condition() | |
def put(self, item, priority): | |
with self._cv: | |
heapq.heappush(self._queue, (-priority, self._count, item)) | |
self._count += 1 | |
self._cv.notify() | |
def get(self): | |
with self._cv: | |
while len(self._queue) == 0: | |
self._cv.wait() | |
return heapq.heappop(self._queue)[-1] | |
def producer(q): | |
print("Producing items") | |
q.put("C", 5) | |
q.put("A", 15) | |
q.put("B", 10) | |
q.put("D", 0) | |
q.put(None, -100) | |
def consumer(q): | |
time.sleep(5) | |
print("Getting items") | |
while True: | |
item = q.get() | |
if item is None: | |
break | |
print("Got:", item) | |
print("Consumer done") | |
if __name__ == "__main__": | |
q = PriorityQueue() | |
t1 = threading.Thread(target=producer, args=(q,)) | |
t2 = threading.Thread(target=consumer, args=(q,)) | |
t1.start() | |
t2.start() | |
t1.join() | |
t2.join() | |
################################################################################ | |
## 12::how_to_create_a_thread_pool | |
from socket import AF_INET, SOCK_STREAM, socket | |
from concurrent.futures import ThreadPoolExecutor | |
def echo_client(sock, client_addr): | |
""" | |
Handle a client connection | |
""" | |
print("Got connection from", client_addr) | |
while True: | |
msg = sock.recv(65536) | |
if not msg: | |
break | |
sock.sendall(msg) | |
print("Client closed connection") | |
sock.close() | |
def echo_server(addr): | |
print("Echo server running at", addr) | |
pool = ThreadPoolExecutor(128) | |
sock = socket(AF_INET, SOCK_STREAM) | |
sock.bind(addr) | |
sock.listen(5) | |
while True: | |
client_sock, client_addr = sock.accept() | |
pool.submit(echo_client, client_sock, client_addr) | |
echo_server(("", 15000)) | |
################################################################################ | |
## 12::how_to_create_a_thread_pool | |
from socket import socket, AF_INET, SOCK_STREAM | |
from threading import Thread | |
from queue import Queue | |
def echo_client(q): | |
""" | |
Handle a client connection | |
""" | |
sock, client_addr = q.get() | |
print("Got connection from", client_addr) | |
while True: | |
msg = sock.recv(65536) | |
if not msg: | |
break | |
sock.sendall(msg) | |
print("Client closed connection") | |
sock.close() | |
def echo_server(addr, nworkers): | |
print("Echo server running at", addr) | |
# Launch the client workers | |
q = Queue() | |
for n in range(nworkers): | |
t = Thread(target=echo_client, args=(q,)) | |
t.daemon = True | |
t.start() | |
# Run the server | |
sock = socket(AF_INET, SOCK_STREAM) | |
sock.bind(addr) | |
sock.listen(5) | |
while True: | |
client_sock, client_addr = sock.accept() | |
q.put((client_sock, client_addr)) | |
echo_server(("", 15000), 128) | |
################################################################################ | |
## 12::how_to_create_a_thread_pool | |
from concurrent.futures import ThreadPoolExecutor | |
import urllib.request | |
def fetch_url(url): | |
u = urllib.request.urlopen(url) | |
data = u.read() | |
return data | |
pool = ThreadPoolExecutor(10) | |
# Submit work to the pool | |
a = pool.submit(fetch_url, "http://www.python.org") | |
b = pool.submit(fetch_url, "http://www.pypy.org") | |
# Get the results back | |
x = a.result() | |
y = b.result() | |
################################################################################ | |
## 12::how_to_determine_if_a_thread_has_started | |
from threading import Thread, Event | |
import time | |
# Code to execute in an independent thread | |
def countdown(n, started_evt): | |
print("countdown starting") | |
started_evt.set() | |
while n > 0: | |
print("T-minus", n) | |
n -= 1 | |
time.sleep(5) | |
# Create the event object that will be used to signal startup | |
started_evt = Event() | |
# Launch the thread and pass the startup event | |
print("Launching countdown") | |
t = Thread(target=countdown, args=(10, started_evt)) | |
t.start() | |
# Wait for the thread to start | |
started_evt.wait() | |
print("countdown is running") | |
################################################################################ | |
## 12::how_to_determine_if_a_thread_has_started | |
import threading | |
import time | |
class PeriodicTimer: | |
def __init__(self, interval): | |
self._interval = interval | |
self._flag = 0 | |
self._cv = threading.Condition() | |
def start(self): | |
t = threading.Thread(target=self.run) | |
t.daemon = True | |
t.start() | |
def run(self): | |
""" | |
Run the timer and notify waiting threads after each interval | |
""" | |
while True: | |
time.sleep(self._interval) | |
with self._cv: | |
self._flag ^= 1 | |
self._cv.notify_all() | |
def wait_for_tick(self): | |
""" | |
Wait for the next tick of the timer | |
""" | |
with self._cv: | |
last_flag = self._flag | |
while last_flag == self._flag: | |
self._cv.wait() | |
# Example use of the timer | |
ptimer = PeriodicTimer(5) | |
ptimer.start() | |
# Two threads that synchronize on the timer | |
def countdown(nticks): | |
while nticks > 0: | |
ptimer.wait_for_tick() | |
print("T-minus", nticks) | |
nticks -= 1 | |
def countup(last): | |
n = 0 | |
while n < last: | |
ptimer.wait_for_tick() | |
print("Counting", n) | |
n += 1 | |
threading.Thread(target=countdown, args=(10,)).start() | |
threading.Thread(target=countup, args=(5,)).start() | |
################################################################################ | |
## 12::how_to_determine_if_a_thread_has_started | |
import threading | |
import time | |
# Worker thread | |
def worker(n, sema): | |
# Wait to be signalled | |
sema.acquire() | |
# Do some work | |
print("Working", n) | |
# Create some threads | |
sema = threading.Semaphore(0) | |
nworkers = 10 | |
for n in range(nworkers): | |
t = threading.Thread(target=worker, args=(n, sema,)) | |
t.daemon = True | |
t.start() | |
print("About to release first worker") | |
time.sleep(5) | |
sema.release() | |
time.sleep(1) | |
print("About to release second worker") | |
time.sleep(5) | |
sema.release() | |
time.sleep(1) | |
print("Goodbye") | |
################################################################################ | |
## 12::how_to_lock_critical_sections | |
import threading | |
class SharedCounter: | |
""" | |
A counter object that can be shared by multiple threads. | |
""" | |
def __init__(self, initial_value=0): | |
self._value = initial_value | |
self._value_lock = threading.Lock() | |
def incr(self, delta=1): | |
""" | |
Increment the counter with locking | |
""" | |
with self._value_lock: | |
self._value += delta | |
def decr(self, delta=1): | |
""" | |
Decrement the counter with locking | |
""" | |
with self._value_lock: | |
self._value -= delta | |
def test(c): | |
for n in range(1000000): | |
c.incr() | |
for n in range(1000000): | |
c.decr() | |
if __name__ == "__main__": | |
c = SharedCounter() | |
t1 = threading.Thread(target=test, args=(c,)) | |
t2 = threading.Thread(target=test, args=(c,)) | |
t3 = threading.Thread(target=test, args=(c,)) | |
t1.start() | |
t2.start() | |
t3.start() | |
print("Running test") | |
t1.join() | |
t2.join() | |
t3.join() | |
assert c._value == 0 | |
print("Looks good!") | |
################################################################################ | |
## 12::how_to_start_and_stop_threads | |
from threading import Thread | |
import time | |
class CountdownTask: | |
def __init__(self): | |
self._running = True | |
def terminate(self): | |
self._running = False | |
def run(self, n): | |
while self._running and n > 0: | |
print("T-minus", n) | |
n -= 1 | |
time.sleep(5) | |
c = CountdownTask() | |
t = Thread(target=c.run, args=(10,)) | |
t.start() | |
time.sleep(20) | |
print("About to terminate") | |
c.terminate() | |
t.join() | |
print("Terminated") | |
################################################################################ | |
## 12::implementing_publish_subscribe_messaging | |
from collections import defaultdict | |
class Exchange: | |
def __init__(self): | |
self._subscribers = set() | |
def attach(self, task): | |
self._subscribers.add(task) | |
def detach(self, task): | |
self._subscribers.remove(task) | |
def send(self, msg): | |
for subscriber in self._subscribers: | |
subscriber.send(msg) | |
# Dictionary of all created exchanges | |
_exchanges = defaultdict(Exchange) | |
# Return the Exchange instance associated with a given name | |
def get_exchange(name): | |
return _exchanges[name] | |
if __name__ == "__main__": | |
# Example task (just for testing) | |
class Task: | |
def __init__(self, name): | |
self.name = name | |
def send(self, msg): | |
print("{} got: {!r}".format(self.name, msg)) | |
task_a = Task("A") | |
task_b = Task("B") | |
exc = get_exchange("spam") | |
exc.attach(task_a) | |
exc.attach(task_b) | |
exc.send("msg1") | |
exc.send("msg2") | |
exc.detach(task_a) | |
exc.detach(task_b) | |
exc.send("msg3") | |
################################################################################ | |
## 12::implementing_publish_subscribe_messaging | |
from contextlib import contextmanager | |
from collections import defaultdict | |
class Exchange: | |
def __init__(self): | |
self._subscribers = set() | |
def attach(self, task): | |
self._subscribers.add(task) | |
def detach(self, task): | |
self._subscribers.remove(task) | |
@contextmanager | |
def subscribe(self, *tasks): | |
for task in tasks: | |
self.attach(task) | |
try: | |
yield | |
finally: | |
for task in tasks: | |
self.detach(task) | |
def send(self, msg): | |
for subscriber in self._subscribers: | |
subscriber.send(msg) | |
# Dictionary of all created exchanges | |
_exchanges = defaultdict(Exchange) | |
# Return the Exchange instance associated with a given name | |
def get_exchange(name): | |
return _exchanges[name] | |
# Example of using the subscribe() method | |
if __name__ == "__main__": | |
# Example task (just for testing) | |
class Task: | |
def __init__(self, name): | |
self.name = name | |
def send(self, msg): | |
print("{} got: {!r}".format(self.name, msg)) | |
task_a = Task("A") | |
task_b = Task("B") | |
exc = get_exchange("spam") | |
with exc.subscribe(task_a, task_b): | |
exc.send("msg1") | |
exc.send("msg2") | |
exc.send("msg3") | |
################################################################################ | |
## 12::launching_a_daemon_process_on_unix | |
#!/usr/bin/env python3 | |
# daemon.py | |
import os | |
import sys | |
import atexit | |
import signal | |
def daemonize(pidfile, *, stdin="/dev/null", stdout="/dev/null", stderr="/dev/null"): | |
if os.path.exists(pidfile): | |
raise RuntimeError("Already running") | |
# First fork (detaches from parent) | |
try: | |
if os.fork() > 0: | |
raise SystemExit(0) # Parent exit | |
except OSError as e: | |
raise RuntimeError("fork #1 failed.") | |
os.chdir("/") | |
os.umask(0) | |
os.setsid() | |
# Second fork (relinquish session leadership) | |
try: | |
if os.fork() > 0: | |
raise SystemExit(0) | |
except OSError as e: | |
raise RuntimeError("fork #2 failed.") | |
# Flush I/O buffers | |
sys.stdout.flush() | |
sys.stderr.flush() | |
# Replace file descriptors for stdin, stdout, and stderr | |
with open(stdin, "rb", 0) as f: | |
os.dup2(f.fileno(), sys.stdin.fileno()) | |
with open(stdout, "ab", 0) as f: | |
os.dup2(f.fileno(), sys.stdout.fileno()) | |
with open(stderr, "ab", 0) as f: | |
os.dup2(f.fileno(), sys.stderr.fileno()) | |
# Write the PID file | |
with open(pidfile, "w") as f: | |
print(os.getpid(), file=f) | |
# Arrange to have the PID file removed on exit/signal | |
atexit.register(lambda: os.remove(pidfile)) | |
# Signal handler for termination (required) | |
def sigterm_handler(signo, frame): | |
raise SystemExit(1) | |
signal.signal(signal.SIGTERM, sigterm_handler) | |
def main(): | |
import time | |
sys.stdout.write("Daemon started with pid {}\n".format(os.getpid())) | |
while True: | |
sys.stdout.write("Daemon Alive! {}\n".format(time.ctime())) | |
time.sleep(10) | |
if __name__ == "__main__": | |
PIDFILE = "/tmp/daemon.pid" | |
if len(sys.argv) != 2: | |
print("Usage: {} [start|stop]".format(sys.argv[0]), file=sys.stderr) | |
raise SystemExit(1) | |
if sys.argv[1] == "start": | |
try: | |
daemonize(PIDFILE, stdout="/tmp/daemon.log", stderr="/tmp/dameon.log") | |
except RuntimeError as e: | |
print(e, file=sys.stderr) | |
raise SystemExit(1) | |
main() | |
elif sys.argv[1] == "stop": | |
if os.path.exists(PIDFILE): | |
with open(PIDFILE) as f: | |
os.kill(int(f.read()), signal.SIGTERM) | |
else: | |
print("Not running", file=sys.stderr) | |
raise SystemExit(1) | |
else: | |
print("Unknown command {!r}".format(sys.argv[1]), file=sys.stderr) | |
raise SystemExit(1) | |
################################################################################ | |
## 12::locking_with_deadlock_avoidance | |
import threading | |
from contextlib import contextmanager | |
# Thread-local state to stored information on locks already acquired | |
_local = threading.local() | |
@contextmanager | |
def acquire(*locks): | |
# Sort locks by object identifier | |
locks = sorted(locks, key=lambda x: id(x)) | |
# Make sure lock order of previously acquired locks is not violated | |
acquired = getattr(_local, "acquired", []) | |
if acquired and max(id(lock) for lock in acquired) >= id(locks[0]): | |
raise RuntimeError("Lock Order Violation") | |
# Acquire all of the locks | |
acquired.extend(locks) | |
_local.acquired = acquired | |
try: | |
for lock in locks: | |
lock.acquire() | |
yield | |
finally: | |
# Release locks in reverse order of acquisition | |
for lock in reversed(locks): | |
lock.release() | |
del acquired[-len(locks) :] | |
################################################################################ | |
## 12::locking_with_deadlock_avoidance | |
import threading | |
from deadlock import acquire | |
x_lock = threading.Lock() | |
y_lock = threading.Lock() | |
def thread_1(): | |
while True: | |
with acquire(x_lock, y_lock): | |
print("Thread-1") | |
def thread_2(): | |
while True: | |
with acquire(y_lock, x_lock): | |
print("Thread-2") | |
input("This program runs forever. Press [return] to start, Ctrl-C to exit") | |
t1 = threading.Thread(target=thread_1) | |
t1.daemon = True | |
t1.start() | |
t2 = threading.Thread(target=thread_2) | |
t2.daemon = True | |
t2.start() | |
import time | |
while True: | |
time.sleep(1) | |
################################################################################ | |
## 12::locking_with_deadlock_avoidance | |
import threading | |
import time | |
from deadlock import acquire | |
x_lock = threading.Lock() | |
y_lock = threading.Lock() | |
def thread_1(): | |
while True: | |
with acquire(x_lock): | |
with acquire(y_lock): | |
print("Thread-1") | |
time.sleep(1) | |
def thread_2(): | |
while True: | |
with acquire(y_lock): | |
with acquire(x_lock): | |
print("Thread-2") | |
time.sleep(1) | |
input("This program crashes with an exception. Press [return] to start") | |
t1 = threading.Thread(target=thread_1) | |
t1.daemon = True | |
t1.start() | |
t2 = threading.Thread(target=thread_2) | |
t2.daemon = True | |
t2.start() | |
time.sleep(5) | |
################################################################################ | |
## 12::locking_with_deadlock_avoidance | |
import threading | |
from deadlock import acquire | |
# The philosopher thread | |
def philosopher(left, right): | |
while True: | |
with acquire(left, right): | |
print(threading.currentThread(), "eating") | |
# The chopsticks (represented by locks) | |
NSTICKS = 5 | |
chopsticks = [threading.Lock() for n in range(NSTICKS)] | |
# Create all of the philosophers | |
for n in range(NSTICKS): | |
t = threading.Thread( | |
target=philosopher, args=(chopsticks[n], chopsticks[(n + 1) % NSTICKS]) | |
) | |
t.daemon = True | |
t.start() | |
import time | |
while True: | |
time.sleep(1) | |
################################################################################ | |
## 12::polling_multiple_thread_queues | |
import queue | |
import socket | |
import os | |
class PollableQueue(queue.Queue): | |
def __init__(self): | |
super().__init__() | |
# Create a pair of connected sockets | |
if os.name == "posix": | |
self._putsocket, self._getsocket = socket.socketpair() | |
else: | |
# Compatibility on non-POSIX systems | |
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
server.bind(("127.0.0.1", 0)) | |
server.listen(1) | |
self._putsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
self._putsocket.connect(server.getsockname()) | |
self._getsocket, _ = server.accept() | |
server.close() | |
def fileno(self): | |
return self._getsocket.fileno() | |
def put(self, item): | |
super().put(item) | |
self._putsocket.send(b"x") | |
def get(self): | |
self._getsocket.recv(1) | |
return super().get() | |
# Example code that performs polling: | |
if __name__ == "__main__": | |
import select | |
import threading | |
import time | |
def consumer(queues): | |
""" | |
Consumer that reads data on multiple queues simultaneously | |
""" | |
while True: | |
can_read, _, _ = select.select(queues, [], []) | |
for r in can_read: | |
item = r.get() | |
print("Got:", item) | |
q1 = PollableQueue() | |
q2 = PollableQueue() | |
q3 = PollableQueue() | |
t = threading.Thread(target=consumer, args=([q1, q2, q3],)) | |
t.daemon = True | |
t.start() | |
# Feed data to the queues | |
q1.put(1) | |
q2.put(10) | |
q3.put("hello") | |
q2.put(15) | |
# Give thread time to run | |
time.sleep(1) | |
################################################################################ | |
## 12::simple_parallel_programming | |
# findrobots.py | |
import gzip | |
import io | |
import glob | |
def find_robots(filename): | |
""" | |
Find all of the hosts that access robots.txt in a single log file | |
""" | |
robots = set() | |
with gzip.open(filename) as f: | |
for line in io.TextIOWrapper(f, encoding="ascii"): | |
fields = line.split() | |
if fields[6] == "/robots.txt": | |
robots.add(fields[0]) | |
return robots | |
def find_all_robots(logdir): | |
""" | |
Find all hosts across and entire sequence of files | |
""" | |
files = glob.glob(logdir + "/*.log.gz") | |
all_robots = set() | |
for robots in map(find_robots, files): | |
all_robots.update(robots) | |
return all_robots | |
if __name__ == "__main__": | |
import time | |
start = time.time() | |
robots = find_all_robots("logs") | |
end = time.time() | |
for ipaddr in robots: | |
print(ipaddr) | |
print("Took {:f} seconds".format(end - start)) | |
################################################################################ | |
## 12::simple_parallel_programming | |
# findrobots.py | |
import gzip | |
import io | |
import glob | |
from concurrent import futures | |
def find_robots(filename): | |
""" | |
Find all of the hosts that access robots.txt in a single log file | |
""" | |
robots = set() | |
with gzip.open(filename) as f: | |
for line in io.TextIOWrapper(f, encoding="ascii"): | |
fields = line.split() | |
if fields[6] == "/robots.txt": | |
robots.add(fields[0]) | |
return robots | |
def find_all_robots(logdir): | |
""" | |
Find all hosts across and entire sequence of files | |
""" | |
files = glob.glob(logdir + "/*.log.gz") | |
all_robots = set() | |
with futures.ProcessPoolExecutor() as pool: | |
for robots in pool.map(find_robots, files): | |
all_robots.update(robots) | |
return all_robots | |
if __name__ == "__main__": | |
import time | |
start = time.time() | |
robots = find_all_robots("logs") | |
end = time.time() | |
for ipaddr in robots: | |
print(ipaddr) | |
print("Took {:f} seconds".format(end - start)) | |
################################################################################ | |
## 12::storing_thread_specific_state | |
from socket import socket, AF_INET, SOCK_STREAM | |
import threading | |
class LazyConnection: | |
def __init__(self, address, family=AF_INET, type=SOCK_STREAM): | |
self.address = address | |
self.family = AF_INET | |
self.type = SOCK_STREAM | |
self.local = threading.local() | |
def __enter__(self): | |
if hasattr(self.local, "sock"): | |
raise RuntimeError("Already connected") | |
self.local.sock = socket(self.family, self.type) | |
self.local.sock.connect(self.address) | |
return self.local.sock | |
def __exit__(self, exc_ty, exc_val, tb): | |
self.local.sock.close() | |
del self.local.sock | |
def test(conn): | |
from functools import partial | |
# Connection closed | |
with conn as s: | |
# conn.__enter__() executes: connection open | |
s.send(b"GET /index.html HTTP/1.0\r\n") | |
s.send(b"Host: www.python.org\r\n") | |
s.send(b"\r\n") | |
resp = b"".join(iter(partial(s.recv, 8192), b"")) | |
# conn.__exit__() executes: connection closed | |
print("Got {} bytes".format(len(resp))) | |
if __name__ == "__main__": | |
conn = LazyConnection(("www.python.org", 80)) | |
t1 = threading.Thread(target=test, args=(conn,)) | |
t2 = threading.Thread(target=test, args=(conn,)) | |
t1.start() | |
t2.start() | |
t1.join() | |
t2.join() | |
################################################################################ | |
## 12::storing_thread_specific_state | |
from socket import socket, AF_INET, SOCK_STREAM | |
import threading | |
class LazyConnection: | |
def __init__(self, address, family=AF_INET, type=SOCK_STREAM): | |
self.address = address | |
self.family = AF_INET | |
self.type = SOCK_STREAM | |
self.local = threading.local() | |
def __enter__(self): | |
sock = socket(self.family, self.type) | |
sock.connect(self.address) | |
if not hasattr(self.local, "connections"): | |
self.local.connections = [] | |
self.local.connections.append(sock) | |
return sock | |
def __exit__(self, exc_ty, exc_val, tb): | |
self.local.connections.pop().close() | |
def test(conn): | |
# Example use | |
from functools import partial | |
with conn as s: | |
s.send(b"GET /index.html HTTP/1.0\r\n") | |
s.send(b"Host: www.python.org\r\n") | |
s.send(b"\r\n") | |
resp = b"".join(iter(partial(s.recv, 8192), b"")) | |
print("Got {} bytes".format(len(resp))) | |
with conn as s1, conn as s2: | |
s1.send(b"GET /downloads HTTP/1.0\r\n") | |
s2.send(b"GET /index.html HTTP/1.0\r\n") | |
s1.send(b"Host: www.python.org\r\n") | |
s2.send(b"Host: www.python.org\r\n") | |
s1.send(b"\r\n") | |
s2.send(b"\r\n") | |
resp1 = b"".join(iter(partial(s1.recv, 8192), b"")) | |
resp2 = b"".join(iter(partial(s2.recv, 8192), b"")) | |
print("resp1 got {} bytes".format(len(resp1))) | |
print("resp2 got {} bytes".format(len(resp2))) | |
if __name__ == "__main__": | |
conn = LazyConnection(("www.python.org", 80)) | |
t1 = threading.Thread(target=test, args=(conn,)) | |
t2 = threading.Thread(target=test, args=(conn,)) | |
t3 = threading.Thread(target=test, args=(conn,)) | |
t1.start() | |
t2.start() | |
t3.start() | |
t1.join() | |
t2.join() | |
t3.join() | |
################################################################################ | |
## 12::using_generators_as_an_alternative_to_threads | |
from collections import deque | |
class ActorScheduler: | |
def __init__(self): | |
self._actors = {} # Mapping of names to actors | |
self._msg_queue = deque() # Message queue | |
def new_actor(self, name, actor): | |
""" | |
Admit a newly started actor to the scheduler and give it a name | |
""" | |
self._msg_queue.append((actor, None)) | |
self._actors[name] = actor | |
def send(self, name, msg): | |
""" | |
Send a message to a named actor | |
""" | |
actor = self._actors.get(name) | |
if actor: | |
self._msg_queue.append((actor, msg)) | |
def run(self): | |
""" | |
Run as long as there are pending messages. | |
""" | |
while self._msg_queue: | |
actor, msg = self._msg_queue.popleft() | |
try: | |
actor.send(msg) | |
except StopIteration: | |
pass | |
# Example use | |
if __name__ == "__main__": | |
def printer(): | |
while True: | |
msg = yield | |
print("Got:", msg) | |
def counter(sched): | |
while True: | |
# Receive the current count | |
n = yield | |
if n == 0: | |
break | |
# Send to the printer task | |
sched.send("printer", n) | |
# Send the next count to the counter task (recursive) | |
sched.send("counter", n - 1) | |
sched = ActorScheduler() | |
# Create the initial actors | |
sched.new_actor("printer", printer()) | |
sched.new_actor("counter", counter(sched)) | |
# Send an initial message to the counter to initiate | |
sched.send("counter", 10000) | |
sched.run() | |
################################################################################ | |
## 12::using_generators_as_an_alternative_to_threads | |
from collections import deque | |
from select import select | |
# This class represents a generic yield event in the scheduler | |
class YieldEvent: | |
def handle_yield(self, sched, task): | |
pass | |
def handle_resume(self, sched, task): | |
pass | |
# Task Scheduler | |
class Scheduler: | |
def __init__(self): | |
self._numtasks = 0 # Total num of tasks | |
self._ready = deque() # Tasks ready to run | |
self._read_waiting = {} # Tasks waiting to read | |
self._write_waiting = {} # Tasks waiting to write | |
# Poll for I/O events and restart waiting tasks | |
def _iopoll(self): | |
rset, wset, eset = select(self._read_waiting, self._write_waiting, []) | |
for r in rset: | |
evt, task = self._read_waiting.pop(r) | |
evt.handle_resume(self, task) | |
for w in wset: | |
evt, task = self._write_waiting.pop(w) | |
evt.handle_resume(self, task) | |
def new(self, task): | |
""" | |
Add a newly started task to the scheduler | |
""" | |
self._ready.append((task, None)) | |
self._numtasks += 1 | |
def add_ready(self, task, msg=None): | |
""" | |
Append an already started task to the ready queue. | |
msg is what to send into the task when it resumes. | |
""" | |
self._ready.append((task, msg)) | |
# Add a task to the reading set | |
def _read_wait(self, fileno, evt, task): | |
self._read_waiting[fileno] = (evt, task) | |
# Add a task to the write set | |
def _write_wait(self, fileno, evt, task): | |
self._write_waiting[fileno] = (evt, task) | |
def run(self): | |
""" | |
Run the task scheduler until there are no tasks | |
""" | |
while self._numtasks: | |
if not self._ready: | |
self._iopoll() | |
task, msg = self._ready.popleft() | |
try: | |
# Run the coroutine to the next yield | |
r = task.send(msg) | |
if isinstance(r, YieldEvent): | |
r.handle_yield(self, task) | |
else: | |
raise RuntimeError("unrecognized yield event") | |
except StopIteration: | |
self._numtasks -= 1 | |
# Example implementation of coroutine based socket I/O | |
class ReadSocket(YieldEvent): | |
def __init__(self, sock, nbytes): | |
self.sock = sock | |
self.nbytes = nbytes | |
def handle_yield(self, sched, task): | |
sched._read_wait(self.sock.fileno(), self, task) | |
def handle_resume(self, sched, task): | |
data = self.sock.recv(self.nbytes) | |
sched.add_ready(task, data) | |
class WriteSocket(YieldEvent): | |
def __init__(self, sock, data): | |
self.sock = sock | |
self.data = data | |
def handle_yield(self, sched, task): | |
sched._write_wait(self.sock.fileno(), self, task) | |
def handle_resume(self, sched, task): | |
nsent = self.sock.send(self.data) | |
sched.add_ready(task, nsent) | |
class AcceptSocket(YieldEvent): | |
def __init__(self, sock): | |
self.sock = sock | |
def handle_yield(self, sched, task): | |
sched._read_wait(self.sock.fileno(), self, task) | |
def handle_resume(self, sched, task): | |
r = self.sock.accept() | |
sched.add_ready(task, r) | |
# Wrapper around a socket object for use with yield | |
class Socket(object): | |
def __init__(self, sock): | |
self._sock = sock | |
def recv(self, maxbytes): | |
return ReadSocket(self._sock, maxbytes) | |
def send(self, data): | |
return WriteSocket(self._sock, data) | |
def accept(self): | |
return AcceptSocket(self._sock) | |
def __getattr__(self, name): | |
return getattr(self._sock, name) | |
if __name__ == "__main__": | |
from socket import socket, AF_INET, SOCK_STREAM | |
import time | |
# Example of a function involving generators. This should | |
# be called using line = yield from readline(sock) | |
def readline(sock): | |
chars = [] | |
while True: | |
c = yield sock.recv(1) | |
if not c: | |
break | |
chars.append(c) | |
if c == b"\n": | |
break | |
return b"".join(chars) | |
# Echo server using generators | |
class EchoServer: | |
def __init__(self, addr, sched): | |
self.sched = sched | |
sched.new(self.server_loop(addr)) | |
def server_loop(self, addr): | |
s = Socket(socket(AF_INET, SOCK_STREAM)) | |
s.bind(addr) | |
s.listen(5) | |
while True: | |
c, a = yield s.accept() | |
print("Got connection from ", a) | |
self.sched.new(self.client_handler(Socket(c))) | |
def client_handler(self, client): | |
while True: | |
line = yield from readline(client) | |
if not line: | |
break | |
line = b"GOT:" + line | |
while line: | |
nsent = yield client.send(line) | |
line = line[nsent:] | |
client.close() | |
print("Client closed") | |
sched = Scheduler() | |
EchoServer(("", 16000), sched) | |
sched.run() | |
################################################################################ | |
## 12::using_generators_as_an_alternative_to_threads | |
# A very simple example of a coroutine/generator scheduler | |
# Two simple generator functions | |
def countdown(n): | |
while n > 0: | |
print("T-minus", n) | |
yield | |
n -= 1 | |
print("Blastoff!") | |
def countup(n): | |
x = 0 | |
while x < n: | |
print("Counting up", x) | |
yield | |
x += 1 | |
from collections import deque | |
class TaskScheduler: | |
def __init__(self): | |
self._task_queue = deque() | |
def new_task(self, task): | |
""" | |
Admit a newly started task to the scheduler | |
""" | |
self._task_queue.append(task) | |
def run(self): | |
""" | |
Run until there are no more tasks | |
""" | |
while self._task_queue: | |
task = self._task_queue.popleft() | |
try: | |
# Run until the next yield statement | |
next(task) | |
self._task_queue.append(task) | |
except StopIteration: | |
# Generator is no longer executing | |
pass | |
# Example use | |
sched = TaskScheduler() | |
sched.new_task(countdown(10)) | |
sched.new_task(countdown(5)) | |
sched.new_task(countup(15)) | |
sched.run() | |
################################################################################ | |
## 13::adding_logging_to_libraries | |
# somelib.py | |
import logging | |
log = logging.getLogger(__name__) | |
log.addHandler(logging.NullHandler()) | |
# Example function (for testing) | |
def func(): | |
log.critical("A Critical Error!") | |
log.debug("A debug message") | |
################################################################################ | |
## 13::executing_an_external_command_and_getting_its_output | |
import subprocess | |
try: | |
out_bytes = subprocess.check_output(["netstat", "-a"]) | |
out_text = out_bytes.decode("utf-8") | |
print(out_text) | |
except subprocess.CalledProcessError as e: | |
print("It did not work. Reason:", e) | |
print("Exitcode:", e.returncode) | |
################################################################################ | |
## 13::executing_an_external_command_and_getting_its_output | |
import subprocess | |
# Some text to send | |
text = b""" | |
hello world | |
this is a test | |
goodbye | |
""" | |
# Launch a command with pipes | |
p = subprocess.Popen(["wc"], stdout=subprocess.PIPE, stdin=subprocess.PIPE) | |
# Send the data and get the output | |
stdout, stderr = p.communicate(text) | |
text = stdout.decode("utf-8") | |
print(text) | |
################################################################################ | |
## 13::finding_files | |
#!/usr/bin/env python3.3 | |
import os | |
import time | |
def modified_within(top, seconds): | |
now = time.time() | |
for path, dirs, files in os.walk(top): | |
for name in files: | |
fullpath = os.path.join(path, name) | |
if os.path.exists(fullpath): | |
mtime = os.path.getmtime(fullpath) | |
if mtime > (now - seconds): | |
print(fullpath) | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) != 3: | |
print("Usage: {} dir seconds".format(sys.argv[0])) | |
raise SystemExit(1) | |
modified_within(sys.argv[1], float(sys.argv[2])) | |
################################################################################ | |
## 13::generating_a_range_of_ip_addresses_from_a_cidr_address | |
from socket import AF_INET, AF_INET6, inet_pton, inet_ntop | |
def cidr_range(cidr_address): | |
family = AF_INET6 if ":" in cidr_address else AF_INET | |
address, maskstr = cidr_address.split("/") | |
maskbits = int(maskstr) | |
# Parse the supplied address into bytes | |
addr_bytes = inet_pton(family, address) | |
# Calculate number of address bytes and mask bits | |
addr_len = len(addr_bytes) | |
numaddrs = 2 ** (addr_len * 8 - maskbits) | |
mask = -numaddrs | |
# Generate addresses | |
addr = int.from_bytes(addr_bytes, "big") & mask | |
for n in range(numaddrs): | |
yield inet_ntop(family, (addr + n).to_bytes(addr_len, "big")) | |
if __name__ == "__main__": | |
for a in cidr_range("123.45.67.89/27"): | |
print(a) | |
for a in cidr_range("12:3456:78:90ab:cd:ef01:23:34/125"): | |
print(a) | |
################################################################################ | |
## 13::getting_the_terminal_size | |
import os | |
sz = os.get_terminal_size() | |
print(sz.columns, "columns") | |
print(sz.lines, "lines") | |
################################################################################ | |
## 13::making_a_stopwatch | |
import time | |
class Timer: | |
def __init__(self, func=time.perf_counter): | |
self.elapsed = 0.0 | |
self._func = func | |
self._start = None | |
def start(self): | |
if self._start is not None: | |
raise RuntimeError("Already started") | |
self._start = self._func() | |
def end(self): | |
if self._start is None: | |
raise RuntimeError("Not started") | |
end = self._func() | |
self.elapsed += end - self._start | |
self._start = None | |
def reset(self): | |
self.elapsed = 0.0 | |
@property | |
def running(self): | |
return self._start is not None | |
def __enter__(self): | |
self.start() | |
return self | |
def __exit__(self, *args): | |
self.end() | |
if __name__ == "__main__": | |
def countdown(n): | |
while n > 0: | |
n -= 1 | |
t = Timer() | |
t.start() | |
countdown(1000000) | |
t.end() | |
print(t.elapsed) | |
with t: | |
countdown(1000000) | |
print(t.elapsed) | |
################################################################################ | |
## 13::parsing_command_line_options | |
# search.py | |
""" | |
Hypothetical command line tool for searching a collection of | |
files for one or more text patterns. | |
""" | |
import argparse | |
parser = argparse.ArgumentParser(description="Search some files") | |
parser.add_argument(dest="filenames", metavar="filename", nargs="*") | |
parser.add_argument( | |
"-p", | |
"--pat", | |
metavar="pattern", | |
required=True, | |
dest="patterns", | |
action="append", | |
help="text pattern to search for", | |
) | |
parser.add_argument("-v", dest="verbose", action="store_true", help="verbose mode") | |
parser.add_argument("-o", dest="outfile", action="store", help="output file") | |
parser.add_argument( | |
"--speed", | |
dest="speed", | |
action="store", | |
choices={"slow", "fast"}, | |
default="slow", | |
help="search speed", | |
) | |
args = parser.parse_args() | |
# Output the collected arguments | |
print(args.filenames) | |
print(args.patterns) | |
print(args.verbose) | |
print(args.outfile) | |
print(args.speed) | |
################################################################################ | |
## 13::prompting_for_a_password_at_runtime | |
import getpass | |
user = getpass.getuser() | |
passwd = getpass.getpass() | |
print("User:", user) | |
print("Passwd:", passwd) | |
################################################################################ | |
## 13::putting_limits_on_memory_and_cpu_usage | |
import signal | |
import resource | |
import os | |
def time_exceeded(signo, frame): | |
print("Time's up!") | |
raise SystemExit(1) | |
def set_max_runtime(seconds): | |
# Install the signal handler and set a resource limit | |
soft, hard = resource.getrlimit(resource.RLIMIT_CPU) | |
resource.setrlimit(resource.RLIMIT_CPU, (seconds, hard)) | |
signal.signal(signal.SIGXCPU, time_exceeded) | |
if __name__ == "__main__": | |
set_max_runtime(15) | |
while True: | |
pass | |
################################################################################ | |
## 13::reading_configuration_files | |
from configparser import ConfigParser | |
cfg = ConfigParser() | |
cfg.read("config.ini") | |
print("sections:", cfg.sections()) | |
print("installation:library", cfg.get("installation", "library")) | |
print("debug:log_errors", cfg.getboolean("debug", "log_errors")) | |
print("server:port", cfg.getint("server", "port")) | |
print("server:nworkers", cfg.getint("server", "nworkers")) | |
print("server:signature", cfg.get("server", "signature")) | |
################################################################################ | |
## 13::simple_logging_for_scripts | |
import logging | |
def main(): | |
# Configure the logging system | |
logging.basicConfig(filename="app.log", level=logging.ERROR) | |
# Variables (to make the calls that follow work) | |
hostname = "www.python.org" | |
item = "spam" | |
filename = "data.csv" | |
mode = "r" | |
# Example logging calls (insert into your program) | |
logging.critical("Host %s unknown", hostname) | |
logging.error("Couldn't find %r", item) | |
logging.warning("Feature is deprecated") | |
logging.info("Opening file %r, mode=%r", filename, mode) | |
logging.debug("Got here") | |
if __name__ == "__main__": | |
main() | |
################################################################################ | |
## 13::simple_logging_for_scripts | |
import logging | |
import logging.config | |
def main(): | |
# Configure the logging system | |
logging.config.fileConfig("logconfig.ini") | |
# Variables (to make the calls that follow work) | |
hostname = "www.python.org" | |
item = "spam" | |
filename = "data.csv" | |
mode = "r" | |
# Example logging calls (insert into your program) | |
logging.critical("Host %s unknown", hostname) | |
logging.error("Couldn't find %r", item) | |
logging.warning("Feature is deprecated") | |
logging.info("Opening file %r, mode=%r", filename, mode) | |
logging.debug("Got here") | |
if __name__ == "__main__": | |
main() | |
################################################################################ | |
## 14::logging_test_output_to_a_file | |
import unittest | |
# A simple function to illustrate | |
def parse_int(s): | |
return int(s) | |
class TestConversion(unittest.TestCase): | |
# Testing that an exception gets raised | |
def test_bad_int(self): | |
self.assertRaises(ValueError, parse_int, "N/A") | |
# Testing an exception plus regex on exception message | |
def test_bad_int_msg(self): | |
self.assertRaisesRegex(ValueError, "invalid literal .*", parse_int, "N/A") | |
# Example of testing an exception along with inspection of exception instance | |
import errno | |
class TestIO(unittest.TestCase): | |
def test_file_not_found(self): | |
try: | |
f = open("/file/not/found") | |
except IOError as e: | |
self.assertEqual(e.errno, errno.ENOENT) | |
else: | |
self.fail("IOError not raised") | |
import sys | |
def main(out=sys.stderr, verbosity=2): | |
loader = unittest.TestLoader() | |
suite = loader.loadTestsFromModule(sys.modules[__name__]) | |
unittest.TextTestRunner(out, verbosity=verbosity).run(suite) | |
if __name__ == "__main__": | |
with open("testing.out", "w") as f: | |
main(f) | |
################################################################################ | |
## 14::make_your_programs_run_faster | |
import time | |
def test(func): | |
start = time.time() | |
nums = range(1000000) | |
for n in range(100): | |
r = func(nums) | |
end = time.time() | |
print(func.__name__, ":", end - start) | |
import math | |
def compute_roots_1(nums): | |
result = [] | |
for n in nums: | |
result.append(math.sqrt(n)) | |
return result | |
from math import sqrt | |
def compute_roots_2(nums): | |
result = [] | |
result_append = result.append | |
for n in nums: | |
result_append(sqrt(n)) | |
return result | |
def compute_roots_3(nums): | |
sqrt = math.sqrt | |
result = [] | |
result_append = result.append | |
for n in nums: | |
result_append(sqrt(n)) | |
return result | |
tests = [compute_roots_1, compute_roots_2, compute_roots_3] | |
for func in tests: | |
test(func) | |
################################################################################ | |
## 14::profiling_and_timing_your_program | |
# timethis.py | |
import time | |
from functools import wraps | |
def timethis(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
start = time.perf_counter() | |
r = func(*args, **kwargs) | |
end = time.perf_counter() | |
print("{}.{} : {}".format(func.__module__, func.__name__, end - start)) | |
return r | |
return wrapper | |
if __name__ == "__main__": | |
@timethis | |
def countdown(n): | |
while n > 0: | |
n -= 1 | |
countdown(10000000) | |
################################################################################ | |
## 14::raising_an_exception_in_response_to_another_exception | |
# Different styles of raising chained exceptions | |
# Example 1: Explicit chaining. Use this whenever your | |
# intent is to raise a new exception in response to another | |
def example1(): | |
try: | |
int("N/A") | |
except ValueError as e: | |
raise RuntimeError("A parsing error occurred") from e | |
# Example 2: Implicit chaining. This occurs if there's an | |
# unexpected exception in the except block. | |
def example2(): | |
try: | |
int("N/A") | |
except ValueError as e: | |
print("It failed. Reason:", err) # Intentional error | |
# Example 3: Discarding the previous exception | |
def example3(): | |
try: | |
int("N/A") | |
except ValueError as e: | |
raise RuntimeError("A parsing error occurred") from None | |
if __name__ == "__main__": | |
import traceback | |
print("****** EXPLICIT EXCEPTION CHAINING ******") | |
try: | |
example1() | |
except Exception: | |
traceback.print_exc() | |
print() | |
print("****** IMPLICIT EXCEPTION CHAINING ******") | |
try: | |
example2() | |
except Exception: | |
traceback.print_exc() | |
print() | |
print("****** DISCARDED CHAINING *******") | |
try: | |
example3() | |
except Exception: | |
traceback.print_exc() | |
################################################################################ | |
## 14::skipping_or_anticipating_test_failures | |
import unittest | |
import os | |
import platform | |
class Tests(unittest.TestCase): | |
def test_0(self): | |
self.assertTrue(True) | |
@unittest.skip("skipped test") | |
def test_1(self): | |
self.fail("should have failed!") | |
@unittest.skipIf(os.name == "posix", "Not supported on Unix") | |
def test_2(self): | |
import winreg | |
@unittest.skipUnless(platform.system() == "Darwin", "Mac specific test") | |
def test_3(self): | |
self.assertTrue(True) | |
@unittest.expectedFailure | |
def test_4(self): | |
self.assertEqual(2 + 2, 5) | |
if __name__ == "__main__": | |
unittest.main(verbosity=2) | |
################################################################################ | |
## 14::testing_for_exceptional_conditions_in_unit_tests | |
import unittest | |
# A simple function to illustrate | |
def parse_int(s): | |
return int(s) | |
class TestConversion(unittest.TestCase): | |
# Testing that an exception gets raised | |
def test_bad_int(self): | |
self.assertRaises(ValueError, parse_int, "N/A") | |
# Testing an exception plus regex on exception message | |
def test_bad_int_msg(self): | |
self.assertRaisesRegex(ValueError, "invalid literal .*", parse_int, "N/A") | |
# Example of testing an exception along with inspection of exception instance | |
import errno | |
class TestIO(unittest.TestCase): | |
def test_file_not_found(self): | |
try: | |
f = open("/file/not/found") | |
except IOError as e: | |
self.assertEqual(e.errno, errno.ENOENT) | |
else: | |
self.fail("IOError not raised") | |
if __name__ == "__main__": | |
unittest.main() | |
################################################################################ | |
## 14::testing_output_sent_to_stdout | |
# mymodule.py | |
def urlprint(protocol, host, domain): | |
url = "{}://{}.{}".format(protocol, host, domain) | |
print(url) | |
################################################################################ | |
## 14::testing_output_sent_to_stdout | |
from io import StringIO | |
from unittest import TestCase | |
from unittest.mock import patch | |
import mymodule | |
class TestURLPrint(TestCase): | |
def test_url_gets_to_stdout(self): | |
protocol = "http" | |
host = "www" | |
domain = "example.com" | |
expected_url = "{}://{}.{}\n".format(protocol, host, domain) | |
with patch("sys.stdout", new=StringIO()) as fake_out: | |
mymodule.urlprint(protocol, host, domain) | |
self.assertEqual(fake_out.getvalue(), expected_url) | |
if __name__ == "__main__": | |
import unittest | |
unittest.main() | |
################################################################################ | |
## 15::accessing_c_code_using_ctypes | |
import sample | |
print(sample.gcd(35, 42)) | |
print(sample.in_mandel(0, 0, 500)) | |
print(sample.in_mandel(2.0, 1.0, 500)) | |
print(sample.divide(42, 8)) | |
print(sample.avg([1, 2, 3])) | |
p1 = sample.Point(1, 2) | |
p2 = sample.Point(4, 5) | |
print(sample.distance(p1, p2)) | |
################################################################################ | |
## 15::accessing_c_code_using_ctypes | |
# sample.py | |
import ctypes | |
import os | |
# .so file is located in the directory above. See Makefile for | |
# build instructions | |
_path = "../libsample.so" | |
_mod = ctypes.cdll.LoadLibrary(_path) | |
# int gcd(int, int) | |
gcd = _mod.gcd | |
gcd.argtypes = (ctypes.c_int, ctypes.c_int) | |
gcd.restype = ctypes.c_int | |
# int in_mandel(double, double, int) | |
in_mandel = _mod.in_mandel | |
in_mandel.argtypes = (ctypes.c_double, ctypes.c_double, ctypes.c_int) | |
in_mandel.restype = ctypes.c_int | |
# int divide(int, int, int *) | |
_divide = _mod.divide | |
_divide.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int)) | |
_divide.restype = ctypes.c_int | |
def divide(x, y): | |
rem = ctypes.c_int() | |
quot = _divide(x, y, rem) | |
return quot, rem.value | |
# void avg(double *, int n) | |
# Define a special type for the 'double *' argument | |
class DoubleArrayType: | |
def from_param(self, param): | |
typename = type(param).__name__ | |
if hasattr(self, "from_" + typename): | |
return getattr(self, "from_" + typename)(param) | |
elif isinstance(param, ctypes.Array): | |
return param | |
else: | |
raise TypeError("Can't convert %s" % typename) | |
# Cast from array.array objects | |
def from_array(self, param): | |
if param.typecode != "d": | |
raise TypeError("must be an array of doubles") | |
ptr, _ = param.buffer_info() | |
return ctypes.cast(ptr, ctypes.POINTER(ctypes.c_double)) | |
# Cast from lists/tuples | |
def from_list(self, param): | |
val = ((ctypes.c_double) * len(param))(*param) | |
return val | |
from_tuple = from_list | |
# Cast from a numpy array | |
def from_ndarray(self, param): | |
return param.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) | |
DoubleArray = DoubleArrayType() | |
_avg = _mod.avg | |
_avg.argtypes = (DoubleArray, ctypes.c_int) | |
_avg.restype = ctypes.c_double | |
def avg(values): | |
return _avg(values, len(values)) | |
# struct Point { } | |
class Point(ctypes.Structure): | |
_fields_ = [("x", ctypes.c_double), ("y", ctypes.c_double)] | |
# double distance(Point *, Point *) | |
distance = _mod.distance | |
distance.argtypes = (ctypes.POINTER(Point), ctypes.POINTER(Point)) | |
distance.restype = ctypes.c_double | |
################################################################################ | |
## 15::consuming_an_iterable_from_c | |
import sample | |
sample.consume_iterable([1, 2, 3, 4]) | |
def countdown(n): | |
while n > 0: | |
yield n | |
n -= 1 | |
sample.consume_iterable(countdown(10)) | |
################################################################################ | |
## 15::consuming_an_iterable_from_c | |
# setup.py | |
from distutils.core import setup, Extension | |
setup(name="sample", ext_modules=[Extension("sample", ["sample.c"],)]) | |
################################################################################ | |
## 15::defining_and_exporting_c_apis_from_extension_modules | |
import sample | |
import ptexample | |
p1 = sample.Point(2, 3) | |
ptexample.print_point(p1) | |
################################################################################ | |
## 15::defining_and_exporting_c_apis_from_extension_modules | |
# setup.py | |
from distutils.core import setup, Extension | |
setup( | |
name="ptexample", | |
ext_modules=[ | |
Extension( | |
"ptexample", | |
["ptexample.c"], | |
include_dirs=["..", "."], # May need pysample.h directory | |
) | |
], | |
) | |
################################################################################ | |
## 15::defining_and_exporting_c_apis_from_extension_modules | |
# setup.py | |
from distutils.core import setup, Extension | |
setup( | |
name="sample", | |
ext_modules=[ | |
Extension("sample", ["../sample.c", "pysample.c"], include_dirs=[".."],) | |
], | |
) | |
################################################################################ | |
## 15::diagnosing_segmentation_faults | |
# example.py | |
import sample | |
def foo(): | |
print("About to die") | |
sample.die() | |
def bar(): | |
print("About to call the function that dies") | |
foo() | |
def spam(): | |
print("About to call the function that calls the function that dies") | |
bar() | |
if __name__ == "__main__": | |
import faulthandler | |
faulthandler.enable() | |
spam() | |
################################################################################ | |
## 15::diagnosing_segmentation_faults | |
# setup.py | |
from distutils.core import setup, Extension | |
setup(name="sample", ext_modules=[Extension("sample", ["sample.c"],)]) | |
################################################################################ | |
## 15::managing_opaque_pointers_in_c_extension_modules | |
import sample | |
p1 = sample.Point(2, 3) | |
p2 = sample.Point(4, 5) | |
print(p1) | |
print(p2) | |
print(sample.distance(p1, p2)) | |
del p1 | |
del p2 | |
print("Done") | |
################################################################################ | |
## 15::managing_opaque_pointers_in_c_extension_modules | |
# setup.py | |
from distutils.core import setup, Extension | |
setup( | |
name="sample", | |
ext_modules=[ | |
Extension("sample", ["../sample.c", "pysample.c"], include_dirs=[".."],) | |
], | |
) | |
################################################################################ | |
## 15::passing_null_terminated_strings_to_c_libraries | |
import sample | |
import sys | |
sample.print_chars(b"hello world") | |
s = "Spicy Jalape\u00f1o" | |
print(sys.getsizeof(s)) | |
sample.print_chars_str(s) | |
print(sys.getsizeof(s)) | |
del s | |
s = "spicy Jalape\u00f1o" | |
print(sys.getsizeof(s)) | |
sample.print_chars_str_alt(s) | |
print(sys.getsizeof(s)) | |
################################################################################ | |
## 15::passing_null_terminated_strings_to_c_libraries | |
# setup.py | |
from distutils.core import setup, Extension | |
setup(name="sample", ext_modules=[Extension("sample", ["sample.c"],)]) | |
################################################################################ | |
## 15::passing_unicode_strings_to_c_libraries | |
import sample | |
s = "Spicy Jalape\u00f1o" | |
sample.print_chars(s) | |
sample.print_wchars(s) | |
################################################################################ | |
## 15::passing_unicode_strings_to_c_libraries | |
# setup.py | |
from distutils.core import setup, Extension | |
setup(name="sample", ext_modules=[Extension("sample", ["sample.c"],)]) | |
################################################################################ | |
## 15::reading_file_like_objects_from_c | |
f = open("sample.c") | |
import sample | |
sample.consume_file(f) | |
f.close() | |
print("**** DONE") | |
################################################################################ | |
## 15::reading_file_like_objects_from_c | |
# setup.py | |
from distutils.core import setup, Extension | |
setup(name="sample", ext_modules=[Extension("sample", ["sample.c"],)]) | |
################################################################################ | |
## 15::turning_a_function_pointer_into_a_callable | |
import ctypes | |
lib = ctypes.cdll.LoadLibrary(None) | |
# Get the address of sin() from the C math library | |
addr = ctypes.cast(lib.sin, ctypes.c_void_p).value | |
print(addr) | |
140735505915760 | |
# Turn the address into a callable function | |
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double) | |
func = functype(addr) | |
print(func) | |
# Call the resulting function | |
print(func(2)) | |
print(func(0)) | |
################################################################################ | |
## 15::using_cython_to_write_high_performance_array_operations | |
# array module example | |
import sample | |
import array | |
a = array.array("d", [1, -3, 4, 7, 2, 0]) | |
print(a) | |
sample.clip(a, 1, 4, a) | |
print(a) | |
# numpy example | |
import numpy | |
b = numpy.random.uniform(-10, 10, size=1000000) | |
print(b) | |
c = numpy.zeros_like(b) | |
print(c) | |
sample.clip(b, -5, 5, c) | |
print(c) | |
print(min(c)) | |
print(max(c)) | |
# Timing test | |
from timeit import timeit | |
print("numpy.clip") | |
print(timeit("numpy.clip(b,-5,5,c)", "from __main__ import b,c,numpy", number=1000)) | |
print("sample.clip") | |
print(timeit("sample.clip(b,-5,5,c)", "from __main__ import b,c,sample", number=1000)) | |
print("sample.clip_fast") | |
print( | |
timeit("sample.clip_fast(b,-5,5,c)", "from __main__ import b,c,sample", number=1000) | |
) | |
# 2D test | |
d = numpy.random.uniform(-10, 10, size=(1000, 1000)) | |
print(d) | |
sample.clip2d(d, -5, 5, d) | |
print(d) | |
################################################################################ | |
## 15::using_cython_to_write_high_performance_array_operations | |
from distutils.core import setup | |
from distutils.extension import Extension | |
from Cython.Distutils import build_ext | |
ext_modules = [Extension("sample", ["sample.pyx"])] | |
setup(name="Sample app", cmdclass={"build_ext": build_ext}, ext_modules=ext_modules) | |
################################################################################ | |
## 15::working_with_c_strings_of_dubious_encoding | |
import sample | |
s = sample.retstr() | |
print(repr(s)) | |
sample.print_chars(s) | |
################################################################################ | |
## 15::working_with_c_strings_of_dubious_encoding | |
# setup.py | |
from distutils.core import setup, Extension | |
setup(name="sample", ext_modules=[Extension("sample", ["sample.c"],)]) | |
################################################################################ | |
## 15::wrapping_c_code_with_swig | |
import sample | |
print(sample.gcd(42, 8)) | |
print(sample.divide(42, 8)) | |
p1 = sample.Point(2, 3) | |
p2 = sample.Point(4, 5) | |
print(sample.distance(p1, p2)) | |
print(p1.x) | |
print(p1.y) | |
import array | |
a = array.array("d", [1, 2, 3]) | |
print(sample.avg(a)) | |
################################################################################ | |
## 15::wrapping_c_code_with_swig | |
# setup.py | |
from distutils.core import setup, Extension | |
setup( | |
name="sample", | |
py_modules=["sample.py"], | |
ext_modules=[ | |
Extension( | |
"_sample", | |
["../sample.c", "sample_wrap.c"], | |
include_dirs=[".."], | |
define_macros=[], | |
undef_macros=[], | |
library_dirs=[], | |
libraries=[], | |
) | |
], | |
) | |
################################################################################ | |
## 15::wrapping_existing_c_code_with_cython | |
import sample | |
print(sample.gcd(42, 8)) | |
print(sample.divide(42, 8)) | |
p1 = sample.Point(2, 3) | |
p2 = sample.Point(4, 5) | |
print(p1) | |
print(p2) | |
print(sample.distance(p1, p2)) | |
import array | |
a = array.array("d", [1, 2, 3]) | |
print(sample.avg(a)) | |
################################################################################ | |
## 15::wrapping_existing_c_code_with_cython | |
from distutils.core import setup | |
from distutils.extension import Extension | |
from Cython.Distutils import build_ext | |
ext_modules = [ | |
Extension( | |
"sample", | |
["sample.pyx"], | |
include_dirs=[".."], | |
libraries=["sample"], | |
library_dirs=[".."], | |
runtime_library_dirs=[".."], | |
) | |
] | |
setup( | |
name="Sample extension module", | |
cmdclass={"build_ext": build_ext}, | |
ext_modules=ext_modules, | |
) | |
################################################################################ | |
## 15::wrapping_existing_c_code_with_cython | |
from distutils.core import setup | |
from distutils.extension import Extension | |
from Cython.Distutils import build_ext | |
ext_modules = [ | |
Extension( | |
"sample", | |
["sample_alt.pyx"], | |
include_dirs=[".."], | |
libraries=["sample"], | |
runtime_library_dirs=[".."], | |
library_dirs=[".."], | |
) | |
] | |
setup( | |
name="Sample extension module", | |
cmdclass={"build_ext": build_ext}, | |
ext_modules=ext_modules, | |
) | |
################################################################################ | |
## 15::writing_a_simple_c_extension_module | |
import sample | |
print(sample.gcd(35, 42)) | |
print(sample.in_mandel(0, 0, 500)) | |
print(sample.in_mandel(2.0, 1.0, 500)) | |
print(sample.divide(42, 8)) | |
################################################################################ | |
## 15::writing_a_simple_c_extension_module | |
# setup.py | |
from distutils.core import setup, Extension | |
setup( | |
name="sample", | |
ext_modules=[ | |
Extension("sample", ["../sample.c", "pysample.c"], include_dirs=[".."],) | |
], | |
) | |
################################################################################ | |
## 15::writing_an_extension_function_that_operates_on_arrays | |
import array | |
from sample import avg | |
print(avg(array.array("d", [1, 2, 3]))) | |
try: | |
import numpy | |
print(avg(numpy.array([1.0, 2.0, 3.0]))) | |
except ImportError: | |
pass | |
################################################################################ | |
## 15::writing_an_extension_function_that_operates_on_arrays | |
# setup.py | |
from distutils.core import setup, Extension | |
setup( | |
name="sample", | |
ext_modules=[ | |
Extension("sample", ["../sample.c", "pysample.c"], include_dirs=[".."],) | |
], | |
) | |
################################################################################ | |
## 2::combining_and_concatenating_strings | |
# example.py | |
# | |
# Example of combining text via generators | |
def sample(): | |
yield "Is" | |
yield "Chicago" | |
yield "Not" | |
yield "Chicago?" | |
# (a) Simple join operator | |
text = "".join(sample()) | |
print(text) | |
# (b) Redirection of parts to I/O | |
import sys | |
for part in sample(): | |
sys.stdout.write(part) | |
sys.stdout.write("\n") | |
# (c) Combination of parts into buffers and larger I/O operations | |
def combine(source, maxsize): | |
parts = [] | |
size = 0 | |
for part in source: | |
parts.append(part) | |
size += len(part) | |
if size > maxsize: | |
yield "".join(parts) | |
parts = [] | |
size = 0 | |
yield "".join(parts) | |
for part in combine(sample(), 32768): | |
sys.stdout.write(part) | |
sys.stdout.write("\n") | |
################################################################################ | |
## 2::matching_and_searching_for_text_patterns_using_regular_expressions | |
# example.py | |
# | |
# Examples of simple regular expression matching | |
import re | |
# Some sample text | |
text = "Today is 11/27/2012. PyCon starts 3/13/2013." | |
# (a) Find all matching dates | |
datepat = re.compile(r"\d+/\d+/\d+") | |
print(datepat.findall(text)) | |
# (b) Find all matching dates with capture groups | |
datepat = re.compile(r"(\d+)/(\d+)/(\d+)") | |
for month, day, year in datepat.findall(text): | |
print("{}-{}-{}".format(year, month, day)) | |
# (c) Iterative search | |
for m in datepat.finditer(text): | |
print(m.groups()) | |
################################################################################ | |
## 2::matching_strings_using_shell_wildcard_patterns | |
# example.py | |
# | |
# Example of using shell-wildcard style matching in list comprehensions | |
from fnmatch import fnmatchcase as match | |
addresses = [ | |
"5412 N CLARK ST", | |
"1060 W ADDISON ST", | |
"1039 W GRANVILLE AVE", | |
"2122 N CLARK ST", | |
"4802 N BROADWAY", | |
] | |
a = [addr for addr in addresses if match(addr, "* ST")] | |
print(a) | |
b = [addr for addr in addresses if match(addr, "54[0-9][0-9] *CLARK*")] | |
print(b) | |
################################################################################ | |
## 2::normalizing_unicode_text_to_a_standard_representation | |
# example.py | |
# | |
# Example of unicode normalization | |
# Two strings | |
s1 = "Spicy Jalape\u00f1o" | |
s2 = "Spicy Jalapen\u0303o" | |
# (a) Print them out (usually looks identical) | |
print(s1) | |
print(s2) | |
# (b) Examine equality and length | |
print("s1 == s2", s1 == s2) | |
print(len(s1), len(s2)) | |
# (c) Normalize and try the same experiment | |
import unicodedata | |
n_s1 = unicodedata.normalize("NFC", s1) | |
n_s2 = unicodedata.normalize("NFC", s2) | |
print("n_s1 == n_s2", n_s1 == n_s2) | |
print(len(n_s1), len(n_s2)) | |
# (d) Example of normalizing to a decomposed form and stripping accents | |
t1 = unicodedata.normalize("NFD", s1) | |
print("".join(c for c in t1 if not unicodedata.combining(c))) | |
################################################################################ | |
## 2::reformatting_text_to_fixed_number_of_columns | |
# example.py | |
# | |
# Examples of reformatting text to different column widths | |
# A long string | |
s = "Look into my eyes, look into my eyes, the eyes, the eyes, \ | |
the eyes, not around the eyes, don't look around the eyes, \ | |
look into my eyes, you're under." | |
import textwrap | |
print(textwrap.fill(s, 70)) | |
print() | |
print(textwrap.fill(s, 40)) | |
print() | |
print(textwrap.fill(s, 40, initial_indent=" ")) | |
print() | |
print(textwrap.fill(s, 40, subsequent_indent=" ")) | |
print() | |
################################################################################ | |
## 2::sanitizing_and_cleaning_up_text | |
# example.py | |
# | |
# Example of some tricky sanitization problems | |
# A tricky string | |
s = "p\xfdt\u0125\xf6\xf1\x0cis\tawesome\r\n" | |
print(s) | |
# (a) Remapping whitespace | |
remap = {ord("\t"): " ", ord("\f"): " ", ord("\r"): None} # Deleted | |
a = s.translate(remap) | |
print("whitespace remapped:", a) | |
# (b) Remove all combining characters/marks | |
import unicodedata | |
import sys | |
cmb_chrs = dict.fromkeys( | |
c for c in range(sys.maxunicode) if unicodedata.combining(chr(c)) | |
) | |
b = unicodedata.normalize("NFD", a) | |
c = b.translate(cmb_chrs) | |
print("accents removed:", c) | |
# (c) Accent removal using I/O decoding | |
d = b.encode("ascii", "ignore").decode("ascii") | |
print("accents removed via I/O:", d) | |
################################################################################ | |
## 2::searching_and_replacing_text | |
# example.py | |
# | |
# Examples of simple regular expression substitution | |
import re | |
# Some sample text | |
text = "Today is 11/27/2012. PyCon starts 3/13/2013." | |
datepat = re.compile(r"(\d+)/(\d+)/(\d+)") | |
# (a) Simple substitution | |
print(datepat.sub(r"\3-\1-\2", text)) | |
# (b) Replacement function | |
from calendar import month_abbr | |
def change_date(m): | |
mon_name = month_abbr[int(m.group(1))] | |
return "{} {} {}".format(m.group(2), mon_name, m.group(3)) | |
print(datepat.sub(change_date, text)) | |
################################################################################ | |
## 2::specifying_a_regular_expression_for_the_shortest_match | |
# example.py | |
# | |
# Example of a regular expression that finds shortest matches | |
import re | |
# Sample text | |
text = 'Computer says "no." Phone says "yes."' | |
# (a) Regex that finds quoted strings - longest match | |
str_pat = re.compile(r"\"(.*)\"") | |
print(str_pat.findall(text)) | |
# (b) Regex that finds quoted strings - shortest match | |
str_pat = re.compile(r"\"(.*?)\"") | |
print(str_pat.findall(text)) | |
################################################################################ | |
## 2::splitting_strings_on_any_of_multiple_delimiters | |
# example.py | |
# | |
# Example of splitting a string on multiple delimiters using a regex | |
import re | |
line = "asdf fjdk; afed, fjek,asdf, foo" | |
# (a) Splitting on space, comma, and semicolon | |
parts = re.split(r"[;,\s]\s*", line) | |
print(parts) | |
# (b) Splitting with a capture group | |
fields = re.split(r"(;|,|\s)\s*", line) | |
print(fields) | |
# (c) Rebuilding a string using fields above | |
values = fields[::2] | |
delimiters = fields[1::2] | |
delimiters.append("") | |
print("value =", values) | |
print("delimiters =", delimiters) | |
newline = "".join(v + d for v, d in zip(values, delimiters)) | |
print("newline =", newline) | |
# (d) Splitting using a non-capture group | |
parts = re.split(r"(?:,|;|\s)\s*", line) | |
print(parts) | |
################################################################################ | |
## 2::tokenizing_text | |
# example.py | |
# | |
# Example of a tokenizer | |
import re | |
from collections import namedtuple | |
NAME = r"(?P<NAME>[a-zA-Z_][a-zA-Z_0-9]*)" | |
NUM = r"(?P<NUM>\d+)" | |
PLUS = r"(?P<PLUS>\+)" | |
TIMES = r"(?P<TIMES>\*)" | |
EQ = r"(?P<EQ>=)" | |
WS = r"(?P<WS>\s+)" | |
master_pat = re.compile("|".join([NAME, NUM, PLUS, TIMES, EQ, WS])) | |
Token = namedtuple("Token", ["type", "value"]) | |
def generate_tokens(pat, text): | |
scanner = pat.scanner(text) | |
for m in iter(scanner.match, None): | |
yield Token(m.lastgroup, m.group()) | |
for tok in generate_tokens(master_pat, "foo = 42"): | |
print(tok) | |
################################################################################ | |
## 2::variable_interpolation_in_strings | |
# example.py | |
# | |
# Examples of variable interpolation | |
# Class for performing safe substitutions | |
class safesub(dict): | |
def __missing__(self, key): | |
return "{%s}" % key | |
s = "{name} has {n} messages." | |
# (a) Simple substitution | |
name = "Guido" | |
n = 37 | |
print(s.format_map(vars())) | |
# (b) Safe substitution with missing values | |
del n | |
print(s.format_map(safesub(vars()))) | |
# (c) Safe substitution + frame hack | |
n = 37 | |
import sys | |
def sub(text): | |
return text.format_map(safesub(sys._getframe(1).f_locals)) | |
print(sub("Hello {name}")) | |
print(sub("{name} has {n} messages")) | |
print(sub("Your favorite color is {color}")) | |
################################################################################ | |
## 2::writing_a_regular_expression_for_multiline_patterns | |
# example.py | |
# | |
# Regular expression that matches multiline patterns | |
import re | |
text = """/* this is a | |
multiline comment */ | |
""" | |
comment = re.compile(r"/\*((?:.|\n)*?)\*/") | |
print(comment.findall(text)) | |
################################################################################ | |
## 2::writing_a_simple_recursive_descent_parser | |
# example.py | |
# | |
# An example of writing a simple recursive descent parser | |
import re | |
import collections | |
# Token specification | |
NUM = r"(?P<NUM>\d+)" | |
PLUS = r"(?P<PLUS>\+)" | |
MINUS = r"(?P<MINUS>-)" | |
TIMES = r"(?P<TIMES>\*)" | |
DIVIDE = r"(?P<DIVIDE>/)" | |
LPAREN = r"(?P<LPAREN>\()" | |
RPAREN = r"(?P<RPAREN>\))" | |
WS = r"(?P<WS>\s+)" | |
master_pat = re.compile("|".join([NUM, PLUS, MINUS, TIMES, DIVIDE, LPAREN, RPAREN, WS])) | |
# Tokenizer | |
Token = collections.namedtuple("Token", ["type", "value"]) | |
def generate_tokens(text): | |
scanner = master_pat.scanner(text) | |
for m in iter(scanner.match, None): | |
tok = Token(m.lastgroup, m.group()) | |
if tok.type != "WS": | |
yield tok | |
# Parser | |
class ExpressionEvaluator: | |
""" | |
Implementation of a recursive descent parser. Each method | |
implements a single grammar rule. Use the ._accept() method | |
to test and accept the current lookahead token. Use the ._expect() | |
method to exactly match and discard the next token on on the input | |
(or raise a SyntaxError if it doesn't match). | |
""" | |
def parse(self, text): | |
self.tokens = generate_tokens(text) | |
self.tok = None # Last symbol consumed | |
self.nexttok = None # Next symbol tokenized | |
self._advance() # Load first lookahead token | |
return self.expr() | |
def _advance(self): | |
"Advance one token ahead" | |
self.tok, self.nexttok = self.nexttok, next(self.tokens, None) | |
def _accept(self, toktype): | |
"Test and consume the next token if it matches toktype" | |
if self.nexttok and self.nexttok.type == toktype: | |
self._advance() | |
return True | |
else: | |
return False | |
def _expect(self, toktype): | |
"Consume next token if it matches toktype or raise SyntaxError" | |
if not self._accept(toktype): | |
raise SyntaxError("Expected " + toktype) | |
# Grammar rules follow | |
def expr(self): | |
"expression ::= term { ('+'|'-') term }*" | |
exprval = self.term() | |
while self._accept("PLUS") or self._accept("MINUS"): | |
op = self.tok.type | |
right = self.term() | |
if op == "PLUS": | |
exprval += right | |
elif op == "MINUS": | |
exprval -= right | |
return exprval | |
def term(self): | |
"term ::= factor { ('*'|'/') factor }*" | |
termval = self.factor() | |
while self._accept("TIMES") or self._accept("DIVIDE"): | |
op = self.tok.type | |
right = self.factor() | |
if op == "TIMES": | |
termval *= right | |
elif op == "DIVIDE": | |
termval /= right | |
return termval | |
def factor(self): | |
"factor ::= NUM | ( expr )" | |
if self._accept("NUM"): | |
return int(self.tok.value) | |
elif self._accept("LPAREN"): | |
exprval = self.expr() | |
self._expect("RPAREN") | |
return exprval | |
else: | |
raise SyntaxError("Expected NUMBER or LPAREN") | |
if __name__ == "__main__": | |
e = ExpressionEvaluator() | |
print(e.parse("2")) | |
print(e.parse("2 + 3")) | |
print(e.parse("2 + 3 * 4")) | |
print(e.parse("2 + (3 + 4) * 5")) | |
# Example of building trees | |
class ExpressionTreeBuilder(ExpressionEvaluator): | |
def expr(self): | |
"expression ::= term { ('+'|'-') term }" | |
exprval = self.term() | |
while self._accept("PLUS") or self._accept("MINUS"): | |
op = self.tok.type | |
right = self.term() | |
if op == "PLUS": | |
exprval = ("+", exprval, right) | |
elif op == "MINUS": | |
exprval = ("-", exprval, right) | |
return exprval | |
def term(self): | |
"term ::= factor { ('*'|'/') factor }" | |
termval = self.factor() | |
while self._accept("TIMES") or self._accept("DIVIDE"): | |
op = self.tok.type | |
right = self.factor() | |
if op == "TIMES": | |
termval = ("*", termval, right) | |
elif op == "DIVIDE": | |
termval = ("/", termval, right) | |
return termval | |
def factor(self): | |
"factor ::= NUM | ( expr )" | |
if self._accept("NUM"): | |
return int(self.tok.value) | |
elif self._accept("LPAREN"): | |
exprval = self.expr() | |
self._expect("RPAREN") | |
return exprval | |
else: | |
raise SyntaxError("Expected NUMBER or LPAREN") | |
if __name__ == "__main__": | |
e = ExpressionTreeBuilder() | |
print(e.parse("2 + 3")) | |
print(e.parse("2 + 3 * 4")) | |
print(e.parse("2 + (3 + 4) * 5")) | |
print(e.parse("2 + 3 + 4")) | |
################################################################################ | |
## 2::writing_a_simple_recursive_descent_parser | |
# plyexample.py | |
# | |
# Example of parsing with PLY | |
from ply.lex import lex | |
from ply.yacc import yacc | |
# Token list | |
tokens = ["NUM", "PLUS", "MINUS", "TIMES", "DIVIDE", "LPAREN", "RPAREN"] | |
# Ignored characters | |
t_ignore = " \t\n" | |
# Token specifications (as regexs) | |
t_PLUS = r"\+" | |
t_MINUS = r"-" | |
t_TIMES = r"\*" | |
t_DIVIDE = r"/" | |
t_LPAREN = r"\(" | |
t_RPAREN = r"\)" | |
# Token processing functions | |
def t_NUM(t): | |
r"\d+" | |
t.value = int(t.value) | |
return t | |
# Error handler | |
def t_error(t): | |
print("Bad character: {!r}".format(t.value[0])) | |
t.skip(1) | |
# Build the lexer | |
lexer = lex() | |
# Grammar rules and handler functions | |
def p_expr(p): | |
""" | |
expr : expr PLUS term | |
| expr MINUS term | |
""" | |
if p[2] == "+": | |
p[0] = p[1] + p[3] | |
elif p[2] == "-": | |
p[0] = p[1] - p[3] | |
def p_expr_term(p): | |
""" | |
expr : term | |
""" | |
p[0] = p[1] | |
def p_term(p): | |
""" | |
term : term TIMES factor | |
| term DIVIDE factor | |
""" | |
if p[2] == "*": | |
p[0] = p[1] * p[3] | |
elif p[2] == "/": | |
p[0] = p[1] / p[3] | |
def p_term_factor(p): | |
""" | |
term : factor | |
""" | |
p[0] = p[1] | |
def p_factor(p): | |
""" | |
factor : NUM | |
""" | |
p[0] = p[1] | |
def p_factor_group(p): | |
""" | |
factor : LPAREN expr RPAREN | |
""" | |
p[0] = p[2] | |
def p_error(p): | |
print("Syntax error") | |
parser = yacc() | |
if __name__ == "__main__": | |
print(parser.parse("2")) | |
print(parser.parse("2+3")) | |
print(parser.parse("2+(3+4)*5")) | |
################################################################################ | |
## 3::determining_last_fridays_date | |
from datetime import datetime, timedelta | |
weekdays = [ | |
"Monday", | |
"Tuesday", | |
"Wednesday", | |
"Thursday", | |
"Friday", | |
"Saturday", | |
"Sunday", | |
] | |
def get_previous_byday(dayname, start_date=None): | |
if start_date is None: | |
start_date = datetime.today() | |
day_num = start_date.weekday() | |
day_num_target = weekdays.index(dayname) | |
days_ago = (7 + day_num - day_num_target) % 7 | |
if days_ago == 0: | |
days_ago = 7 | |
target_date = start_date - timedelta(days=days_ago) | |
return target_date | |
################################################################################ | |
## 3::finding_the_date_range_for_the_current_month | |
from datetime import datetime, date, timedelta | |
import calendar | |
def get_month_range(start_date=None): | |
if start_date is None: | |
start_date = date.today().replace(day=1) | |
days_in_month = calendar.monthrange(start_date.year, start_date.month)[1] | |
end_date = start_date + timedelta(days=days_in_month) | |
return (start_date, end_date) | |
first_day, last_day = get_month_range() | |
a_day = timedelta(days=1) | |
while first_day < last_day: | |
print(first_day) | |
first_day += a_day | |
def daterange(start, stop, step): | |
while start < stop: | |
yield start | |
start += step | |
for d in daterange(date(2012, 8, 1), date(2012, 8, 11), timedelta(days=1)): | |
print(d) | |
for d in daterange(datetime(2012, 8, 1), datetime(2012, 8, 3), timedelta(minutes=30)): | |
print(d) | |
################################################################################ | |
## 4::creating_data_processing_pipelines | |
import os | |
import fnmatch | |
import gzip | |
import bz2 | |
import re | |
def gen_find(filepat, top): | |
""" | |
Find all filenames in a directory tree that match a shell wildcard pattern | |
""" | |
for path, dirlist, filelist in os.walk(top): | |
for name in fnmatch.filter(filelist, filepat): | |
yield os.path.join(path, name) | |
def gen_opener(filenames): | |
""" | |
Open a sequence of filenames one at a time producing a file object. | |
The file is closed immediately when proceeding to the next iteration. | |
""" | |
for filename in filenames: | |
if filename.endswith(".gz"): | |
f = gzip.open(filename, "rt") | |
elif filename.endswith(".bz2"): | |
f = bz2.open(filename, "rt") | |
else: | |
f = open(filename, "rt") | |
yield f | |
f.close() | |
def gen_concatenate(iterators): | |
""" | |
Chain a sequence of iterators together into a single sequence. | |
""" | |
for it in iterators: | |
yield from it | |
def gen_grep(pattern, lines): | |
""" | |
Look for a regex pattern in a sequence of lines | |
""" | |
pat = re.compile(pattern) | |
for line in lines: | |
if pat.search(line): | |
yield line | |
if __name__ == "__main__": | |
# Example 1 | |
lognames = gen_find("access-log*", "www") | |
files = gen_opener(lognames) | |
lines = gen_concatenate(files) | |
pylines = gen_grep("(?i)python", lines) | |
for line in pylines: | |
print(line) | |
# Example 2 | |
lognames = gen_find("access-log*", "www") | |
files = gen_opener(lognames) | |
lines = gen_concatenate(files) | |
pylines = gen_grep("(?i)python", lines) | |
bytecolumn = (line.rsplit(None, 1)[1] for line in pylines) | |
bytes = (int(x) for x in bytecolumn if x != "-") | |
print("Total", sum(bytes)) | |
################################################################################ | |
## 4::creating_new_iteration_patterns_with_generators | |
def frange(start, stop, increment): | |
x = start | |
while x < stop: | |
yield x | |
x += increment | |
for n in frange(0, 4, 0.5): | |
print(n) | |
################################################################################ | |
## 4::delegating-iteration | |
# Example of delegating iteration to an internal container | |
class Node: | |
def __init__(self, value): | |
self._value = value | |
self._children = [] | |
def __repr__(self): | |
return "Node({!r})".format(self._value) | |
def add_child(self, node): | |
self._children.append(node) | |
def __iter__(self): | |
return iter(self._children) | |
# Example | |
if __name__ == "__main__": | |
root = Node(0) | |
child1 = Node(1) | |
child2 = Node(2) | |
root.add_child(child1) | |
root.add_child(child2) | |
for ch in root: | |
print(ch) | |
# Outputs: Node(1), Node(2) | |
################################################################################ | |
## 4::easy_implementation_of_the_iterator_protocol | |
# example.py | |
# | |
# Example of depth-first search using a generator | |
class Node: | |
def __init__(self, value): | |
self._value = value | |
self._children = [] | |
def __repr__(self): | |
return "Node({!r})".format(self._value) | |
def add_child(self, node): | |
self._children.append(node) | |
def __iter__(self): | |
return iter(self._children) | |
def depth_first(self): | |
yield self | |
for c in self: | |
yield from c.depth_first() | |
# Example | |
if __name__ == "__main__": | |
root = Node(0) | |
child1 = Node(1) | |
child2 = Node(2) | |
root.add_child(child1) | |
root.add_child(child2) | |
child1.add_child(Node(3)) | |
child1.add_child(Node(4)) | |
child2.add_child(Node(5)) | |
for ch in root.depth_first(): | |
print(ch) | |
# Outputs: Node(0), Node(1), Node(3), Node(4), Node(2), Node(5) | |
################################################################################ | |
## 4::easy_implementation_of_the_iterator_protocol | |
# Hard example of depth-first iteration using an iterator object | |
class Node: | |
def __init__(self, value): | |
self._value = value | |
self._children = [] | |
def __repr__(self): | |
return "Node(%r)" % self._value | |
def add_child(self, other_node): | |
self._children.append(other_node) | |
def __iter__(self): | |
return iter(self._children) | |
def depth_first(self): | |
return DepthFirstIterator(self) | |
class DepthFirstIterator(object): | |
""" | |
Depth-first traversal | |
""" | |
def __init__(self, start_node): | |
self._node = start_node | |
self._children_iter = None | |
self._child_iter = None | |
def __iter__(self): | |
return self | |
def __next__(self): | |
# Return myself if just started. Create an iterator for children | |
if self._children_iter is None: | |
self._children_iter = iter(self._node) | |
return self._node | |
# If processing a child, return its next item | |
elif self._child_iter: | |
try: | |
nextchild = next(self._child_iter) | |
return nextchild | |
except StopIteration: | |
self._child_iter = None | |
return next(self) | |
# Advance to the next child and start its iteration | |
else: | |
self._child_iter = next(self._children_iter).depth_first() | |
return next(self) | |
# Example | |
if __name__ == "__main__": | |
root = Node(0) | |
child1 = Node(1) | |
child2 = Node(2) | |
root.add_child(child1) | |
root.add_child(child2) | |
child1.add_child(Node(3)) | |
child1.add_child(Node(4)) | |
child2.add_child(Node(5)) | |
for ch in root.depth_first(): | |
print(ch) | |
# Outputs: Node(0), Node(1), Node(3), Node(4), Node(2), Node(5) | |
################################################################################ | |
## 4::generators_with_state | |
# Example of a generator with extra state that can be | |
# accessed. Simply define as a class! | |
from collections import deque | |
class linehistory: | |
def __init__(self, lines, histlen=3): | |
self.lines = lines | |
self.history = deque(maxlen=histlen) | |
def __iter__(self): | |
for lineno, line in enumerate(self.lines, 1): | |
self.history.append((lineno, line)) | |
yield line | |
def clear(self): | |
self.history.clear() | |
with open("somefile.txt") as f: | |
lines = linehistory(f) | |
for line in lines: | |
if "python" in line: | |
for lineno, hline in lines.history: | |
print("{}:{}".format(lineno, hline), end="") | |
################################################################################ | |
## 4::how_to_flatten_a_nested_sequence | |
# Example of flattening a nested sequence using subgenerators | |
from collections import Iterable | |
def flatten(items, ignore_types=(str, bytes)): | |
for x in items: | |
if isinstance(x, Iterable) and not isinstance(x, ignore_types): | |
yield from flatten(x) | |
else: | |
yield x | |
items = [1, 2, [3, 4, [5, 6], 7], 8] | |
# Produces 1 2 3 4 5 6 7 8 | |
for x in flatten(items): | |
print(x) | |
items = ["Dave", "Paula", ["Thomas", "Lewis"]] | |
for x in flatten(items): | |
print(x) | |
################################################################################ | |
## 4::iterate_over_the_index-value_pairs_of_a_list | |
# Example of iterating over lines of a file with an extra lineno attribute | |
def parse_data(filename): | |
with open(filename, "rt") as f: | |
for lineno, line in enumerate(f, 1): | |
fields = line.split() | |
try: | |
count = int(fields[1]) | |
except ValueError as e: | |
print("Line {}: Parse error: {}".format(lineno, e)) | |
parse_data("sample.dat") | |
################################################################################ | |
## 4::iterating_in_reverse | |
# Example of an object implementing both forward and reversed iterators | |
class Countdown: | |
def __init__(self, start): | |
self.start = start | |
# Forward iterator | |
def __iter__(self): | |
n = self.start | |
while n > 0: | |
yield n | |
n -= 1 | |
# Reverse iterator | |
def __reversed__(self): | |
n = 1 | |
while n <= self.start: | |
yield n | |
n += 1 | |
c = Countdown(5) | |
print("Forward:") | |
for x in c: | |
print(x) | |
print("Reverse:") | |
for x in reversed(c): | |
print(x) | |
################################################################################ | |
## 4::iterating_in_sorted_order_over_merged_sorted_iterables | |
# Iterating over merged sorted iterables | |
import heapq | |
a = [1, 4, 7, 10] | |
b = [2, 5, 6, 11] | |
for c in heapq.merge(a, b): | |
print(c) | |
################################################################################ | |
## 4::iterating_on_items_in_separate_containers | |
# Example of iterating over two sequences as one | |
from itertools import chain | |
a = [1, 2, 3, 4] | |
b = ["x", "y", "z"] | |
for x in chain(a, b): | |
print(x) | |
################################################################################ | |
## 5::adding_or_changing_the_encoding_of_an_already_open_file | |
# Example of adding a text encoding to existing file-like object | |
import urllib.request | |
import io | |
u = urllib.request.urlopen("http://www.python.org") | |
f = io.TextIOWrapper(u, encoding="utf-8") | |
text = f.read() | |
print(text) | |
################################################################################ | |
## 5::getting_a_directory_listing | |
# Example of getting a directory listing | |
import os | |
import os.path | |
import glob | |
pyfiles = glob.glob("*.py") | |
# Get file sizes and modification dates | |
name_sz_date = [ | |
(name, os.path.getsize(name), os.path.getmtime(name)) for name in pyfiles | |
] | |
for r in name_sz_date: | |
print(r) | |
# Get file metadata | |
file_metadata = [(name, os.stat(name)) for name in pyfiles] | |
for name, meta in file_metadata: | |
print(name, meta.st_size, meta.st_mtime) | |
################################################################################ | |
## 5::iterating_over_fixed-sized_records | |
# Example of iterating of fixed-size records | |
# | |
# The file 'data.bin' contains 32-byte fixed size records | |
# that consist of a 4-digit number followed by a 28-byte string. | |
from functools import partial | |
RECORD_SIZE = 32 | |
with open("data.bin", "rb") as f: | |
records = iter(partial(f.read, RECORD_SIZE), b"") | |
for r in records: | |
print(r) | |
################################################################################ | |
## 5::reading_and_writing_text_data | |
# Some examples of reading text files with different options | |
# | |
# The file sample.txt is a UTF-8 encoded text file with Windows | |
# line-endings (\r\n). | |
# (a) Reading a basic text file (UTF-8 default encoding) | |
print("Reading a simple text file (UTF-8)") | |
with open("sample.txt", "rt") as f: | |
for line in f: | |
print(repr(line)) | |
# (b) Reading a text file with universal newlines turned off | |
print("Reading text file with universal newlines off") | |
with open("sample.txt", "rt", newline="") as f: | |
for line in f: | |
print(repr(line)) | |
# (c) Reading text file as ASCII with replacement error handling | |
print("Reading text as ASCII with replacement error handling") | |
with open("sample.txt", "rt", encoding="ascii", errors="replace") as f: | |
for line in f: | |
print(repr(line)) | |
# (d) Reading text file as ASCII with ignore error handling | |
print("Reading text as ASCII with ignore error handling") | |
with open("sample.txt", "rt", encoding="ascii", errors="ignore") as f: | |
for line in f: | |
print(repr(line)) | |
################################################################################ | |
## 5::wrapping_an_existing_file_descriptor_as_a_file_object | |
from socket import socket, AF_INET, SOCK_STREAM | |
def echo_client(client_sock, addr): | |
print("Got connection from", addr) | |
# Make text-mode file wrappers for socket reading/writing | |
client_in = open(client_sock.fileno(), "rt", encoding="latin-1", closefd=False) | |
client_out = open(client_sock.fileno(), "wt", encoding="latin-1", closefd=False) | |
# Echo lines back to the client using file I/O | |
for line in client_in: | |
client_out.write(line) | |
client_out.flush() | |
client_sock.close() | |
def echo_server(address): | |
sock = socket(AF_INET, SOCK_STREAM) | |
sock.bind(address) | |
sock.listen(1) | |
while True: | |
client, addr = sock.accept() | |
echo_client(client, addr) | |
if __name__ == "__main__": | |
print("Echo serving running on localhost:25000") | |
echo_server(("", 25000)) | |
################################################################################ | |
## 5::writing_bytes_to_a_text_file | |
# Example of writing raw bytes on a file opened in text mode | |
import sys | |
# A byte string | |
data = b"Hello World\n" | |
# Write onto the buffer attribute (bypassing text encoding) | |
sys.stdout.buffer.write(data) | |
################################################################################ | |
## 6::incremental_parsing_of_huge_xml_files | |
# Example of incremental XML parsing | |
# | |
# The file 'potholes.xml' is a greatly condensed version of a larger | |
# file available for download at | |
# | |
# https://data.cityofchicago.org/api/views/7as2-ds3y/rows.xml?accessType=DOWNLOAD | |
from xml.etree.ElementTree import iterparse | |
def parse_and_remove(filename, path): | |
path_parts = path.split("/") | |
doc = iterparse(filename, ("start", "end")) | |
# Skip the root element | |
next(doc) | |
tag_stack = [] | |
elem_stack = [] | |
for event, elem in doc: | |
if event == "start": | |
tag_stack.append(elem.tag) | |
elem_stack.append(elem) | |
elif event == "end": | |
if tag_stack == path_parts: | |
yield elem | |
elem_stack[-2].remove(elem) | |
try: | |
tag_stack.pop() | |
elem_stack.pop() | |
except IndexError: | |
pass | |
# Find zip code with most potholes | |
from collections import Counter | |
potholes_by_zip = Counter() | |
data = parse_and_remove("potholes.xml", "row/row") | |
for pothole in data: | |
potholes_by_zip[pothole.findtext("zip")] += 1 | |
for zipcode, num in potholes_by_zip.most_common(): | |
print(zipcode, num) | |
################################################################################ | |
## 6::parsing_modifying_and_rewriting_xml | |
# example.py | |
# | |
# Example of reading an XML document, making changes, and writing it back out | |
from xml.etree.ElementTree import parse, Element | |
doc = parse("pred.xml") | |
root = doc.getroot() | |
# Remove a few elements | |
root.remove(root.find("sri")) | |
root.remove(root.find("cr")) | |
# Insert a new element after <nm>...</nm> | |
nm_index = root.getchildren().index(root.find("nm")) | |
e = Element("spam") | |
e.text = "This is a test" | |
root.insert(nm_index + 1, e) | |
# Write back to a file | |
doc.write("newpred.xml", xml_declaration=True) | |
################################################################################ | |
## 6::parsing_simple_xml_data | |
from urllib.request import urlopen | |
from xml.etree.ElementTree import parse | |
# Download the RSS feed and parse it | |
u = urlopen("http://planet.python.org/rss20.xml") | |
doc = parse(u) | |
# Extract and output tags of interest | |
for item in doc.iterfind("channel/item"): | |
title = item.findtext("title") | |
date = item.findtext("pubDate") | |
link = item.findtext("link") | |
print(title) | |
print(date) | |
print(link) | |
print() | |
################################################################################ | |
## 6::parsing_xml_documents_with_namespaces | |
# example.py | |
# | |
# Example of XML namespace handling | |
from xml.etree.ElementTree import parse | |
class XMLNamespaces: | |
def __init__(self, **kwargs): | |
self.namespaces = {} | |
for name, uri in kwargs.items(): | |
self.register(name, uri) | |
def register(self, name, uri): | |
self.namespaces[name] = "{" + uri + "}" | |
def __call__(self, path): | |
return path.format_map(self.namespaces) | |
doc = parse("sample.xml") | |
ns = XMLNamespaces(html="http://www.w3.org/1999/xhtml") | |
e = doc.find(ns("content/{html}html")) | |
print(e) | |
text = doc.findtext(ns("content/{html}html/{html}head/{html}title")) | |
print(text) | |
################################################################################ | |
## 6::reading_and_writing_binary_arrays_of_structures | |
from struct import Struct | |
def read_records(format, f): | |
record_struct = Struct(format) | |
chunks = iter(lambda: f.read(record_struct.size), b"") | |
return (record_struct.unpack(chunk) for chunk in chunks) | |
# Example | |
if __name__ == "__main__": | |
with open("data.b", "rb") as f: | |
for rec in read_records("<idd", f): | |
# Process rec | |
print(rec) | |
################################################################################ | |
## 6::reading_and_writing_binary_arrays_of_structures | |
from struct import Struct | |
def unpack_records(format, data): | |
record_struct = Struct(format) | |
return ( | |
record_struct.unpack_from(data, offset) | |
for offset in range(0, len(data), record_struct.size) | |
) | |
# Example | |
if __name__ == "__main__": | |
with open("data.b", "rb") as f: | |
data = f.read() | |
for rec in unpack_records("<idd", data): | |
# Process record | |
print(rec) | |
################################################################################ | |
## 6::reading_and_writing_binary_arrays_of_structures | |
from struct import Struct | |
def write_records(records, format, f): | |
""" | |
Write a sequence of tuples to a binary file of structures. | |
""" | |
record_struct = Struct(format) | |
for r in records: | |
f.write(record_struct.pack(*r)) | |
# Example | |
if __name__ == "__main__": | |
records = [(1, 2.3, 4.5), (6, 7.8, 9.0), (12, 13.4, 56.7)] | |
with open("data.b", "wb") as f: | |
write_records(records, "<idd", f) | |
################################################################################ | |
## 6::reading_and_writing_csv_data | |
# example.py | |
# | |
# Various samples of reading CSV files | |
import csv | |
# (a) Reading as tuples | |
print("Reading as tuples:") | |
with open("stocks.csv") as f: | |
f_csv = csv.reader(f) | |
headers = next(f_csv) | |
for row in f_csv: | |
# process row | |
print(" ", row) | |
# (b) Reading as namedtuples | |
print("Reading as namedtuples") | |
from collections import namedtuple | |
with open("stocks.csv") as f: | |
f_csv = csv.reader(f) | |
Row = namedtuple("Row", next(f_csv)) | |
for r in f_csv: | |
row = Row(*r) | |
# Process row | |
print(" ", row) | |
# (c) Reading as dictionaries | |
print("Reading as dicts") | |
with open("stocks.csv") as f: | |
f_csv = csv.DictReader(f) | |
for row in f_csv: | |
# process row | |
print(" ", row) | |
# (d) Reading into tuples with type conversion | |
print("Reading into named tuples with type conversion") | |
col_types = [str, float, str, str, float, int] | |
with open("stocks.csv") as f: | |
f_csv = csv.reader(f) | |
headers = next(f_csv) | |
for row in f_csv: | |
# Apply conversions to the row items | |
row = tuple(convert(value) for convert, value in zip(col_types, row)) | |
print(row) | |
# (e) Converting selected dict fields | |
print("Reading as dicts with type conversion") | |
field_types = [("Price", float), ("Change", float), ("Volume", int)] | |
with open("stocks.csv") as f: | |
for row in csv.DictReader(f): | |
row.update((key, conversion(row[key])) for key, conversion in field_types) | |
print(row) | |
################################################################################ | |
## 6::reading_and_writing_json_data | |
# Some advanced JSON examples involving ordered dicts and classes | |
import json | |
# Some JSON encoded text | |
s = '{"name": "ACME", "shares": 50, "price": 490.1}' | |
# (a) Turning JSON into an OrderedDict | |
from collections import OrderedDict | |
data = json.loads(s, object_pairs_hook=OrderedDict) | |
print(data) | |
# (b) Using JSON to populate an instance | |
class JSONObject: | |
def __init__(self, d): | |
self.__dict__ = d | |
data = json.loads(s, object_hook=JSONObject) | |
print(data.name) | |
print(data.shares) | |
print(data.price) | |
# (c) Encoding instances | |
class Point: | |
def __init__(self, x, y): | |
self.x = x | |
self.y = y | |
def serialize_instance(obj): | |
d = {"__classname__": type(obj).__name__} | |
d.update(vars(obj)) | |
return d | |
p = Point(3, 4) | |
s = json.dumps(p, default=serialize_instance) | |
print(s) | |
# (d) Decoding instances | |
classes = {"Point": Point} | |
def unserialize_object(d): | |
clsname = d.pop("__classname__", None) | |
if clsname: | |
cls = classes[clsname] | |
obj = cls.__new__(cls) | |
for key, value in d.items(): | |
setattr(obj, key, value) | |
return obj | |
else: | |
return d | |
a = json.loads(s, object_hook=unserialize_object) | |
print(a) | |
print(a.x) | |
print(a.y) | |
################################################################################ | |
## 6::reading_nested_and_variable_sized_binary_structures | |
import struct | |
class StructField: | |
def __init__(self, format, offset): | |
self.format = format | |
self.offset = offset | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
r = struct.unpack_from(self.format, instance._buffer, self.offset) | |
return r[0] if len(r) == 1 else r | |
class Structure: | |
def __init__(self, bytedata): | |
self._buffer = memoryview(bytedata) | |
if __name__ == "__main__": | |
class PolyHeader(Structure): | |
file_code = StructField("<i", 0) | |
min_x = StructField("<d", 4) | |
min_y = StructField("<d", 12) | |
max_x = StructField("<d", 20) | |
max_y = StructField("<d", 28) | |
num_polys = StructField("<i", 36) | |
f = open("polys.bin", "rb") | |
data = f.read() | |
phead = PolyHeader(data) | |
print(phead.file_code == 0x1234) | |
print("min_x=", phead.min_x) | |
print("max_x=", phead.max_x) | |
print("min_y=", phead.min_y) | |
print("max_y=", phead.max_y) | |
print("num_polys=", phead.num_polys) | |
################################################################################ | |
## 6::reading_nested_and_variable_sized_binary_structures | |
# Example 2: Introduction of a metaclass | |
import struct | |
class StructField: | |
def __init__(self, format, offset): | |
self.format = format | |
self.offset = offset | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
r = struct.unpack_from(self.format, instance._buffer, self.offset) | |
return r[0] if len(r) == 1 else r | |
class StructureMeta(type): | |
""" | |
Metaclass that automatically creates StructField descriptors | |
""" | |
def __init__(self, clsname, bases, clsdict): | |
fields = getattr(self, "_fields_", []) | |
byte_order = "" | |
offset = 0 | |
for format, fieldname in fields: | |
if format.startswith(("<", ">", "!", "@")): | |
byte_order = format[0] | |
format = format[1:] | |
format = byte_order + format | |
setattr(self, fieldname, StructField(format, offset)) | |
offset += struct.calcsize(format) | |
setattr(self, "struct_size", offset) | |
class Structure(metaclass=StructureMeta): | |
def __init__(self, bytedata): | |
self._buffer = memoryview(bytedata) | |
@classmethod | |
def from_file(cls, f): | |
return cls(f.read(cls.struct_size)) | |
if __name__ == "__main__": | |
class PolyHeader(Structure): | |
_fields_ = [ | |
("<i", "file_code"), | |
("d", "min_x"), | |
("d", "min_y"), | |
("d", "max_x"), | |
("d", "max_y"), | |
("i", "num_polys"), | |
] | |
f = open("polys.bin", "rb") | |
phead = PolyHeader.from_file(f) | |
print(phead.file_code == 0x1234) | |
print("min_x=", phead.min_x) | |
print("max_x=", phead.max_x) | |
print("min_y=", phead.min_y) | |
print("max_y=", phead.max_y) | |
print("num_polys=", phead.num_polys) | |
################################################################################ | |
## 6::reading_nested_and_variable_sized_binary_structures | |
# Example 3: Nested structure support | |
import struct | |
class StructField: | |
""" | |
Descriptor representing a simple structure field | |
""" | |
def __init__(self, format, offset): | |
self.format = format | |
self.offset = offset | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
r = struct.unpack_from(self.format, instance._buffer, self.offset) | |
return r[0] if len(r) == 1 else r | |
class NestedStruct: | |
""" | |
Descriptor representing a nested structure | |
""" | |
def __init__(self, name, struct_type, offset): | |
self.name = name | |
self.struct_type = struct_type | |
self.offset = offset | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
data = instance._buffer[ | |
self.offset : self.offset + self.struct_type.struct_size | |
] | |
result = self.struct_type(data) | |
setattr(instance, self.name, result) | |
return result | |
class StructureMeta(type): | |
""" | |
Metaclass that automatically creates StructField descriptors | |
""" | |
def __init__(self, clsname, bases, clsdict): | |
fields = getattr(self, "_fields_", []) | |
byte_order = "" | |
offset = 0 | |
for format, fieldname in fields: | |
if isinstance(format, StructureMeta): | |
setattr(self, fieldname, NestedStruct(fieldname, format, offset)) | |
offset += format.struct_size | |
else: | |
if format.startswith(("<", ">", "!", "@")): | |
byte_order = format[0] | |
format = format[1:] | |
format = byte_order + format | |
setattr(self, fieldname, StructField(format, offset)) | |
offset += struct.calcsize(format) | |
setattr(self, "struct_size", offset) | |
class Structure(metaclass=StructureMeta): | |
def __init__(self, bytedata): | |
self._buffer = memoryview(bytedata) | |
@classmethod | |
def from_file(cls, f): | |
return cls(f.read(cls.struct_size)) | |
if __name__ == "__main__": | |
class Point(Structure): | |
_fields_ = [("<d", "x"), ("d", "y")] | |
class PolyHeader(Structure): | |
_fields_ = [ | |
("<i", "file_code"), | |
(Point, "min"), | |
(Point, "max"), | |
("i", "num_polys"), | |
] | |
f = open("polys.bin", "rb") | |
phead = PolyHeader.from_file(f) | |
print(phead.file_code == 0x1234) | |
print("min.x=", phead.min.x) | |
print("max.x=", phead.max.x) | |
print("min.y=", phead.min.y) | |
print("max.y=", phead.max.y) | |
print("num_polys=", phead.num_polys) | |
################################################################################ | |
## 6::reading_nested_and_variable_sized_binary_structures | |
# Example 4: Variable sized chunks | |
import struct | |
class StructField: | |
""" | |
Descriptor representing a simple structure field | |
""" | |
def __init__(self, format, offset): | |
self.format = format | |
self.offset = offset | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
r = struct.unpack_from(self.format, instance._buffer, self.offset) | |
return r[0] if len(r) == 1 else r | |
class NestedStruct: | |
""" | |
Descriptor representing a nested structure | |
""" | |
def __init__(self, name, struct_type, offset): | |
self.name = name | |
self.struct_type = struct_type | |
self.offset = offset | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
data = instance._buffer[ | |
self.offset : self.offset + self.struct_type.struct_size | |
] | |
result = self.struct_type(data) | |
setattr(instance, self.name, result) | |
return result | |
class StructureMeta(type): | |
""" | |
Metaclass that automatically creates StructField descriptors | |
""" | |
def __init__(self, clsname, bases, clsdict): | |
fields = getattr(self, "_fields_", []) | |
byte_order = "" | |
offset = 0 | |
for format, fieldname in fields: | |
if isinstance(format, StructureMeta): | |
setattr(self, fieldname, NestedStruct(fieldname, format, offset)) | |
offset += format.struct_size | |
else: | |
if format.startswith(("<", ">", "!", "@")): | |
byte_order = format[0] | |
format = format[1:] | |
format = byte_order + format | |
setattr(self, fieldname, StructField(format, offset)) | |
offset += struct.calcsize(format) | |
setattr(self, "struct_size", offset) | |
class Structure(metaclass=StructureMeta): | |
def __init__(self, bytedata): | |
self._buffer = memoryview(bytedata) | |
@classmethod | |
def from_file(cls, f): | |
return cls(f.read(cls.struct_size)) | |
class SizedRecord: | |
def __init__(self, bytedata): | |
self._buffer = memoryview(bytedata) | |
@classmethod | |
def from_file(cls, f, size_fmt, includes_size=True): | |
sz_nbytes = struct.calcsize(size_fmt) | |
sz_bytes = f.read(sz_nbytes) | |
(sz,) = struct.unpack(size_fmt, sz_bytes) | |
buf = f.read(sz - includes_size * sz_nbytes) | |
return cls(buf) | |
def iter_as(self, code): | |
if isinstance(code, str): | |
s = struct.Struct(code) | |
for off in range(0, len(self._buffer), s.size): | |
yield s.unpack_from(self._buffer, off) | |
elif isinstance(code, StructureMeta): | |
size = code.struct_size | |
for off in range(0, len(self._buffer), size): | |
data = self._buffer[off : off + size] | |
yield code(data) | |
if __name__ == "__main__": | |
class Point(Structure): | |
_fields_ = [("<d", "x"), ("d", "y")] | |
class PolyHeader(Structure): | |
_fields_ = [ | |
("<i", "file_code"), | |
(Point, "min"), | |
(Point, "max"), | |
("i", "num_polys"), | |
] | |
def read_polys(filename): | |
polys = [] | |
with open(filename, "rb") as f: | |
phead = PolyHeader.from_file(f) | |
for n in range(phead.num_polys): | |
rec = SizedRecord.from_file(f, "<i") | |
poly = [(p.x, p.y) for p in rec.iter_as(Point)] | |
polys.append(poly) | |
return polys | |
polys = read_polys("polys.bin") | |
print(polys) | |
################################################################################ | |
## 6::reading_nested_and_variable_sized_binary_structures | |
import struct | |
import itertools | |
polys = [ | |
[(1.0, 2.5), (3.5, 4.0), (2.5, 1.5)], | |
[(7.0, 1.2), (5.1, 3.0), (0.5, 7.5), (0.8, 9.0)], | |
[(3.4, 6.3), (1.2, 0.5), (4.6, 9.2)], | |
] | |
def write_polys(filename, polys): | |
# Determine bounding box | |
flattened = list(itertools.chain(*polys)) | |
min_x = min(x for x, y in flattened) | |
max_x = max(x for x, y in flattened) | |
min_y = min(y for x, y in flattened) | |
max_y = max(y for x, y in flattened) | |
with open(filename, "wb") as f: | |
f.write(struct.pack("<iddddi", 0x1234, min_x, min_y, max_x, max_y, len(polys))) | |
for poly in polys: | |
size = len(poly) * struct.calcsize("<dd") | |
f.write(struct.pack("<i", size + 4)) | |
for pt in poly: | |
f.write(struct.pack("<dd", *pt)) | |
# Call it with our polygon data | |
write_polys("polys.bin", polys) | |
################################################################################ | |
## 7::accessing_variables_defined_inside_a_closure | |
# Example of accessing variables inside a closure | |
def sample(): | |
n = 0 | |
# Closure function | |
def func(): | |
print("n=", n) | |
# Accessor methods for n | |
def get_n(): | |
return n | |
def set_n(value): | |
nonlocal n | |
n = value | |
# Attach as function attributes | |
func.get_n = get_n | |
func.set_n = set_n | |
return func | |
if __name__ == "__main__": | |
f = sample() | |
f() | |
n = 0 | |
f.set_n(10) | |
f() | |
print(f.get_n()) | |
################################################################################ | |
## 7::accessing_variables_defined_inside_a_closure | |
# Example of faking classes with a closure | |
import sys | |
class ClosureInstance: | |
def __init__(self, locals=None): | |
if locals is None: | |
locals = sys._getframe(1).f_locals | |
# Update instance dictionary with callables | |
self.__dict__.update( | |
(key, value) for key, value in locals.items() if callable(value) | |
) | |
# Redirect special methods | |
def __len__(self): | |
return self.__dict__["__len__"]() | |
# Example use | |
def Stack(): | |
items = [] | |
def push(item): | |
items.append(item) | |
def pop(): | |
return items.pop() | |
def __len__(): | |
return len(items) | |
return ClosureInstance() | |
if __name__ == "__main__": | |
s = Stack() | |
print(s) | |
s.push(10) | |
s.push(20) | |
s.push("Hello") | |
print(len(s)) | |
print(s.pop()) | |
print(s.pop()) | |
print(s.pop()) | |
################################################################################ | |
## 7::accessing_variables_defined_inside_a_closure | |
# Example of a normal class | |
# Example use | |
class Stack2: | |
def __init__(self): | |
self.items = [] | |
def push(self, item): | |
self.items.append(item) | |
def pop(self): | |
return self.items.pop() | |
def __len__(self): | |
return len(self.items) | |
if __name__ == "__main__": | |
import example2 | |
from timeit import timeit | |
print("Using a class") | |
s = Stack2() | |
print(timeit("s.push(1); s.pop()", "from __main__ import s")) | |
print("Using a closure") | |
s = example2.Stack() | |
print(timeit("s.push(1); s.pop()", "from __main__ import s")) | |
################################################################################ | |
## 7::carrying_extra_state_with_callback_functions | |
# This example is about the problem of carrying extra state around | |
# through callback functions. To test the examples, this very | |
# simple code emulates the typical control of a callback. | |
def apply_async(func, args, *, callback): | |
# Compute the result | |
result = func(*args) | |
# Invoke the callback with the result | |
callback(result) | |
# A simple function for testing | |
def add(x, y): | |
return x + y | |
# (a) A simple callback example | |
print("# --- Simple Example") | |
def print_result(result): | |
print("Got:", result) | |
apply_async(add, (2, 3), callback=print_result) | |
apply_async(add, ("hello", "world"), callback=print_result) | |
# (b) Using a bound method | |
print("# --- Using a bound-method") | |
class ResultHandler: | |
def __init__(self): | |
self.sequence = 0 | |
def handler(self, result): | |
self.sequence += 1 | |
print("[{}] Got: {}".format(self.sequence, result)) | |
r = ResultHandler() | |
apply_async(add, (2, 3), callback=r.handler) | |
apply_async(add, ("hello", "world"), callback=r.handler) | |
# (c) Using a closure | |
print("# --- Using a closure") | |
def make_handler(): | |
sequence = 0 | |
def handler(result): | |
nonlocal sequence | |
sequence += 1 | |
print("[{}] Got: {}".format(sequence, result)) | |
return handler | |
handler = make_handler() | |
apply_async(add, (2, 3), callback=handler) | |
apply_async(add, ("hello", "world"), callback=handler) | |
# (d) Using a coroutine | |
print("# --- Using a coroutine") | |
def make_handler(): | |
sequence = 0 | |
while True: | |
result = yield | |
sequence += 1 | |
print("[{}] Got: {}".format(sequence, result)) | |
handler = make_handler() | |
next(handler) # Advance to the yield | |
apply_async(add, (2, 3), callback=handler.send) | |
apply_async(add, ("hello", "world"), callback=handler.send) | |
# (e) Partial function application | |
print("# --- Using partial") | |
class SequenceNo: | |
def __init__(self): | |
self.sequence = 0 | |
def handler(result, seq): | |
seq.sequence += 1 | |
print("[{}] Got: {}".format(seq.sequence, result)) | |
seq = SequenceNo() | |
from functools import partial | |
apply_async(add, (2, 3), callback=partial(handler, seq=seq)) | |
apply_async(add, ("hello", "world"), callback=partial(handler, seq=seq)) | |
################################################################################ | |
## 7::functions_that_accept_any_number_of_arguments | |
# Examples of *args and **kwargs functions | |
def avg(first, *rest): | |
return (first + sum(rest)) / (1 + len(rest)) | |
print(avg(1, 2)) | |
print(avg(1, 2, 3, 4)) | |
import html | |
def make_element(name, value, **attrs): | |
keyvals = [' %s="%s"' % item for item in attrs.items()] | |
attr_str = "".join(keyvals) | |
element = "<{name}{attrs}>{value}</{name}>".format( | |
name=name, attrs=attr_str, value=html.escape(value) | |
) | |
return element | |
# Example | |
# Creates '<item size="large" quantity="6">Albatross</item>' | |
print(make_element("item", "Albatross", size="large", quantity=6)) | |
print(make_element("p", "<spam>")) | |
################################################################################ | |
## 7::functions_that_only_accept_keyword_arguments | |
# examples of keyword-only argument functions | |
# A simple keyword-only argument | |
def recv(maxsize, *, block=True): | |
print(maxsize, block) | |
recv(8192, block=False) # Works | |
try: | |
recv(8192, False) # Fails | |
except TypeError as e: | |
print(e) | |
# Adding keyword-only args to *args functions | |
def minimum(*values, clip=None): | |
m = min(values) | |
if clip is not None: | |
m = clip if clip > m else m | |
return m | |
print(minimum(1, 5, 2, -5, 10)) | |
print(minimum(1, 5, 2, -5, 10, clip=0)) | |
################################################################################ | |
## 7::functions_with_default_arguments | |
# Examples of a function with default arguments | |
# (a) Dangers of using a mutable default argument | |
def spam(b=[]): | |
return b | |
a = spam() | |
print(a) | |
a.append(1) | |
a.append(2) | |
b = spam() | |
print(b) # Carefully observe result | |
print("-" * 10) | |
# (b) Better alternative for mutable defaults | |
def spam(b=None): | |
if b is None: | |
b = [] | |
return b | |
a = spam() | |
print(a) | |
a.append(1) | |
a.append(2) | |
b = spam() | |
print(b) | |
print("-" * 10) | |
# (c) Example of testing if an argument was supplied or not | |
_no_value = object() | |
def spam(b=_no_value): | |
if b is _no_value: | |
print("No b value supplied") | |
else: | |
print("b=", b) | |
spam() | |
spam(None) | |
spam(0) | |
spam([]) | |
################################################################################ | |
## 7::inlining_callback_functions | |
# Example of implementing an inlined-callback function | |
# Sample function to illustrate callback control flow | |
def apply_async(func, args, *, callback): | |
# Compute the result | |
result = func(*args) | |
# Invoke the callback with the result | |
callback(result) | |
# Inlined callback implementation | |
from queue import Queue | |
from functools import wraps | |
class Async: | |
def __init__(self, func, args): | |
self.func = func | |
self.args = args | |
def inlined_async(func): | |
@wraps(func) | |
def wrapper(*args): | |
f = func(*args) | |
result_queue = Queue() | |
result_queue.put(None) | |
while True: | |
result = result_queue.get() | |
try: | |
a = f.send(result) | |
apply_async(a.func, a.args, callback=result_queue.put) | |
except StopIteration: | |
break | |
return wrapper | |
# Sample use | |
def add(x, y): | |
return x + y | |
@inlined_async | |
def test(): | |
r = yield Async(add, (2, 3)) | |
print(r) | |
r = yield Async(add, ("hello", "world")) | |
print(r) | |
for n in range(10): | |
r = yield Async(add, (n, n)) | |
print(r) | |
print("Goodbye") | |
if __name__ == "__main__": | |
# Simple test | |
print("# --- Simple test") | |
test() | |
print("# --- Multiprocessing test") | |
import multiprocessing | |
pool = multiprocessing.Pool() | |
apply_async = pool.apply_async | |
test() | |
################################################################################ | |
## 7::making_an_n-argument_callable_work_as_a_callable_with_fewer_arguments | |
# Example of using partial() with sorting a list of (x,y) coordinates | |
import functools | |
points = [(1, 2), (3, 4), (5, 6), (7, 7)] | |
import math | |
def distance(p1, p2): | |
x1, y1 = p1 | |
x2, y2 = p2 | |
return math.hypot(x2 - x1, y2 - y1) | |
pt = (4, 3) | |
points.sort(key=functools.partial(distance, pt)) | |
print(points) | |
################################################################################ | |
## 7::making_an_n-argument_callable_work_as_a_callable_with_fewer_arguments | |
# Using partial to supply extra arguments to a callback function | |
import functools | |
def output_result(result, log=None): | |
if log is not None: | |
log.debug("Got: %r", result) | |
# A sample function | |
def add(x, y): | |
return x + y | |
if __name__ == "__main__": | |
import logging | |
from multiprocessing import Pool | |
from functools import partial | |
logging.basicConfig(level=logging.DEBUG) | |
log = logging.getLogger("test") | |
p = Pool() | |
p.apply_async(add, (3, 4), callback=functools.partial(output_result, log=log)) | |
p.close() | |
p.join() | |
################################################################################ | |
## 7::making_an_n-argument_callable_work_as_a_callable_with_fewer_arguments | |
# Using partial to supply extra arguments to a class constructor | |
from socketserver import StreamRequestHandler, TCPServer | |
class EchoHandler(StreamRequestHandler): | |
# ack is added keyword-only argument. *args, **kwargs are | |
# any normal parameters supplied (which are passed on) | |
def __init__(self, *args, ack, **kwargs): | |
self.ack = ack | |
super().__init__(*args, **kwargs) | |
def handle(self): | |
for line in self.rfile: | |
self.wfile.write(self.ack + line) | |
if __name__ == "__main__": | |
from functools import partial | |
serv = TCPServer(("", 15000), partial(EchoHandler, ack=b"RECEIVED:")) | |
print("Echo server running on port 15000") | |
serv.serve_forever() | |
################################################################################ | |
## 8::calling_a_method_on_a_parent_class | |
class A: | |
def spam(self): | |
print("A.spam") | |
class B(A): | |
def spam(self): | |
print("B.spam") | |
super().spam() # Call parent spam() | |
if __name__ == "__main__": | |
b = B() | |
b.spam() | |
################################################################################ | |
## 8::calling_a_method_on_a_parent_class | |
class A: | |
def __init__(self): | |
self.x = 0 | |
class B(A): | |
def __init__(self): | |
super().__init__() | |
self.y = 1 | |
if __name__ == "__main__": | |
b = B() | |
print(b.x, b.y) | |
################################################################################ | |
## 8::calling_a_method_on_a_parent_class | |
class Proxy: | |
def __init__(self, obj): | |
self._obj = obj | |
# Delegate attribute lookup to internal obj | |
def __getattr__(self, name): | |
return getattr(self._obj, name) | |
# Delegate attribute assignment | |
def __setattr__(self, name, value): | |
if name.startswith("_"): | |
super().__setattr__(name, value) # Call original __setattr__ | |
else: | |
setattr(self._obj, name, value) | |
if __name__ == "__main__": | |
class A: | |
def __init__(self, x): | |
self.x = x | |
def spam(self): | |
print("A.spam") | |
a = A(42) | |
p = Proxy(a) | |
print(p.x) | |
print(p.spam()) | |
p.x = 37 | |
print("Should be 37:", p.x) | |
print("Should be 37:", a.x) | |
################################################################################ | |
## 8::calling_a_method_on_a_parent_class | |
# Tricky initialization problem involving multiple inheritance. | |
# Does NOT use super() | |
class Base: | |
def __init__(self): | |
print("Base.__init__") | |
class A(Base): | |
def __init__(self): | |
Base.__init__(self) | |
print("A.__init__") | |
class B(Base): | |
def __init__(self): | |
Base.__init__(self) | |
print("B.__init__") | |
class C(A, B): | |
def __init__(self): | |
A.__init__(self) | |
B.__init__(self) | |
print("C.__init__") | |
if __name__ == "__main__": | |
# Please observe double call of Base.__init__ | |
c = C() | |
################################################################################ | |
## 8::calling_a_method_on_a_parent_class | |
# Tricky initialization problem involving multiple inheritance. | |
# Uses super() | |
class Base: | |
def __init__(self): | |
print("Base.__init__") | |
class A(Base): | |
def __init__(self): | |
super().__init__() | |
print("A.__init__") | |
class B(Base): | |
def __init__(self): | |
super().__init__() | |
print("B.__init__") | |
class C(A, B): | |
def __init__(self): | |
super().__init__() # Only one call to super() here | |
print("C.__init__") | |
if __name__ == "__main__": | |
# Observe that each class initialized only once | |
c = C() | |
################################################################################ | |
## 8::calling_a_method_on_an_object_given_the_name_as_a_string | |
# Example of calling methods by name | |
import math | |
class Point: | |
def __init__(self, x, y): | |
self.x = x | |
self.y = y | |
def __repr__(self): | |
return "Point({!r:},{!r:})".format(self.x, self.y) | |
def distance(self, x, y): | |
return math.hypot(self.x - x, self.y - y) | |
p = Point(2, 3) | |
# Method 1 : Use getattr | |
d = getattr(p, "distance")(0, 0) # Calls p.distance(0, 0) | |
print(d) | |
# Method 2: Use methodcaller | |
import operator | |
d = operator.methodcaller("distance", 0, 0)(p) | |
print(d) | |
# Application in sorting | |
points = [ | |
Point(1, 2), | |
Point(3, 0), | |
Point(10, -3), | |
Point(-5, -7), | |
Point(-1, 8), | |
Point(3, 2), | |
] | |
# Sort by distance from origin (0, 0) | |
points.sort(key=operator.methodcaller("distance", 0, 0)) | |
for p in points: | |
print(p) | |
################################################################################ | |
## 8::changing_the_string_representation_of_instances | |
class Pair: | |
def __init__(self, x, y): | |
self.x = x | |
self.y = y | |
def __repr__(self): | |
return "Pair({0.x!r}, {0.y!r})".format(self) | |
def __str__(self): | |
return "({0.x}, {0.y})".format(self) | |
################################################################################ | |
## 8::creating_a_new_kind_of_class_or_instance_attribute | |
# Descriptor attribute for an integer type-checked attribute | |
class Integer: | |
def __init__(self, name): | |
self.name = name | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
return instance.__dict__[self.name] | |
def __set__(self, instance, value): | |
if not isinstance(value, int): | |
raise TypeError("Expected an int") | |
instance.__dict__[self.name] = value | |
def __delete__(self, instance): | |
del instance.__dict__[self.name] | |
class Point: | |
x = Integer("x") | |
y = Integer("y") | |
def __init__(self, x, y): | |
self.x = x | |
self.y = y | |
if __name__ == "__main__": | |
p = Point(2, 3) | |
print(p.x) | |
p.y = 5 | |
try: | |
p.x = 2.3 | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 8::creating_a_new_kind_of_class_or_instance_attribute | |
# Descriptor for a type-checked attribute | |
class Typed: | |
def __init__(self, name, expected_type): | |
self.name = name | |
self.expected_type = expected_type | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
return instance.__dict__[self.name] | |
def __set__(self, instance, value): | |
if not isinstance(value, self.expected_type): | |
raise TypeError("Expected " + str(self.expected_type)) | |
instance.__dict__[self.name] = value | |
def __delete__(self, instance): | |
del instance.__dict__[self.name] | |
# Class decorator that applies it to selected attributes | |
def typeassert(**kwargs): | |
def decorate(cls): | |
for name, expected_type in kwargs.items(): | |
# Attach a Typed descriptor to the class | |
setattr(cls, name, Typed(name, expected_type)) | |
return cls | |
return decorate | |
# Example use | |
@typeassert(name=str, shares=int, price=float) | |
class Stock: | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
if __name__ == "__main__": | |
s = Stock("ACME", 100, 490.1) | |
print(s.name, s.shares, s.price) | |
s.shares = 50 | |
try: | |
s.shares = "a lot" | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 8::creating_an_instance_without_invoking_init | |
from time import localtime | |
class Date: | |
def __init__(self, year, month, day): | |
self.year = year | |
self.month = month | |
self.day = day | |
# Class method that bypasses __init__ | |
@classmethod | |
def today(cls): | |
d = cls.__new__(cls) | |
t = localtime() | |
d.year = t.tm_year | |
d.month = t.tm_mon | |
d.day = t.tm_mday | |
return d | |
d = Date.__new__(Date) | |
print(d) | |
print(hasattr(d, "year")) | |
data = {"year": 2012, "month": 8, "day": 29} | |
d.__dict__.update(data) | |
print(d.year) | |
print(d.month) | |
d = Date.today() | |
print(d.year, d.month, d.day) | |
################################################################################ | |
## 8::creating_cached_instances | |
# Simple example | |
class Spam: | |
def __init__(self, name): | |
self.name = name | |
# Caching support | |
import weakref | |
_spam_cache = weakref.WeakValueDictionary() | |
def get_spam(name): | |
if name not in _spam_cache: | |
s = Spam(name) | |
_spam_cache[name] = s | |
else: | |
s = _spam_cache[name] | |
return s | |
if __name__ == "__main__": | |
a = get_spam("foo") | |
b = get_spam("bar") | |
print("a is b:", a is b) | |
c = get_spam("foo") | |
print("a is c:", a is c) | |
################################################################################ | |
## 8::creating_cached_instances | |
import weakref | |
class CachedSpamManager: | |
def __init__(self): | |
self._cache = weakref.WeakValueDictionary() | |
def get_spam(self, name): | |
if name not in self._cache: | |
s = Spam(name) | |
self._cache[name] = s | |
else: | |
s = self._cache[name] | |
return s | |
class Spam: | |
def __init__(self, name): | |
self.name = name | |
Spam.manager = CachedSpamManager() | |
def get_spam(name): | |
return Spam.manager.get_spam(name) | |
if __name__ == "__main__": | |
a = get_spam("foo") | |
b = get_spam("bar") | |
print("a is b:", a is b) | |
c = get_spam("foo") | |
print("a is c:", a is c) | |
################################################################################ | |
## 8::creating_cached_instances | |
# Example involving new and some of its problems | |
import weakref | |
class Spam: | |
_spam_cache = weakref.WeakValueDictionary() | |
def __new__(cls, name): | |
if name in cls._spam_cache: | |
return cls._spam_cache[name] | |
else: | |
self = super().__new__(cls) | |
cls._spam_cache[name] = self | |
return self | |
def __init__(self, name): | |
print("Initializing Spam") | |
self.name = name | |
if __name__ == "__main__": | |
print("This should print 'Initializing Spam' twice") | |
s = Spam("Dave") | |
t = Spam("Dave") | |
print(s is t) | |
################################################################################ | |
## 8::creating_managed_attributes | |
# Example of managed attributes via properties | |
class Person: | |
def __init__(self, first_name): | |
self.first_name = first_name | |
# Getter function | |
@property | |
def first_name(self): | |
return self._first_name | |
# Setter function | |
@first_name.setter | |
def first_name(self, value): | |
if not isinstance(value, str): | |
raise TypeError("Expected a string") | |
self._first_name = value | |
if __name__ == "__main__": | |
a = Person("Guido") | |
print(a.first_name) | |
a.first_name = "Dave" | |
print(a.first_name) | |
try: | |
a.first_name = 42 | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 8::customized_formatting | |
_formats = { | |
"ymd": "{d.year}-{d.month}-{d.day}", | |
"mdy": "{d.month}/{d.day}/{d.year}", | |
"dmy": "{d.day}/{d.month}/{d.year}", | |
} | |
class Date: | |
def __init__(self, year, month, day): | |
self.year = year | |
self.month = month | |
self.day = day | |
def __format__(self, code): | |
if code == "": | |
code = "ymd" | |
fmt = _formats[code] | |
return fmt.format(d=self) | |
################################################################################ | |
## 8::delegation_and_proxies | |
class A: | |
def spam(self, x): | |
print("A.spam") | |
def foo(self): | |
print("A.foo") | |
class B: | |
def __init__(self): | |
self._a = A() | |
def bar(self): | |
print("B.bar") | |
# Expose all of the methods defined on class A | |
def __getattr__(self, name): | |
return getattr(self._a, name) | |
if __name__ == "__main__": | |
b = B() | |
b.bar() | |
b.spam(42) | |
################################################################################ | |
## 8::delegation_and_proxies | |
# A proxy class that wraps around another object, but | |
# exposes its public attributes | |
class Proxy: | |
def __init__(self, obj): | |
self._obj = obj | |
# Delegate attribute lookup to internal obj | |
def __getattr__(self, name): | |
print("getattr:", name) | |
return getattr(self._obj, name) | |
# Delegate attribute assignment | |
def __setattr__(self, name, value): | |
if name.startswith("_"): | |
super().__setattr__(name, value) | |
else: | |
print("setattr:", name, value) | |
setattr(self._obj, name, value) | |
# Delegate attribute deletion | |
def __delattr__(self, name): | |
if name.startswith("_"): | |
super().__delattr__(name) | |
else: | |
print("delattr:", name) | |
delattr(self._obj, name) | |
if __name__ == "__main__": | |
class Spam: | |
def __init__(self, x): | |
self.x = x | |
def bar(self, y): | |
print("Spam.bar:", self.x, y) | |
# Create an instance | |
s = Spam(2) | |
# Create a proxy around it | |
p = Proxy(s) | |
# Access the proxy | |
print(p.x) # Outputs 2 | |
p.bar(3) # Outputs "Spam.bar: 2 3" | |
p.x = 37 # Changes s.x to 37 | |
################################################################################ | |
## 8::delegation_and_proxies | |
class ListLike: | |
def __init__(self): | |
self._items = [] | |
def __getattr__(self, name): | |
return getattr(self._items, name) | |
# Added special methods to support certain list operations | |
def __len__(self): | |
return len(self._items) | |
def __getitem__(self, index): | |
return self._items[index] | |
def __setitem__(self, index, value): | |
self._items[index] = value | |
def __delitem__(self, index): | |
del self._items[index] | |
if __name__ == "__main__": | |
a = ListLike() | |
a.append(2) | |
a.insert(0, 1) | |
a.sort() | |
print(len(a)) | |
print(a[0]) | |
################################################################################ | |
## 8::delegation_and_proxies | |
class A: | |
def spam(self): | |
print("A.spam") | |
def foo(self): | |
print("A.foo") | |
class B: | |
def __init__(self): | |
self._a = A() | |
def spam(self): | |
print("B.spam") | |
self._a.spam() # Similar to super() | |
def __getattr__(self, name): | |
return getattr(self._a, name) | |
if __name__ == "__main__": | |
b = B() | |
b.spam() | |
b.foo() | |
################################################################################ | |
## 8::extending_a_property_in_a_subclass | |
# Example of managed attributes via properties | |
class Person: | |
def __init__(self, name): | |
self.name = name | |
# Getter function | |
@property | |
def name(self): | |
return self._name | |
# Setter function | |
@name.setter | |
def name(self, value): | |
if not isinstance(value, str): | |
raise TypeError("Expected a string") | |
self._name = value | |
@name.deleter | |
def name(self): | |
raise AttributeError("Can't delete attribute") | |
class SubPerson(Person): | |
@property | |
def name(self): | |
print("Getting name") | |
return super().name | |
@name.setter | |
def name(self, value): | |
print("Setting name to", value) | |
super(SubPerson, SubPerson).name.__set__(self, value) | |
@name.deleter | |
def name(self): | |
print("Deleting name") | |
super(SubPerson, SubPerson).name.__delete__(self) | |
class SubPerson2(Person): | |
@Person.name.setter | |
def name(self, value): | |
print("Setting name to", value) | |
super(SubPerson2, SubPerson2).name.__set__(self, value) | |
class SubPerson3(Person): | |
# @property | |
@Person.name.getter | |
def name(self): | |
print("Getting name") | |
return super().name | |
if __name__ == "__main__": | |
a = Person("Guido") | |
print(a.name) | |
a.name = "Dave" | |
print(a.name) | |
try: | |
a.name = 42 | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 8::extending_a_property_in_a_subclass | |
# Example of managed attributes via properties | |
class String: | |
def __init__(self, name): | |
self.name = name | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
return instance.__dict__[self.name] | |
def __set__(self, instance, value): | |
if not isinstance(value, str): | |
raise TypeError("Expected a string") | |
instance.__dict__[self.name] = value | |
class Person: | |
name = String("name") | |
def __init__(self, name): | |
self.name = name | |
class SubPerson(Person): | |
@property | |
def name(self): | |
print("Getting name") | |
return super().name | |
@name.setter | |
def name(self, value): | |
print("Setting name to", value) | |
super(SubPerson, SubPerson).name.__set__(self, value) | |
@name.deleter | |
def name(self): | |
print("Deleting name") | |
super(SubPerson, SubPerson).name.__delete__(self) | |
if __name__ == "__main__": | |
a = Person("Guido") | |
print(a.name) | |
a.name = "Dave" | |
print(a.name) | |
try: | |
a.name = 42 | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 8::extending_classes_with_mixins | |
class LoggedMappingMixin: | |
""" | |
Add logging to get/set/delete operations for debugging. | |
""" | |
__slots__ = () | |
def __getitem__(self, key): | |
print("Getting " + str(key)) | |
return super().__getitem__(key) | |
def __setitem__(self, key, value): | |
print("Setting {} = {!r}".format(key, value)) | |
return super().__setitem__(key, value) | |
def __delitem__(self, key): | |
print("Deleting " + str(key)) | |
return super().__delitem__(key) | |
class SetOnceMappingMixin: | |
""" | |
Only allow a key to be set once. | |
""" | |
__slots__ = () | |
def __setitem__(self, key, value): | |
if key in self: | |
raise KeyError(str(key) + " already set") | |
return super().__setitem__(key, value) | |
class StringKeysMappingMixin: | |
""" | |
Restrict keys to strings only | |
""" | |
__slots__ = () | |
def __setitem__(self, key, value): | |
if not isinstance(key, str): | |
raise TypeError("keys must be strings") | |
return super().__setitem__(key, value) | |
# Examples | |
print("# ---- LoggedDict Example") | |
class LoggedDict(LoggedMappingMixin, dict): | |
pass | |
d = LoggedDict() | |
d["x"] = 23 | |
print(d["x"]) | |
del d["x"] | |
print("# ---- SetOnceDefaultDict Example") | |
from collections import defaultdict | |
class SetOnceDefaultDict(SetOnceMappingMixin, defaultdict): | |
pass | |
d = SetOnceDefaultDict(list) | |
d["x"].append(2) | |
d["y"].append(3) | |
d["x"].append(10) | |
try: | |
d["x"] = 23 | |
except KeyError as e: | |
print(e) | |
print("# ---- StringOrderedDict Example") | |
from collections import OrderedDict | |
class StringOrderedDict(StringKeysMappingMixin, SetOnceMappingMixin, OrderedDict): | |
pass | |
d = StringOrderedDict() | |
d["x"] = 23 | |
try: | |
d[42] = 10 | |
except TypeError as e: | |
print(e) | |
try: | |
d["x"] = 42 | |
except KeyError as e: | |
print(e) | |
################################################################################ | |
## 8::extending_classes_with_mixins | |
class RestrictKeysMixin: | |
def __init__(self, *args, _restrict_key_type, **kwargs): | |
self.__restrict_key_type = _restrict_key_type | |
super().__init__(*args, **kwargs) | |
def __setitem__(self, key, value): | |
if not isinstance(key, self.__restrict_key_type): | |
raise TypeError("Keys must be " + str(self.__restrict_key_type)) | |
super().__setitem__(key, value) | |
# Example | |
class RDict(RestrictKeysMixin, dict): | |
pass | |
d = RDict(_restrict_key_type=str) | |
e = RDict([("name", "Dave"), ("n", 37)], _restrict_key_type=str) | |
f = RDict(name="Dave", n=37, _restrict_key_type=str) | |
print(f) | |
try: | |
f[42] = 10 | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 8::extending_classes_with_mixins | |
# Class decorator alternative to mixins | |
def LoggedMapping(cls): | |
cls_getitem = cls.__getitem__ | |
cls_setitem = cls.__setitem__ | |
cls_delitem = cls.__delitem__ | |
def __getitem__(self, key): | |
print("Getting %s" % key) | |
return cls_getitem(self, key) | |
def __setitem__(self, key, value): | |
print("Setting %s = %r" % (key, value)) | |
return cls_setitem(self, key, value) | |
def __delitem__(self, key): | |
print("Deleting %s" % key) | |
return cls_delitem(self, key) | |
cls.__getitem__ = __getitem__ | |
cls.__setitem__ = __setitem__ | |
cls.__delitem__ = __delitem__ | |
return cls | |
@LoggedMapping | |
class LoggedDict(dict): | |
pass | |
d = LoggedDict() | |
d["x"] = 23 | |
print(d["x"]) | |
del d["x"] | |
################################################################################ | |
## 8::how_to_define_an_interface_or_abstract_base_class | |
# Defining a simple abstract base class | |
from abc import ABCMeta, abstractmethod | |
class IStream(metaclass=ABCMeta): | |
@abstractmethod | |
def read(self, maxbytes=-1): | |
pass | |
@abstractmethod | |
def write(self, data): | |
pass | |
# Example implementation | |
class SocketStream(IStream): | |
def read(self, maxbytes=-1): | |
print("reading") | |
def write(self, data): | |
print("writing") | |
# Example of type checking | |
def serialize(obj, stream): | |
if not isinstance(stream, IStream): | |
raise TypeError("Expected an IStream") | |
print("serializing") | |
# Examples | |
if __name__ == "__main__": | |
# Attempt to instantiate ABC directly (doesn't work) | |
try: | |
a = IStream() | |
except TypeError as e: | |
print(e) | |
# Instantiation of a concrete implementation | |
a = SocketStream() | |
a.read() | |
a.write("data") | |
# Passing to type-check function | |
serialize(None, a) | |
# Attempt to pass a file-like object to serialize (fails) | |
import sys | |
try: | |
serialize(None, sys.stdout) | |
except TypeError as e: | |
print(e) | |
# Register file streams and retry | |
import io | |
IStream.register(io.IOBase) | |
serialize(None, sys.stdout) | |
################################################################################ | |
## 8::how_to_define_an_interface_or_abstract_base_class | |
from abc import ABCMeta, abstractmethod | |
class A(metaclass=ABCMeta): | |
@property | |
@abstractmethod | |
def name(self): | |
pass | |
@name.setter | |
@abstractmethod | |
def name(self, value): | |
pass | |
@classmethod | |
@abstractmethod | |
def method1(cls): | |
pass | |
@staticmethod | |
@abstractmethod | |
def method2(): | |
pass | |
################################################################################ | |
## 8::how_to_define_more_than_one_constructor_in_a_class | |
import time | |
class Date: | |
# Primary constructor | |
def __init__(self, year, month, day): | |
self.year = year | |
self.month = month | |
self.day = day | |
# Alternate constructor | |
@classmethod | |
def today(cls): | |
t = time.localtime() | |
return cls(t.tm_year, t.tm_mon, t.tm_mday) | |
if __name__ == "__main__": | |
a = Date(2012, 12, 21) | |
b = Date.today() | |
print(a.year, a.month, a.day) | |
print(b.year, b.month, b.day) | |
class NewDate(Date): | |
pass | |
c = Date.today() | |
d = NewDate.today() | |
print("Should be Date instance:", Date) | |
print("Should be NewDate instance:", NewDate) | |
################################################################################ | |
## 8::how_to_define_more_than_one_constructor_in_a_class | |
import time | |
class Date: | |
# Primary constructor | |
def __init__(self, year, month, day): | |
self.year = year | |
self.month = month | |
self.day = day | |
# Alternate constructor | |
@classmethod | |
def today(cls): | |
t = time.localtime() | |
d = cls.__new__(cls) | |
d.year = t.tm_year | |
d.month = t.tm_mon | |
d.day = t.tm_mday | |
return d | |
if __name__ == "__main__": | |
a = Date(2012, 12, 21) | |
b = Date.today() | |
print(a.year, a.month, a.day) | |
print(b.year, b.month, b.day) | |
class NewDate(Date): | |
pass | |
c = Date.today() | |
d = NewDate.today() | |
print("Should be Date instance:", Date) | |
print("Should be NewDate instance:", NewDate) | |
################################################################################ | |
## 8::how_to_encapsulate_names_in_a_class | |
# Example of using __ method name to implement a | |
# non-overrideable method | |
class B: | |
def __init__(self): | |
self.__private = 0 | |
def __private_method(self): | |
print("B.__private_method", self.__private) | |
def public_method(self): | |
self.__private_method() | |
class C(B): | |
def __init__(self): | |
super().__init__() | |
self.__private = 1 # Does not override B.__private | |
# Does not override B.__private_method() | |
def __private_method(self): | |
print("C.__private_method") | |
c = C() | |
c.public_method() | |
################################################################################ | |
## 8::implementing_a_data_model_or_type_system | |
# Base class. Uses a descriptor to set a value | |
class Descriptor: | |
def __init__(self, name=None, **opts): | |
self.name = name | |
self.__dict__.update(opts) | |
def __set__(self, instance, value): | |
instance.__dict__[self.name] = value | |
# Descriptor for enforcing types | |
class Typed(Descriptor): | |
expected_type = type(None) | |
def __set__(self, instance, value): | |
if not isinstance(value, self.expected_type): | |
raise TypeError("expected " + str(self.expected_type)) | |
super().__set__(instance, value) | |
# Descriptor for enforcing values | |
class Unsigned(Descriptor): | |
def __set__(self, instance, value): | |
if value < 0: | |
raise ValueError("Expected >= 0") | |
super().__set__(instance, value) | |
class MaxSized(Descriptor): | |
def __init__(self, name=None, **opts): | |
if "size" not in opts: | |
raise TypeError("missing size option") | |
self.size = opts["size"] | |
super().__init__(name, **opts) | |
def __set__(self, instance, value): | |
if len(value) >= self.size: | |
raise ValueError("size must be < " + str(self.size)) | |
super().__set__(instance, value) | |
class Integer(Typed): | |
expected_type = int | |
class UnsignedInteger(Integer, Unsigned): | |
pass | |
class Float(Typed): | |
expected_type = float | |
class UnsignedFloat(Float, Unsigned): | |
pass | |
class String(Typed): | |
expected_type = str | |
class SizedString(String, MaxSized): | |
pass | |
# Class decorator to apply constraints | |
def check_attributes(**kwargs): | |
def decorate(cls): | |
for key, value in kwargs.items(): | |
if isinstance(value, Descriptor): | |
value.name = key | |
setattr(cls, key, value) | |
else: | |
setattr(cls, key, value(key)) | |
return cls | |
return decorate | |
# A metaclass that applies checking | |
class checkedmeta(type): | |
def __new__(cls, clsname, bases, methods): | |
# Attach attribute names to the descriptors | |
for key, value in methods.items(): | |
if isinstance(value, Descriptor): | |
value.name = key | |
return type.__new__(cls, clsname, bases, methods) | |
# Testing code | |
def test(s): | |
print(s.name) | |
s.shares = 75 | |
print(s.shares) | |
try: | |
s.shares = -10 | |
except ValueError as e: | |
print(e) | |
try: | |
s.price = "a lot" | |
except TypeError as e: | |
print(e) | |
try: | |
s.name = "ABRACADABRA" | |
except ValueError as e: | |
print(e) | |
# Various Examples: | |
if __name__ == "__main__": | |
print("# --- Class with descriptors") | |
class Stock: | |
# Specify constraints | |
name = SizedString("name", size=8) | |
shares = UnsignedInteger("shares") | |
price = UnsignedFloat("price") | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
s = Stock("ACME", 50, 91.1) | |
test(s) | |
print("# --- Class with class decorator") | |
@check_attributes( | |
name=SizedString(size=8), shares=UnsignedInteger, price=UnsignedFloat | |
) | |
class Stock: | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
s = Stock("ACME", 50, 91.1) | |
test(s) | |
print("# --- Class with metaclass") | |
class Stock(metaclass=checkedmeta): | |
name = SizedString(size=8) | |
shares = UnsignedInteger() | |
price = UnsignedFloat() | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
s = Stock("ACME", 50, 91.1) | |
test(s) | |
################################################################################ | |
## 8::implementing_a_data_model_or_type_system | |
# Base class. Uses a descriptor to set a value | |
class Descriptor: | |
def __init__(self, name=None, **opts): | |
self.name = name | |
self.__dict__.update(opts) | |
def __set__(self, instance, value): | |
instance.__dict__[self.name] = value | |
def Typed(expected_type, cls=None): | |
if cls is None: | |
return lambda cls: Typed(expected_type, cls) | |
super_set = cls.__set__ | |
def __set__(self, instance, value): | |
if not isinstance(value, expected_type): | |
raise TypeError("expected " + str(expected_type)) | |
super_set(self, instance, value) | |
cls.__set__ = __set__ | |
return cls | |
def Unsigned(cls): | |
super_set = cls.__set__ | |
def __set__(self, instance, value): | |
if value < 0: | |
raise ValueError("Expected >= 0") | |
super_set(self, instance, value) | |
cls.__set__ = __set__ | |
return cls | |
def MaxSized(cls): | |
super_init = cls.__init__ | |
def __init__(self, name=None, **opts): | |
if "size" not in opts: | |
raise TypeError("missing size option") | |
self.size = opts["size"] | |
super_init(self, name, **opts) | |
cls.__init__ = __init__ | |
super_set = cls.__set__ | |
def __set__(self, instance, value): | |
if len(value) >= self.size: | |
raise ValueError("size must be < " + str(self.size)) | |
super_set(self, instance, value) | |
cls.__set__ = __set__ | |
return cls | |
@Typed(int) | |
class Integer(Descriptor): | |
pass | |
@Unsigned | |
class UnsignedInteger(Integer): | |
pass | |
@Typed(float) | |
class Float(Descriptor): | |
pass | |
@Unsigned | |
class UnsignedFloat(Float): | |
pass | |
@Typed(str) | |
class String(Descriptor): | |
pass | |
@MaxSized | |
class SizedString(String): | |
pass | |
# Class decorator to apply constraints | |
def check_attributes(**kwargs): | |
def decorate(cls): | |
for key, value in kwargs.items(): | |
if isinstance(value, Descriptor): | |
value.name = key | |
setattr(cls, key, value) | |
else: | |
setattr(cls, key, value(key)) | |
return cls | |
return decorate | |
# A metaclass that applies checking | |
class checkedmeta(type): | |
def __new__(cls, clsname, bases, methods): | |
# Attach attribute names to the descriptors | |
for key, value in methods.items(): | |
if isinstance(value, Descriptor): | |
value.name = key | |
return type.__new__(cls, clsname, bases, methods) | |
# Testing code | |
def test(s): | |
print(s.name) | |
s.shares = 75 | |
print(s.shares) | |
try: | |
s.shares = -10 | |
except ValueError as e: | |
print(e) | |
try: | |
s.price = "a lot" | |
except TypeError as e: | |
print(e) | |
try: | |
s.name = "ABRACADABRA" | |
except ValueError as e: | |
print(e) | |
# Various Examples: | |
if __name__ == "__main__": | |
print("# --- Class with descriptors") | |
class Stock: | |
# Specify constraints | |
name = SizedString("name", size=8) | |
shares = UnsignedInteger("shares") | |
price = UnsignedFloat("price") | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
s = Stock("ACME", 50, 91.1) | |
test(s) | |
print("# --- Class with class decorator") | |
@check_attributes( | |
name=SizedString(size=8), shares=UnsignedInteger, price=UnsignedFloat | |
) | |
class Stock: | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
s = Stock("ACME", 50, 91.1) | |
test(s) | |
print("# --- Class with metaclass") | |
class Stock(metaclass=checkedmeta): | |
name = SizedString(size=8) | |
shares = UnsignedInteger() | |
price = UnsignedFloat() | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
s = Stock("ACME", 50, 91.1) | |
test(s) | |
################################################################################ | |
## 8::implementing_custom_containers | |
# Example of a custom container | |
import collections | |
import bisect | |
class SortedItems(collections.Sequence): | |
def __init__(self, initial=None): | |
self._items = sorted(initial) if initial is not None else [] | |
# Required sequence methods | |
def __getitem__(self, index): | |
return self._items[index] | |
def __len__(self): | |
return len(self._items) | |
# Method for adding an item in the right location | |
def add(self, item): | |
bisect.insort(self._items, item) | |
if __name__ == "__main__": | |
items = SortedItems([5, 1, 3]) | |
print(list(items)) | |
print(items[0]) | |
print(items[-1]) | |
items.add(2) | |
print(list(items)) | |
items.add(-10) | |
print(list(items)) | |
print(items[1:4]) | |
print(3 in items) | |
print(len(items)) | |
for n in items: | |
print(n) | |
################################################################################ | |
## 8::implementing_custom_containers | |
import collections | |
class Items(collections.MutableSequence): | |
def __init__(self, initial=None): | |
self._items = list(initial) if initial is not None else [] | |
# Required sequence methods | |
def __getitem__(self, index): | |
print("Getting:", index) | |
return self._items[index] | |
def __setitem__(self, index, value): | |
print("Setting:", index, value) | |
self._items[index] = value | |
def __delitem__(self, index): | |
print("Deleting:", index) | |
del self._items[index] | |
def insert(self, index, value): | |
print("Inserting:", index, value) | |
self._items.insert(index, value) | |
def __len__(self): | |
print("Len") | |
return len(self._items) | |
if __name__ == "__main__": | |
a = Items([1, 2, 3]) | |
print(len(a)) | |
a.append(4) | |
a.append(2) | |
print(a.count(2)) | |
a.remove(3) | |
################################################################################ | |
## 8::implementing_stateful_objects_or_state_machines | |
class Connection: | |
def __init__(self): | |
self.new_state(ClosedConnection) | |
def new_state(self, state): | |
self.__class__ = state | |
def read(self): | |
raise NotImplementedError() | |
def write(self, data): | |
raise NotImplementedError() | |
def open(self): | |
raise NotImplementedError() | |
def close(self): | |
raise NotImplementedError() | |
class ClosedConnection(Connection): | |
def read(self): | |
raise RuntimeError("Not open") | |
def write(self, data): | |
raise RuntimeError("Not open") | |
def open(self): | |
self.new_state(OpenConnection) | |
def close(self): | |
raise RuntimeError("Already closed") | |
class OpenConnection(Connection): | |
def read(self): | |
print("reading") | |
def write(self, data): | |
print("writing") | |
def open(self): | |
raise RuntimeError("Already open") | |
def close(self): | |
self.new_state(ClosedConnection) | |
# Example | |
if __name__ == "__main__": | |
c = Connection() | |
print(c) | |
try: | |
c.read() | |
except RuntimeError as e: | |
print(e) | |
c.open() | |
print(c) | |
c.read() | |
c.close() | |
print(c) | |
################################################################################ | |
## 8::implementing_stateful_objects_or_state_machines | |
class Connection: | |
def __init__(self): | |
self.new_state(ClosedConnectionState) | |
def new_state(self, newstate): | |
self._state = newstate | |
# Delegate to the state class | |
def read(self): | |
return self._state.read(self) | |
def write(self, data): | |
return self._state.write(self, data) | |
def open(self): | |
return self._state.open(self) | |
def close(self): | |
return self._state.close(self) | |
# Connection state base class | |
class ConnectionState: | |
@staticmethod | |
def read(conn): | |
raise NotImplementedError() | |
@staticmethod | |
def write(conn, data): | |
raise NotImplementedError() | |
@staticmethod | |
def open(conn): | |
raise NotImplementedError() | |
@staticmethod | |
def close(conn): | |
raise NotImplementedError() | |
# Implementation of different states | |
class ClosedConnectionState(ConnectionState): | |
@staticmethod | |
def read(conn): | |
raise RuntimeError("Not open") | |
@staticmethod | |
def write(conn, data): | |
raise RuntimeError("Not open") | |
@staticmethod | |
def open(conn): | |
conn.new_state(OpenConnectionState) | |
@staticmethod | |
def close(conn): | |
raise RuntimeError("Already closed") | |
class OpenConnectionState(ConnectionState): | |
@staticmethod | |
def read(conn): | |
print("reading") | |
@staticmethod | |
def write(conn, data): | |
print("writing") | |
@staticmethod | |
def open(conn): | |
raise RuntimeError("Already open") | |
@staticmethod | |
def close(conn): | |
conn.new_state(ClosedConnectionState) | |
# Example | |
if __name__ == "__main__": | |
c = Connection() | |
print(c) | |
try: | |
c.read() | |
except RuntimeError as e: | |
print(e) | |
c.open() | |
print(c) | |
c.read() | |
c.close() | |
print(c) | |
################################################################################ | |
## 8::implementing_the_visitor_pattern | |
# Example of the visitor pattern | |
# --- The following classes represent nodes in an expression tree | |
class Node: | |
pass | |
class UnaryOperator(Node): | |
def __init__(self, operand): | |
self.operand = operand | |
class BinaryOperator(Node): | |
def __init__(self, left, right): | |
self.left = left | |
self.right = right | |
class Add(BinaryOperator): | |
pass | |
class Sub(BinaryOperator): | |
pass | |
class Mul(BinaryOperator): | |
pass | |
class Div(BinaryOperator): | |
pass | |
class Negate(UnaryOperator): | |
pass | |
class Number(Node): | |
def __init__(self, value): | |
self.value = value | |
# --- The visitor base class | |
class NodeVisitor: | |
def visit(self, node): | |
methname = "visit_" + type(node).__name__ | |
meth = getattr(self, methname, None) | |
if meth is None: | |
meth = self.generic_visit | |
return meth(node) | |
def generic_visit(self, node): | |
raise RuntimeError("No {} method".format("visit_" + type(node).__name__)) | |
# --- Example 1: An expression evaluator | |
class Evaluator(NodeVisitor): | |
def visit_Number(self, node): | |
return node.value | |
def visit_Add(self, node): | |
return self.visit(node.left) + self.visit(node.right) | |
def visit_Sub(self, node): | |
return self.visit(node.left) - self.visit(node.right) | |
def visit_Mul(self, node): | |
return self.visit(node.left) * self.visit(node.right) | |
def visit_Div(self, node): | |
return self.visit(node.left) / self.visit(node.right) | |
def visit_Negate(self, node): | |
return -node.operand | |
# --- Example 2: Generate stack instructions | |
class StackCode(NodeVisitor): | |
def generate_code(self, node): | |
self.instructions = [] | |
self.visit(node) | |
return self.instructions | |
def visit_Number(self, node): | |
self.instructions.append(("PUSH", node.value)) | |
def binop(self, node, instruction): | |
self.visit(node.left) | |
self.visit(node.right) | |
self.instructions.append((instruction,)) | |
def visit_Add(self, node): | |
self.binop(node, "ADD") | |
def visit_Sub(self, node): | |
self.binop(node, "SUB") | |
def visit_Mul(self, node): | |
self.binop(node, "MUL") | |
def visit_Div(self, node): | |
self.binop(node, "DIV") | |
def unaryop(self, node, instruction): | |
self.visit(node.operand) | |
self.instructions.append((instruction,)) | |
def visit_Negate(self, node): | |
self.unaryop(node, "NEG") | |
# --- Example of the above classes in action | |
# Representation of 1 + 2 * (3 - 4) / 5 | |
t1 = Sub(Number(3), Number(4)) | |
t2 = Mul(Number(2), t1) | |
t3 = Div(t2, Number(5)) | |
t4 = Add(Number(1), t3) | |
e = Evaluator() | |
print("Should get 0.6 :", e.visit(t4)) | |
s = StackCode() | |
code = s.generate_code(t4) | |
for c in code: | |
print(c) | |
################################################################################ | |
## 8::implementing_the_visitor_pattern_without_recursion | |
# Example: Recursive implementation | |
from node import Node, NodeVisitor | |
class UnaryOperator(Node): | |
def __init__(self, operand): | |
self.operand = operand | |
class BinaryOperator(Node): | |
def __init__(self, left, right): | |
self.left = left | |
self.right = right | |
class Add(BinaryOperator): | |
pass | |
class Sub(BinaryOperator): | |
pass | |
class Mul(BinaryOperator): | |
pass | |
class Div(BinaryOperator): | |
pass | |
class Negate(UnaryOperator): | |
pass | |
class Number(Node): | |
def __init__(self, value): | |
self.value = value | |
# A sample visitor class that evaluates expressions | |
class Evaluator(NodeVisitor): | |
def visit_Number(self, node): | |
return node.value | |
def visit_Add(self, node): | |
return self.visit(node.left) + self.visit(node.right) | |
def visit_Sub(self, node): | |
return self.visit(node.left) - self.visit(node.right) | |
def visit_Mul(self, node): | |
return self.visit(node.left) * self.visit(node.right) | |
def visit_Div(self, node): | |
return self.visit(node.left) / self.visit(node.right) | |
def visit_Negate(self, node): | |
return -self.visit(node.operand) | |
if __name__ == "__main__": | |
# 1 + 2*(3-4) / 5 | |
t1 = Sub(Number(3), Number(4)) | |
t2 = Mul(Number(2), t1) | |
t3 = Div(t2, Number(5)) | |
t4 = Add(Number(1), t3) | |
# Evaluate it | |
e = Evaluator() | |
print(e.visit(t4)) # Outputs 0.6 | |
# Blow it up | |
a = Number(0) | |
for n in range(1, 100000): | |
a = Add(a, Number(n)) | |
try: | |
print(e.visit(a)) | |
except RuntimeError as e: | |
print(e) | |
################################################################################ | |
## 8::implementing_the_visitor_pattern_without_recursion | |
# Example: Non-recursive implementation using yield | |
from node import Node, NodeVisitor | |
class UnaryOperator(Node): | |
def __init__(self, operand): | |
self.operand = operand | |
class BinaryOperator(Node): | |
def __init__(self, left, right): | |
self.left = left | |
self.right = right | |
class Add(BinaryOperator): | |
pass | |
class Sub(BinaryOperator): | |
pass | |
class Mul(BinaryOperator): | |
pass | |
class Div(BinaryOperator): | |
pass | |
class Negate(UnaryOperator): | |
pass | |
class Number(Node): | |
def __init__(self, value): | |
self.value = value | |
class Evaluator(NodeVisitor): | |
def visit_Number(self, node): | |
return node.value | |
def visit_Add(self, node): | |
yield (yield node.left) + (yield node.right) | |
def visit_Sub(self, node): | |
yield (yield node.left) - (yield node.right) | |
def visit_Mul(self, node): | |
yield (yield node.left) * (yield node.right) | |
def visit_Div(self, node): | |
yield (yield node.left) / (yield node.right) | |
def visit_Negate(self, node): | |
yield -(yield node.operand) | |
if __name__ == "__main__": | |
# 1 + 2*(3-4) / 5 | |
t1 = Sub(Number(3), Number(4)) | |
t2 = Mul(Number(2), t1) | |
t3 = Div(t2, Number(5)) | |
t4 = Add(Number(1), t3) | |
# Evaluate it | |
e = Evaluator() | |
print(e.visit(t4)) # Outputs 0.6 | |
# Blow it up | |
a = Number(0) | |
for n in range(1, 100000): | |
a = Add(a, Number(n)) | |
try: | |
print(e.visit(a)) | |
except RuntimeError as e: | |
print(e) | |
################################################################################ | |
## 8::implementing_the_visitor_pattern_without_recursion | |
# Example: Modified non-recursive implementation using | |
# a special Visit() class to signal what should be visited next | |
import types | |
class Node: | |
pass | |
class Visit: | |
def __init__(self, node): | |
self.node = node | |
class NodeVisitor: | |
def visit(self, node): | |
stack = [Visit(node)] | |
last_result = None | |
while stack: | |
try: | |
last = stack[-1] | |
if isinstance(last, types.GeneratorType): | |
stack.append(last.send(last_result)) | |
last_result = None | |
elif isinstance(last, Visit): | |
stack.append(self._visit(stack.pop().node)) | |
else: | |
last_result = stack.pop() | |
except StopIteration: | |
stack.pop() | |
return last_result | |
def _visit(self, node): | |
methname = "visit_" + type(node).__name__ | |
meth = getattr(self, methname, None) | |
if meth is None: | |
meth = self.generic_visit | |
return meth(node) | |
def generic_visit(self, node): | |
raise RuntimeError("No {} method".format("visit_" + type(node).__name__)) | |
class UnaryOperator(Node): | |
def __init__(self, operand): | |
self.operand = operand | |
class BinaryOperator(Node): | |
def __init__(self, left, right): | |
self.left = left | |
self.right = right | |
class Add(BinaryOperator): | |
pass | |
class Sub(BinaryOperator): | |
pass | |
class Mul(BinaryOperator): | |
pass | |
class Div(BinaryOperator): | |
pass | |
class Negate(UnaryOperator): | |
pass | |
class Number(Node): | |
def __init__(self, value): | |
self.value = value | |
class Evaluator(NodeVisitor): | |
def visit_Number(self, node): | |
return node.value | |
def visit_Add(self, node): | |
yield (yield Visit(node.left)) + (yield Visit(node.right)) | |
def visit_Sub(self, node): | |
yield (yield Visit(node.left)) - (yield Visit(node.right)) | |
def visit_Mul(self, node): | |
yield (yield Visit(node.left)) * (yield Visit(node.right)) | |
def visit_Div(self, node): | |
yield (yield Visit(node.left)) / (yield Visit(node.right)) | |
def visit_Negate(self, node): | |
yield -(yield Visit(node.operand)) | |
if __name__ == "__main__": | |
# 1 + 2*(3-4) / 5 | |
t1 = Sub(Number(3), Number(4)) | |
t2 = Mul(Number(2), t1) | |
t3 = Div(t2, Number(5)) | |
t4 = Add(Number(1), t3) | |
# Evaluate it | |
e = Evaluator() | |
print(e.visit(t4)) # Outputs 0.6 | |
# Blow it up | |
a = Number(0) | |
for n in range(1, 100000): | |
a = Add(a, Number(n)) | |
try: | |
print(e.visit(a)) | |
except RuntimeError as e: | |
print(e) | |
################################################################################ | |
## 8::implementing_the_visitor_pattern_without_recursion | |
# node.py | |
# | |
# Base class and non-recursive visitor implementation. | |
# Used by various example files. | |
import types | |
class Node: | |
pass | |
import types | |
class NodeVisitor: | |
def visit(self, node): | |
stack = [node] | |
last_result = None | |
while stack: | |
try: | |
last = stack[-1] | |
if isinstance(last, types.GeneratorType): | |
stack.append(last.send(last_result)) | |
last_result = None | |
elif isinstance(last, Node): | |
stack.append(self._visit(stack.pop())) | |
else: | |
last_result = stack.pop() | |
except StopIteration: | |
stack.pop() | |
return last_result | |
def _visit(self, node): | |
methname = "visit_" + type(node).__name__ | |
meth = getattr(self, methname, None) | |
if meth is None: | |
meth = self.generic_visit | |
return meth(node) | |
def generic_visit(self, node): | |
raise RuntimeError("No {} method".format("visit_" + type(node).__name__)) | |
################################################################################ | |
## 8::lazily_computed_attributes | |
class lazyproperty: | |
def __init__(self, func): | |
self.func = func | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
value = self.func(instance) | |
setattr(instance, self.func.__name__, value) | |
return value | |
if __name__ == "__main__": | |
import math | |
class Circle: | |
def __init__(self, radius): | |
self.radius = radius | |
@lazyproperty | |
def area(self): | |
print("Computing area") | |
return math.pi * self.radius ** 2 | |
@lazyproperty | |
def perimeter(self): | |
print("Computing perimeter") | |
return 2 * math.pi * self.radius | |
################################################################################ | |
## 8::lazily_computed_attributes | |
def lazyproperty(func): | |
name = "_lazy_" + func.__name__ | |
@property | |
def lazy(self): | |
if hasattr(self, name): | |
return getattr(self, name) | |
else: | |
value = func(self) | |
setattr(self, name, value) | |
return value | |
return lazy | |
if __name__ == "__main__": | |
import math | |
class Circle: | |
def __init__(self, radius): | |
self.radius = radius | |
@lazyproperty | |
def area(self): | |
print("Computing area") | |
return math.pi * self.radius ** 2 | |
@lazyproperty | |
def perimeter(self): | |
print("Computing perimeter") | |
return 2 * math.pi * self.radius | |
################################################################################ | |
## 8::making_classes_support_comparison_operations | |
from functools import total_ordering | |
class Room: | |
def __init__(self, name, length, width): | |
self.name = name | |
self.length = length | |
self.width = width | |
self.square_feet = self.length * self.width | |
@total_ordering | |
class House: | |
def __init__(self, name, style): | |
self.name = name | |
self.style = style | |
self.rooms = list() | |
@property | |
def living_space_footage(self): | |
return sum(r.square_feet for r in self.rooms) | |
def add_room(self, room): | |
self.rooms.append(room) | |
def __str__(self): | |
return "{}: {} square foot {}".format( | |
self.name, self.living_space_footage, self.style | |
) | |
def __eq__(self, other): | |
return self.living_space_footage == other.living_space_footage | |
def __lt__(self, other): | |
return self.living_space_footage < other.living_space_footage | |
# Build a few houses, and add rooms to them. | |
h1 = House("h1", "Cape") | |
h1.add_room(Room("Master Bedroom", 14, 21)) | |
h1.add_room(Room("Living Room", 18, 20)) | |
h1.add_room(Room("Kitchen", 12, 16)) | |
h1.add_room(Room("Office", 12, 12)) | |
h2 = House("h2", "Ranch") | |
h2.add_room(Room("Master Bedroom", 14, 21)) | |
h2.add_room(Room("Living Room", 18, 20)) | |
h2.add_room(Room("Kitchen", 12, 16)) | |
h3 = House("h3", "Split") | |
h3.add_room(Room("Master Bedroom", 14, 21)) | |
h3.add_room(Room("Living Room", 18, 20)) | |
h3.add_room(Room("Office", 12, 16)) | |
h3.add_room(Room("Kitchen", 15, 17)) | |
houses = [h1, h2, h3] | |
print("Is h1 bigger than h2?", h1 > h2) # prints True | |
print("Is h2 smaller than h3?", h2 < h3) # prints True | |
print("Is h2 greater than or equal to h1?", h2 >= h1) # prints False | |
print("Which one is biggest?", max(houses)) # prints 'h3: 1101 square foot Split' | |
print("Which is smallest?", min(houses)) # prints 'h2: 846 square foot Ranch' | |
################################################################################ | |
## 8::making_objects_support_the_context_manager_protocol | |
from socket import socket, AF_INET, SOCK_STREAM | |
class LazyConnection: | |
def __init__(self, address, family=AF_INET, type=SOCK_STREAM): | |
self.address = address | |
self.family = AF_INET | |
self.type = SOCK_STREAM | |
self.sock = None | |
def __enter__(self): | |
if self.sock is not None: | |
raise RuntimeError("Already connected") | |
self.sock = socket(self.family, self.type) | |
self.sock.connect(self.address) | |
return self.sock | |
def __exit__(self, exc_ty, exc_val, tb): | |
self.sock.close() | |
self.sock = None | |
if __name__ == "__main__": | |
from functools import partial | |
c = LazyConnection(("www.python.org", 80)) | |
# Connection closed | |
with c as s: | |
# c.__enter__() executes: connection open | |
s.send(b"GET /index.html HTTP/1.0\r\n") | |
s.send(b"Host: www.python.org\r\n") | |
s.send(b"\r\n") | |
resp = b"".join(iter(partial(s.recv, 8192), b"")) | |
# c.__exit__() executes: connection closed | |
print("Got %d bytes" % len(resp)) | |
################################################################################ | |
## 8::making_objects_support_the_context_manager_protocol | |
from socket import socket, AF_INET, SOCK_STREAM | |
class LazyConnection: | |
def __init__(self, address, family=AF_INET, type=SOCK_STREAM): | |
self.address = address | |
self.family = AF_INET | |
self.type = SOCK_STREAM | |
self.connections = [] | |
def __enter__(self): | |
sock = socket(self.family, self.type) | |
sock.connect(self.address) | |
self.connections.append(sock) | |
return sock | |
def __exit__(self, exc_ty, exc_val, tb): | |
self.connections.pop().close() | |
if __name__ == "__main__": | |
# Example use | |
from functools import partial | |
conn = LazyConnection(("www.python.org", 80)) | |
with conn as s: | |
s.send(b"GET /index.html HTTP/1.0\r\n") | |
s.send(b"Host: www.python.org\r\n") | |
s.send(b"\r\n") | |
resp = b"".join(iter(partial(s.recv, 8192), b"")) | |
print("Got %d bytes" % len(resp)) | |
with conn as s1, conn as s2: | |
s1.send(b"GET /downloads HTTP/1.0\r\n") | |
s2.send(b"GET /index.html HTTP/1.0\r\n") | |
s1.send(b"Host: www.python.org\r\n") | |
s2.send(b"Host: www.python.org\r\n") | |
s1.send(b"\r\n") | |
s2.send(b"\r\n") | |
resp1 = b"".join(iter(partial(s1.recv, 8192), b"")) | |
resp2 = b"".join(iter(partial(s2.recv, 8192), b"")) | |
print("resp1 got %d bytes" % len(resp1)) | |
print("resp2 got %d bytes" % len(resp2)) | |
################################################################################ | |
## 8::managing_memory_in_cyclic_data_structures | |
import weakref | |
class Node: | |
def __init__(self, value): | |
self.value = value | |
self._parent = None | |
self.children = [] | |
def __repr__(self): | |
return "Node({!r:})".format(self.value) | |
# property that manages the parent as a weak-reference | |
@property | |
def parent(self): | |
return self._parent if self._parent is None else self._parent() | |
@parent.setter | |
def parent(self, node): | |
self._parent = weakref.ref(node) | |
def add_child(self, child): | |
self.children.append(child) | |
child.parent = self | |
if __name__ == "__main__": | |
root = Node("parent") | |
c1 = Node("c1") | |
c2 = Node("c2") | |
root.add_child(c1) | |
root.add_child(c2) | |
print(c1.parent) | |
del root | |
print(c1.parent) | |
################################################################################ | |
## 8::simplified_initialization_of_data_structures | |
class Structure: | |
# Class variable that specifies expected fields | |
_fields = [] | |
def __init__(self, *args): | |
if len(args) != len(self._fields): | |
raise TypeError("Expected {} arguments".format(len(self._fields))) | |
# Set the arguments | |
for name, value in zip(self._fields, args): | |
setattr(self, name, value) | |
# Example class definitions | |
if __name__ == "__main__": | |
class Stock(Structure): | |
_fields = ["name", "shares", "price"] | |
class Point(Structure): | |
_fields = ["x", "y"] | |
class Circle(Structure): | |
_fields = ["radius"] | |
def area(self): | |
return math.pi * self.radius ** 2 | |
if __name__ == "__main__": | |
s = Stock("ACME", 50, 91.1) | |
print(s.name, s.shares, s.price) | |
p = Point(2, 3) | |
print(p.x, p.y) | |
c = Circle(4.5) | |
print(c.radius) | |
try: | |
s2 = Stock("ACME", 50) | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 8::simplified_initialization_of_data_structures | |
class Structure: | |
_fields = [] | |
def __init__(self, *args, **kwargs): | |
if len(args) > len(self._fields): | |
raise TypeError("Expected {} arguments".format(len(self._fields))) | |
# Set all of the positional arguments | |
for name, value in zip(self._fields, args): | |
setattr(self, name, value) | |
# Set the remaining keyword arguments | |
for name in self._fields[len(args) :]: | |
setattr(self, name, kwargs.pop(name)) | |
# Check for any remaining unknown arguments | |
if kwargs: | |
raise TypeError("Invalid argument(s): {}".format(",".join(kwargs))) | |
# Example use | |
if __name__ == "__main__": | |
class Stock(Structure): | |
_fields = ["name", "shares", "price"] | |
s1 = Stock("ACME", 50, 91.1) | |
s2 = Stock("ACME", 50, price=91.1) | |
s3 = Stock("ACME", shares=50, price=91.1) | |
################################################################################ | |
## 8::simplified_initialization_of_data_structures | |
class Structure: | |
# Class variable that specifies expected fields | |
_fields = [] | |
def __init__(self, *args, **kwargs): | |
if len(args) != len(self._fields): | |
raise TypeError("Expected {} arguments".format(len(self._fields))) | |
# Set the arguments | |
for name, value in zip(self._fields, args): | |
setattr(self, name, value) | |
# Set the additional arguments (if any) | |
extra_args = kwargs.keys() - self._fields | |
for name in extra_args: | |
setattr(self, name, kwargs.pop(name)) | |
if kwargs: | |
raise TypeError("Duplicate values for {}".format(",".join(kwargs))) | |
# Example use | |
if __name__ == "__main__": | |
class Stock(Structure): | |
_fields = ["name", "shares", "price"] | |
s1 = Stock("ACME", 50, 91.1) | |
s2 = Stock("ACME", 50, 91.1, date="8/2/2012") | |
################################################################################ | |
## 9::applying_decorators_to_class_and_static_methods | |
import time | |
from functools import wraps | |
# A simple decorator | |
def timethis(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
r = func(*args, **kwargs) | |
end = time.time() | |
print(end - start) | |
return r | |
return wrapper | |
# Class illustrating application of the decorator to different kinds of methods | |
class Spam: | |
@timethis | |
def instance_method(self, n): | |
print(self, n) | |
while n > 0: | |
n -= 1 | |
@classmethod | |
@timethis | |
def class_method(cls, n): | |
print(cls, n) | |
while n > 0: | |
n -= 1 | |
@staticmethod | |
@timethis | |
def static_method(n): | |
print(n) | |
while n > 0: | |
n -= 1 | |
if __name__ == "__main__": | |
s = Spam() | |
s.instance_method(10000000) | |
Spam.class_method(10000000) | |
Spam.static_method(10000000) | |
################################################################################ | |
## 9::avoiding_repetitive_property_methods | |
def typed_property(name, expected_type): | |
storage_name = "_" + name | |
@property | |
def prop(self): | |
return getattr(self, storage_name) | |
@prop.setter | |
def prop(self, value): | |
if not isinstance(value, expected_type): | |
raise TypeError("{} must be a {}".format(name, expected_type)) | |
setattr(self, storage_name, value) | |
return prop | |
# Example use | |
class Person: | |
name = typed_property("name", str) | |
age = typed_property("age", int) | |
def __init__(self, name, age): | |
self.name = name | |
self.age = age | |
if __name__ == "__main__": | |
p = Person("Dave", 39) | |
p.name = "Guido" | |
try: | |
p.age = "Old" | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 9::capturing_class_attribute_definition_order | |
# Example of capturing class definition order | |
from collections import OrderedDict | |
# A set of descriptors for various types | |
class Typed: | |
_expected_type = type(None) | |
def __init__(self, name=None): | |
self._name = name | |
def __set__(self, instance, value): | |
if not isinstance(value, self._expected_type): | |
raise TypeError("Expected " + str(self._expected_type)) | |
instance.__dict__[self._name] = value | |
class Integer(Typed): | |
_expected_type = int | |
class Float(Typed): | |
_expected_type = float | |
class String(Typed): | |
_expected_type = str | |
# Metaclass that uses an OrderedDict for class body | |
class OrderedMeta(type): | |
def __new__(cls, clsname, bases, clsdict): | |
d = dict(clsdict) | |
order = [] | |
for name, value in clsdict.items(): | |
if isinstance(value, Typed): | |
value._name = name | |
order.append(name) | |
d["_order"] = order | |
return type.__new__(cls, clsname, bases, d) | |
@classmethod | |
def __prepare__(cls, clsname, bases): | |
return OrderedDict() | |
# Example class that uses the definition order to initialize members | |
class Structure(metaclass=OrderedMeta): | |
def as_csv(self): | |
return ",".join(str(getattr(self, name)) for name in self._order) | |
# Example use | |
class Stock(Structure): | |
name = String() | |
shares = Integer() | |
price = Float() | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
if __name__ == "__main__": | |
s = Stock("GOOG", 100, 490.1) | |
print(s.name) | |
print(s.as_csv()) | |
try: | |
t = Stock("AAPL", "a lot", 610.23) | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 9::capturing_class_attribute_definition_order | |
# Example of a metaclass that rejects duplicate definitions | |
from collections import OrderedDict | |
class NoDupOrderedDict(OrderedDict): | |
def __init__(self, clsname): | |
self.clsname = clsname | |
super().__init__() | |
def __setitem__(self, name, value): | |
if name in self: | |
raise TypeError("{} already defined in {}".format(name, self.clsname)) | |
super().__setitem__(name, value) | |
class OrderedMeta(type): | |
def __new__(cls, clsname, bases, clsdict): | |
d = dict(clsdict) | |
d["_order"] = [name for name in clsdict if name[0] != "_"] | |
return type.__new__(cls, clsname, bases, d) | |
@classmethod | |
def __prepare__(cls, clsname, bases): | |
return NoDupOrderedDict(clsname) | |
# Example | |
class A(metaclass=OrderedMeta): | |
def spam(self): | |
pass | |
print("**** A type error is expected now:") | |
def spam(self): | |
pass | |
################################################################################ | |
## 9::defining_a_decorator_that_takes_an_optional_argument | |
from functools import wraps, partial | |
import logging | |
def logged(func=None, *, level=logging.DEBUG, name=None, message=None): | |
if func is None: | |
return partial(logged, level=level, name=name, message=message) | |
logname = name if name else func.__module__ | |
log = logging.getLogger(logname) | |
logmsg = message if message else func.__name__ | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
log.log(level, logmsg) | |
return func(*args, **kwargs) | |
return wrapper | |
# Example use | |
@logged | |
def add(x, y): | |
return x + y | |
@logged() | |
def sub(x, y): | |
return x - y | |
@logged(level=logging.CRITICAL, name="example") | |
def spam(): | |
print("Spam!") | |
if __name__ == "__main__": | |
import logging | |
logging.basicConfig(level=logging.DEBUG) | |
add(2, 3) | |
sub(2, 3) | |
spam() | |
################################################################################ | |
## 9::defining_a_decorator_that_takes_arguments | |
from functools import wraps | |
import logging | |
def logged(level, name=None, message=None): | |
""" | |
Add logging to a function. level is the logging | |
level, name is the logger name, and message is the | |
log message. If name and message aren't specified, | |
they default to the function's module and name. | |
""" | |
def decorate(func): | |
logname = name if name else func.__module__ | |
log = logging.getLogger(logname) | |
logmsg = message if message else func.__name__ | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
log.log(level, logmsg) | |
return func(*args, **kwargs) | |
return wrapper | |
return decorate | |
# Example use | |
@logged(logging.DEBUG) | |
def add(x, y): | |
return x + y | |
@logged(logging.CRITICAL, "example") | |
def spam(): | |
print("Spam!") | |
if __name__ == "__main__": | |
import logging | |
logging.basicConfig(level=logging.DEBUG) | |
print(add(2, 3)) | |
spam() | |
################################################################################ | |
## 9::defining_a_decorator_with_user_adjustable_attributes | |
from functools import wraps, partial | |
import logging | |
def attach_wrapper(obj, func=None): | |
if func is None: | |
return partial(attach_wrapper, obj) | |
setattr(obj, func.__name__, func) | |
return func | |
def logged(level, name=None, message=None): | |
""" | |
Add logging to a function. level is the logging | |
level, name is the logger name, and message is the | |
log message. If name and message aren't specified, | |
they default to the function's module and name. | |
""" | |
def decorate(func): | |
logname = name if name else func.__module__ | |
log = logging.getLogger(logname) | |
logmsg = message if message else func.__name__ | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
log.log(level, logmsg) | |
return func(*args, **kwargs) | |
# Attach setter functions | |
@attach_wrapper(wrapper) | |
def set_level(newlevel): | |
nonlocal level | |
level = newlevel | |
@attach_wrapper(wrapper) | |
def set_message(newmsg): | |
nonlocal logmsg | |
logmsg = newmsg | |
return wrapper | |
return decorate | |
# Example use | |
@logged(logging.DEBUG) | |
def add(x, y): | |
return x + y | |
@logged(logging.CRITICAL, "example") | |
def spam(): | |
print("Spam!") | |
# Example involving multiple decorators | |
import time | |
def timethis(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
r = func(*args, **kwargs) | |
end = time.time() | |
print(func.__name__, end - start) | |
return r | |
return wrapper | |
@timethis | |
@logged(logging.DEBUG) | |
def countdown(n): | |
while n > 0: | |
n -= 1 | |
@logged(logging.DEBUG) | |
@timethis | |
def countdown2(n): | |
while n > 0: | |
n -= 1 | |
if __name__ == "__main__": | |
import logging | |
logging.basicConfig(level=logging.DEBUG) | |
print(add(2, 3)) | |
# Change the log message | |
add.set_message("Add called") | |
print(add(2, 3)) | |
# Change the log level | |
add.set_level(logging.WARNING) | |
print(add(2, 3)) | |
countdown(100000) | |
countdown.set_level(logging.CRITICAL) | |
countdown(100000) | |
countdown2(100000) | |
countdown2.set_level(logging.CRITICAL) | |
countdown2(100000) | |
################################################################################ | |
## 9::defining_a_decorator_with_user_adjustable_attributes | |
# Alternate formulation using function attributes directly | |
from functools import wraps | |
import logging | |
def logged(level, name=None, message=None): | |
""" | |
Add logging to a function. level is the logging | |
level, name is the logger name, and message is the | |
log message. If name and message aren't specified, | |
they default to the function's module and name. | |
""" | |
def decorate(func): | |
logname = name if name else func.__module__ | |
log = logging.getLogger(logname) | |
logmsg = message if message else func.__name__ | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
wrapper.log.log(wrapper.level, wrapper.logmsg) | |
return func(*args, **kwargs) | |
# Attach adjustable attributes | |
wrapper.level = level | |
wrapper.logmsg = logmsg | |
wrapper.log = log | |
return wrapper | |
return decorate | |
# Example use | |
@logged(logging.DEBUG) | |
def add(x, y): | |
return x + y | |
@logged(logging.CRITICAL, "example") | |
def spam(): | |
print("Spam!") | |
if __name__ == "__main__": | |
import logging | |
logging.basicConfig(level=logging.DEBUG) | |
print(add(2, 3)) | |
# Change the log message | |
add.logmsg = "Add called" | |
print(add(2, 3)) | |
# Change the log level | |
add.level = logging.WARNING | |
print(add(2, 3)) | |
################################################################################ | |
## 9::defining_a_metaclass_that_takes_optional_arguments | |
# Example of a metaclass that takes optional arguments | |
class MyMeta(type): | |
# Optional | |
@classmethod | |
def __prepare__(cls, name, bases, *, debug=False, synchronize=False): | |
# Custom processing | |
return super().__prepare__(name, bases) | |
# Required | |
def __new__(cls, name, bases, ns, *, debug=False, synchronize=False): | |
# Custom processing | |
return super().__new__(cls, name, bases, ns) | |
def __init__(self, name, bases, ns, *, debug=False, synchronize=False): | |
# Custom processing | |
super().__init__(name, bases, ns) | |
# Examples | |
class A(metaclass=MyMeta, debug=True, synchronize=True): | |
pass | |
class B(metaclass=MyMeta): | |
pass | |
class C(metaclass=MyMeta, synchronize=True): | |
pass | |
################################################################################ | |
## 9::defining_classes_programmatically | |
# Example of making a class manually from parts | |
# Methods | |
def __init__(self, name, shares, price): | |
self.name = name | |
self.shares = shares | |
self.price = price | |
def cost(self): | |
return self.shares * self.price | |
cls_dict = { | |
"__init__": __init__, | |
"cost": cost, | |
} | |
# Make a class | |
import types | |
Stock = types.new_class("Stock", (), {}, lambda ns: ns.update(cls_dict)) | |
if __name__ == "__main__": | |
s = Stock("ACME", 50, 91.1) | |
print(s) | |
print(s.cost()) | |
################################################################################ | |
## 9::defining_classes_programmatically | |
# An alternative formulation of namedtuples | |
import operator | |
import types | |
import sys | |
def named_tuple(classname, fieldnames): | |
# Populate a dictionary of field property accessors | |
cls_dict = { | |
name: property(operator.itemgetter(n)) for n, name in enumerate(fieldnames) | |
} | |
# Make a __new__ function and add to the class dict | |
def __new__(cls, *args): | |
if len(args) != len(fieldnames): | |
raise TypeError("Expected {} arguments".format(len(fieldnames))) | |
return tuple.__new__(cls, (args)) | |
cls_dict["__new__"] = __new__ | |
# Make the class | |
cls = types.new_class(classname, (tuple,), {}, lambda ns: ns.update(cls_dict)) | |
cls.__module__ = sys._getframe(1).f_globals["__name__"] | |
return cls | |
if __name__ == "__main__": | |
Point = named_tuple("Point", ["x", "y"]) | |
print(Point) | |
p = Point(4, 5) | |
print(len(p)) | |
print(p.x, p[0]) | |
print(p.y, p[1]) | |
try: | |
p.x = 2 | |
except AttributeError as e: | |
print(e) | |
print("%s %s" % p) | |
################################################################################ | |
## 9::defining_context_managers_the_easy_way | |
import time | |
from contextlib import contextmanager | |
@contextmanager | |
def timethis(label): | |
start = time.time() | |
try: | |
yield | |
finally: | |
end = time.time() | |
print("{}: {}".format(label, end - start)) | |
# Example use | |
with timethis("counting"): | |
n = 10000000 | |
while n > 0: | |
n -= 1 | |
################################################################################ | |
## 9::defining_context_managers_the_easy_way | |
from contextlib import contextmanager | |
@contextmanager | |
def list_transaction(orig_list): | |
working = list(orig_list) | |
yield working | |
orig_list[:] = working | |
# Example | |
if __name__ == "__main__": | |
items = [1, 2, 3] | |
with list_transaction(items) as working: | |
working.append(4) | |
working.append(5) | |
print(items) | |
try: | |
with list_transaction(items) as working: | |
working.append(6) | |
working.append(7) | |
raise RuntimeError("oops") | |
except RuntimeError as e: | |
print(e) | |
print(items) | |
################################################################################ | |
## 9::defining_decorators_as_classes | |
# Example of defining a decorator as a class | |
import types | |
from functools import wraps | |
class Profiled: | |
def __init__(self, func): | |
wraps(func)(self) | |
self.ncalls = 0 | |
def __call__(self, *args, **kwargs): | |
self.ncalls += 1 | |
return self.__wrapped__(*args, **kwargs) | |
def __get__(self, instance, cls): | |
if instance is None: | |
return self | |
else: | |
return types.MethodType(self, instance) | |
# Example | |
@Profiled | |
def add(x, y): | |
return x + y | |
class Spam: | |
@Profiled | |
def bar(self, x): | |
print(self, x) | |
if __name__ == "__main__": | |
print(add(2, 3)) | |
print(add(4, 5)) | |
print("ncalls:", add.ncalls) | |
s = Spam() | |
s.bar(1) | |
s.bar(2) | |
s.bar(3) | |
print("ncalls:", Spam.bar.ncalls) | |
################################################################################ | |
## 9::defining_decorators_as_classes | |
# Reformulation using closures and function attributes | |
from functools import wraps | |
def profiled(func): | |
ncalls = 0 | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
nonlocal ncalls | |
ncalls += 1 | |
return func(*args, **kwargs) | |
wrapper.ncalls = lambda: ncalls | |
return wrapper | |
# Example | |
@profiled | |
def add(x, y): | |
return x + y | |
class Spam: | |
@profiled | |
def bar(self, x): | |
print(self, x) | |
if __name__ == "__main__": | |
print(add(2, 3)) | |
print(add(4, 5)) | |
print("ncalls:", add.ncalls()) | |
s = Spam() | |
s.bar(1) | |
s.bar(2) | |
s.bar(3) | |
print("ncalls:", Spam.bar.ncalls()) | |
################################################################################ | |
## 9::defining_decorators_as_classes | |
# Reformulation using closures and function attributes | |
# This example tests the composability of decorators | |
import time | |
from functools import wraps | |
def timethis(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
r = func(*args, **kwargs) | |
end = time.time() | |
print(func.__name__, end - start) | |
return r | |
return wrapper | |
def profiled(func): | |
ncalls = 0 | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
nonlocal ncalls | |
ncalls += 1 | |
return func(*args, **kwargs) | |
wrapper.ncalls = lambda: ncalls | |
return wrapper | |
# Example | |
@profiled | |
def add(x, y): | |
return x + y | |
class Spam: | |
@profiled | |
def bar(self, x): | |
print(self, x) | |
@timethis | |
@profiled | |
def countdown(n): | |
while n > 0: | |
n -= 1 | |
if __name__ == "__main__": | |
print(add(2, 3)) | |
print(add(4, 5)) | |
print("ncalls:", add.ncalls()) | |
s = Spam() | |
s.bar(1) | |
s.bar(2) | |
s.bar(3) | |
print("ncalls:", Spam.bar.ncalls()) | |
countdown(100000) | |
countdown(10000000) | |
print(countdown.ncalls()) | |
################################################################################ | |
## 9::defining_decorators_as_part_of_a_class | |
from functools import wraps | |
class A: | |
# Decorator as an instance method | |
def decorator1(self, func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
print("Decorator 1") | |
return func(*args, **kwargs) | |
return wrapper | |
# Decorator as a class method | |
@classmethod | |
def decorator2(cls, func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
print("Decorator 2") | |
return func(*args, **kwargs) | |
return wrapper | |
# Example | |
# As an instance method | |
a = A() | |
@a.decorator1 | |
def spam(): | |
pass | |
# As a class method | |
@A.decorator2 | |
def grok(): | |
pass | |
spam() | |
grok() | |
################################################################################ | |
## 9::defining_decorators_as_part_of_a_class | |
# Property example | |
class Person: | |
first_name = property() | |
@first_name.getter | |
def first_name(self): | |
return self._first_name | |
@first_name.setter | |
def first_name(self, value): | |
if not isinstance(value, str): | |
raise TypeError("Expected a string") | |
self._first_name = value | |
p = Person() | |
p.first_name = "Dave" | |
print(p.first_name) | |
################################################################################ | |
## 9::disassembling_python_byte_code | |
# Example of manual disassembly of bytecode | |
import opcode | |
def generate_opcodes(codebytes): | |
extended_arg = 0 | |
i = 0 | |
n = len(codebytes) | |
while i < n: | |
op = codebytes[i] | |
i += 1 | |
if op >= opcode.HAVE_ARGUMENT: | |
oparg = codebytes[i] + codebytes[i + 1] * 256 + extended_arg | |
extended_arg = 0 | |
i += 2 | |
if op == opcode.EXTENDED_ARG: | |
extended_arg = oparg * 65536 | |
continue | |
else: | |
oparg = None | |
yield (op, oparg) | |
# Example | |
def countdown(n): | |
while n > 0: | |
print("T-minus", n) | |
n -= 1 | |
print("Blastoff!") | |
for op, oparg in generate_opcodes(countdown.__code__.co_code): | |
print(op, opcode.opname[op], oparg) | |
################################################################################ | |
## 9::enforcing_an_argument_signature | |
# Example of code that enforces signatures on an __init__ function | |
from inspect import Signature, Parameter | |
def make_sig(*names): | |
parms = [Parameter(name, Parameter.POSITIONAL_OR_KEYWORD) for name in names] | |
return Signature(parms) | |
class Structure: | |
__signature__ = make_sig() | |
def __init__(self, *args, **kwargs): | |
bound_values = self.__signature__.bind(*args, **kwargs) | |
for name, value in bound_values.arguments.items(): | |
setattr(self, name, value) | |
# Example use | |
class Stock(Structure): | |
__signature__ = make_sig("name", "shares", "price") | |
class Point(Structure): | |
__signature__ = make_sig("x", "y") | |
# Example instantiation tests | |
if __name__ == "__main__": | |
s1 = Stock("ACME", 100, 490.1) | |
print(s1.name, s1.shares, s1.price) | |
s2 = Stock(shares=100, name="ACME", price=490.1) | |
print(s2.name, s2.shares, s2.price) | |
# Not enough args | |
try: | |
s3 = Stock("ACME", 100) | |
except TypeError as e: | |
print(e) | |
# Too many args | |
try: | |
s4 = Stock("ACME", 100, 490.1, "12/21/2012") | |
except TypeError as e: | |
print(e) | |
# Replicated args | |
try: | |
s5 = Stock("ACME", 100, name="ACME", price=490.1) | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 9::enforcing_an_argument_signature | |
# Example of building signatures in a metaclass | |
from inspect import Signature, Parameter | |
def make_sig(*names): | |
parms = [Parameter(name, Parameter.POSITIONAL_OR_KEYWORD) for name in names] | |
return Signature(parms) | |
class StructureMeta(type): | |
def __new__(cls, clsname, bases, clsdict): | |
clsdict["__signature__"] = make_sig(*clsdict.get("_fields", [])) | |
return super().__new__(cls, clsname, bases, clsdict) | |
class Structure(metaclass=StructureMeta): | |
_fields = [] | |
def __init__(self, *args, **kwargs): | |
bound_values = self.__signature__.bind(*args, **kwargs) | |
for name, value in bound_values.arguments.items(): | |
setattr(self, name, value) | |
# Example | |
class Stock(Structure): | |
_fields = ["name", "shares", "price"] | |
class Point(Structure): | |
_fields = ["x", "y"] | |
# Example instantiation tests | |
if __name__ == "__main__": | |
s1 = Stock("ACME", 100, 490.1) | |
print(s1.name, s1.shares, s1.price) | |
s2 = Stock(shares=100, name="ACME", price=490.1) | |
print(s2.name, s2.shares, s2.price) | |
# Not enough args | |
try: | |
s3 = Stock("ACME", 100) | |
except TypeError as e: | |
print(e) | |
# Too many args | |
try: | |
s4 = Stock("ACME", 100, 490.1, "12/21/2012") | |
except TypeError as e: | |
print(e) | |
# Replicated args | |
try: | |
s5 = Stock("ACME", 100, name="ACME", price=490.1) | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 9::enforcing_coding_conventions_in_classes | |
# A metaclass that disallows mixed case identifier names | |
class NoMixedCaseMeta(type): | |
def __new__(cls, clsname, bases, clsdict): | |
for name in clsdict: | |
if name.lower() != name: | |
raise TypeError("Bad attribute name: " + name) | |
return super().__new__(cls, clsname, bases, clsdict) | |
class Root(metaclass=NoMixedCaseMeta): | |
pass | |
class A(Root): | |
def foo_bar(self): # Ok | |
pass | |
print("**** About to generate a TypeError") | |
class B(Root): | |
def fooBar(self): # TypeError | |
pass | |
################################################################################ | |
## 9::enforcing_coding_conventions_in_classes | |
# Using a metaclass to issue warnings about signature mismatches | |
from inspect import signature | |
import logging | |
class MatchSignaturesMeta(type): | |
def __init__(self, clsname, bases, clsdict): | |
super().__init__(clsname, bases, clsdict) | |
sup = super(self, self) | |
for name, value in clsdict.items(): | |
if name.startswith("_") or not callable(value): | |
continue | |
# Get the previous definition (if any) and compare the signatures | |
prev_dfn = getattr(sup, name, None) | |
if prev_dfn: | |
prev_sig = signature(prev_dfn) | |
val_sig = signature(value) | |
if prev_sig != val_sig: | |
logging.warning( | |
"Signature mismatch in %s. %s != %s", | |
value.__qualname__, | |
str(prev_sig), | |
str(val_sig), | |
) | |
# Example | |
class Root(metaclass=MatchSignaturesMeta): | |
pass | |
class A(Root): | |
def foo(self, x, y): | |
pass | |
def spam(self, x, *, z): | |
pass | |
# Class with redefined methods, but slightly different signatures | |
class B(A): | |
def foo(self, a, b): | |
pass | |
def spam(self, x, z): | |
pass | |
################################################################################ | |
## 9::enforcing_type_checking_on_a_function_using_a_decorator | |
from inspect import signature | |
from functools import wraps | |
def typeassert(*ty_args, **ty_kwargs): | |
def decorate(func): | |
# If in optimized mode, disable type checking | |
if not __debug__: | |
return func | |
# Map function argument names to supplied types | |
sig = signature(func) | |
bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
bound_values = sig.bind(*args, **kwargs) | |
# Enforce type assertions across supplied arguments | |
for name, value in bound_values.arguments.items(): | |
if name in bound_types: | |
if not isinstance(value, bound_types[name]): | |
raise TypeError( | |
"Argument {} must be {}".format(name, bound_types[name]) | |
) | |
return func(*args, **kwargs) | |
return wrapper | |
return decorate | |
# Examples | |
@typeassert(int, int) | |
def add(x, y): | |
return x + y | |
@typeassert(int, z=int) | |
def spam(x, y, z=42): | |
print(x, y, z) | |
if __name__ == "__main__": | |
print(add(2, 3)) | |
try: | |
add(2, "hello") | |
except TypeError as e: | |
print(e) | |
spam(1, 2, 3) | |
spam(1, "hello", 3) | |
try: | |
spam(1, "hello", "world") | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 9::executing_code_with_local_side_effects | |
def test(): | |
a = 13 | |
loc = locals() | |
exec("b = a + 1") | |
b = loc["b"] | |
print(b) # --> 14 | |
def test1(): | |
x = 0 | |
exec("x += 1") | |
print(x) # --> 0 | |
def test2(): | |
x = 0 | |
loc = locals() | |
print("before:", loc) | |
exec("x += 1") | |
print("after:", loc) | |
print("x =", x) | |
def test3(): | |
x = 0 | |
loc = locals() | |
print(loc) | |
exec("x += 1") | |
print(loc) | |
locals() | |
print(loc) | |
def test4(): | |
a = 13 | |
loc = {"a": a} | |
glb = {} | |
exec("b = a + 1", glb, loc) | |
b = loc["b"] | |
print(b) | |
if __name__ == "__main__": | |
print(":::: Running test()") | |
test() | |
print(":::: Running test1()") | |
test1() | |
print(":::: Running test2()") | |
test2() | |
print(":::: Running test3()") | |
test3() | |
print(":::: Running test4()") | |
test4() | |
################################################################################ | |
## 9::initializing_class_members_at_definition_time | |
import operator | |
class StructTupleMeta(type): | |
def __init__(cls, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
for n, name in enumerate(cls._fields_): | |
setattr(cls, name, property(operator.itemgetter(n))) | |
class StructTuple(tuple, metaclass=StructTupleMeta): | |
_fields_ = [] | |
def __new__(cls, *args): | |
if len(args) != len(cls._fields_): | |
raise ValueError("{} arguments required".format(len(cls._fields_))) | |
return super().__new__(cls, args) | |
# Examples | |
class Stock(StructTuple): | |
_fields_ = ["name", "shares", "price"] | |
class Point(StructTuple): | |
_fields_ = ["x", "y"] | |
if __name__ == "__main__": | |
s = Stock("ACME", 50, 91.1) | |
print(s) | |
print(s[0]) | |
print(s.name) | |
print(s.shares * s.price) | |
try: | |
s.shares = 23 | |
except AttributeError as e: | |
print(e) | |
################################################################################ | |
## 9::monkeypatching_class_definitions | |
def log_getattribute(cls): | |
# Get the original implementation | |
orig_getattribute = cls.__getattribute__ | |
# Make a new definition | |
def new_getattribute(self, name): | |
print("getting:", name) | |
return orig_getattribute(self, name) | |
# Attach to the class and return | |
cls.__getattribute__ = new_getattribute | |
return cls | |
# Example use | |
@log_getattribute | |
class A: | |
def __init__(self, x): | |
self.x = x | |
def spam(self): | |
pass | |
if __name__ == "__main__": | |
a = A(42) | |
print(a.x) | |
a.spam() | |
################################################################################ | |
## 9::multiple_dispatch_with_function_annotations | |
import inspect | |
import types | |
class MultiMethod: | |
""" | |
Represents a single multimethod. | |
""" | |
def __init__(self, name): | |
self._methods = {} | |
self.__name__ = name | |
def register(self, meth): | |
""" | |
Register a new method as a multimethod | |
""" | |
sig = inspect.signature(meth) | |
# Build a type-signature from the method's annotations | |
types = [] | |
for name, parm in sig.parameters.items(): | |
if name == "self": | |
continue | |
if parm.annotation is inspect.Parameter.empty: | |
raise TypeError( | |
"Argument {} must be annotated with a type".format(name) | |
) | |
if not isinstance(parm.annotation, type): | |
raise TypeError("Argument {} annotation must be a type".format(name)) | |
if parm.default is not inspect.Parameter.empty: | |
self._methods[tuple(types)] = meth | |
types.append(parm.annotation) | |
self._methods[tuple(types)] = meth | |
def __call__(self, *args): | |
""" | |
Call a method based on type signature of the arguments | |
""" | |
types = tuple(type(arg) for arg in args[1:]) | |
meth = self._methods.get(types, None) | |
if meth: | |
return meth(*args) | |
else: | |
raise TypeError("No matching method for types {}".format(types)) | |
def __get__(self, instance, cls): | |
""" | |
Descriptor method needed to make calls work in a class | |
""" | |
if instance is not None: | |
return types.MethodType(self, instance) | |
else: | |
return self | |
class MultiDict(dict): | |
""" | |
Special dictionary to build multimethods in a metaclass | |
""" | |
def __setitem__(self, key, value): | |
if key in self: | |
# If key already exists, it must be a multimethod or callable | |
current_value = self[key] | |
if isinstance(current_value, MultiMethod): | |
current_value.register(value) | |
else: | |
mvalue = MultiMethod(key) | |
mvalue.register(current_value) | |
mvalue.register(value) | |
super().__setitem__(key, mvalue) | |
else: | |
super().__setitem__(key, value) | |
class MultipleMeta(type): | |
""" | |
Metaclass that allows multiple dispatch of methods | |
""" | |
def __new__(cls, clsname, bases, clsdict): | |
return type.__new__(cls, clsname, bases, dict(clsdict)) | |
@classmethod | |
def __prepare__(cls, clsname, bases): | |
return MultiDict() | |
# Some example classes that use multiple dispatch | |
class Spam(metaclass=MultipleMeta): | |
def bar(self, x: int, y: int): | |
print("Bar 1:", x, y) | |
def bar(self, s: str, n: int = 0): | |
print("Bar 2:", s, n) | |
# Example: overloaded __init__ | |
import time | |
class Date(metaclass=MultipleMeta): | |
def __init__(self, year: int, month: int, day: int): | |
self.year = year | |
self.month = month | |
self.day = day | |
def __init__(self): | |
t = time.localtime() | |
self.__init__(t.tm_year, t.tm_mon, t.tm_mday) | |
if __name__ == "__main__": | |
s = Spam() | |
s.bar(2, 3) | |
s.bar("hello") | |
s.bar("hello", 5) | |
try: | |
s.bar(2, "hello") | |
except TypeError as e: | |
print(e) | |
# Overloaded __init__ | |
d = Date(2012, 12, 21) | |
print(d.year, d.month, d.day) | |
# Get today's date | |
e = Date() | |
print(e.year, e.month, e.day) | |
################################################################################ | |
## 9::multiple_dispatch_with_function_annotations | |
# Alternate formulation using decorators | |
import types | |
class multimethod: | |
def __init__(self, func): | |
self._methods = {} | |
self.__name__ = func.__name__ | |
self._default = func | |
def match(self, *types): | |
def register(func): | |
ndefaults = len(func.__defaults__) if func.__defaults__ else 0 | |
for n in range(ndefaults + 1): | |
self._methods[types[: len(types) - n]] = func | |
return self | |
return register | |
def __call__(self, *args): | |
types = tuple(type(arg) for arg in args[1:]) | |
meth = self._methods.get(types, None) | |
if meth: | |
return meth(*args) | |
else: | |
return self._default(*args) | |
def __get__(self, instance, cls): | |
if instance is not None: | |
return types.MethodType(self, instance) | |
else: | |
return self | |
# Example use | |
class Spam: | |
@multimethod | |
def bar(self, *args): | |
# Default method called if no match | |
raise TypeError("No matching method for bar") | |
@bar.match(int, int) | |
def bar(self, x, y): | |
print("Bar 1:", x, y) | |
@bar.match(str, int) | |
def bar(self, s, n=0): | |
print("Bar 2:", s, n) | |
if __name__ == "__main__": | |
s = Spam() | |
s.bar(2, 3) | |
s.bar("hello") | |
s.bar("hello", 5) | |
try: | |
s.bar(2, "hello") | |
except TypeError as e: | |
print(e) | |
################################################################################ | |
## 9::parsing_and_analyzing_python_source | |
import ast | |
class CodeAnalyzer(ast.NodeVisitor): | |
def __init__(self): | |
self.loaded = set() | |
self.stored = set() | |
self.deleted = set() | |
def visit_Name(self, node): | |
if isinstance(node.ctx, ast.Load): | |
self.loaded.add(node.id) | |
elif isinstance(node.ctx, ast.Store): | |
self.stored.add(node.id) | |
elif isinstance(node.ctx, ast.Del): | |
self.deleted.add(node.id) | |
# Sample usage | |
if __name__ == "__main__": | |
# Some python code | |
code = """ | |
for i in range(10): | |
print(i) | |
del i | |
""" | |
# Parse into an AST | |
top = ast.parse(code, mode="exec") | |
# Feed the AST to analyze name usage | |
c = CodeAnalyzer() | |
c.visit(top) | |
print("Loaded:", c.loaded) | |
print("Stored:", c.stored) | |
print("Deleted:", c.deleted) | |
################################################################################ | |
## 9::parsing_and_analyzing_python_source | |
# namelower.py | |
import ast | |
import inspect | |
# Node visitor that lowers globally accessed names into | |
# the function body as local variables. | |
class NameLower(ast.NodeVisitor): | |
def __init__(self, lowered_names): | |
self.lowered_names = lowered_names | |
def visit_FunctionDef(self, node): | |
# Compile some assignments to lower the constants | |
code = "__globals = globals()\n" | |
code += "\n".join( | |
"{0} = __globals['{0}']".format(name) for name in self.lowered_names | |
) | |
code_ast = ast.parse(code, mode="exec") | |
# Inject new statements into the function body | |
node.body[:0] = code_ast.body | |
# Save the function object | |
self.func = node | |
# Decorator that turns global names into locals | |
def lower_names(*namelist): | |
def lower(func): | |
srclines = inspect.getsource(func).splitlines() | |
# Skip source lines prior to the @lower_names decorator | |
for n, line in enumerate(srclines): | |
if "@lower_names" in line: | |
break | |
src = "\n".join(srclines[n + 1 :]) | |
# Hack to deal with indented code | |
if src.startswith((" ", "\t")): | |
src = "if 1:\n" + src | |
top = ast.parse(src, mode="exec") | |
# Transform the AST | |
cl = NameLower(namelist) | |
cl.visit(top) | |
# Execute the modified AST | |
temp = {} | |
exec(compile(top, "", "exec"), temp, temp) | |
# Pull out the modified code object | |
func.__code__ = temp[func.__name__].__code__ | |
return func | |
return lower | |
# Example of use | |
INCR = 1 | |
def countdown1(n): | |
while n > 0: | |
n -= INCR | |
@lower_names("INCR") | |
def countdown2(n): | |
while n > 0: | |
n -= INCR | |
if __name__ == "__main__": | |
import time | |
print("Running a performance check") | |
start = time.time() | |
countdown1(100000000) | |
end = time.time() | |
print("countdown1:", end - start) | |
start = time.time() | |
countdown2(100000000) | |
end = time.time() | |
print("countdown2:", end - start) | |
################################################################################ | |
## 9::preserving_function_metadata_when_writing_decorators | |
import time | |
from functools import wraps | |
def timethis(func): | |
""" | |
Decorator that reports the execution time. | |
""" | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
result = func(*args, **kwargs) | |
end = time.time() | |
print(func.__name__, end - start) | |
return result | |
return wrapper | |
if __name__ == "__main__": | |
@timethis | |
def countdown(n: int): | |
""" | |
Counts down | |
""" | |
while n > 0: | |
n -= 1 | |
countdown(100000) | |
print("Name:", countdown.__name__) | |
print("Docstring:", repr(countdown.__doc__)) | |
print("Annotations:", countdown.__annotations__) | |
################################################################################ | |
## 9::unwrapping_a_decorator | |
# Example of unwrapping a decorator | |
from functools import wraps | |
def decorator1(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
print("Decorator 1") | |
return func(*args, **kwargs) | |
return wrapper | |
def decorator2(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
print("Decorator 2") | |
return func(*args, **kwargs) | |
return wrapper | |
@decorator1 | |
@decorator2 | |
def add(x, y): | |
return x + y | |
# Calling wrapped function | |
print(add(2, 3)) | |
# Calling original function | |
print(add.__wrapped__(2, 3)) | |
################################################################################ | |
## 9::using_metaclasses_to_control_instance_creation | |
# example1.py | |
# | |
# Not allowing direct instantiation | |
class NoInstances(type): | |
def __call__(self, *args, **kwargs): | |
raise TypeError("Can't instantiate directly") | |
class Spam(metaclass=NoInstances): | |
@staticmethod | |
def grok(x): | |
print("Spam.grok") | |
if __name__ == "__main__": | |
try: | |
s = Spam() | |
except TypeError as e: | |
print(e) | |
Spam.grok(42) | |
################################################################################ | |
## 9::using_metaclasses_to_control_instance_creation | |
# example2.py | |
# | |
# Singleton | |
class Singleton(type): | |
def __init__(self, *args, **kwargs): | |
self.__instance = None | |
super().__init__(*args, **kwargs) | |
def __call__(self, *args, **kwargs): | |
if self.__instance is None: | |
self.__instance = super().__call__(*args, **kwargs) | |
return self.__instance | |
else: | |
return self.__instance | |
class Spam(metaclass=Singleton): | |
def __init__(self): | |
print("Creating Spam") | |
if __name__ == "__main__": | |
a = Spam() | |
b = Spam() | |
print(a is b) | |
################################################################################ | |
## 9::using_metaclasses_to_control_instance_creation | |
# example3.py | |
# | |
# Cached instances | |
import weakref | |
class Cached(type): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.__cache = weakref.WeakValueDictionary() | |
def __call__(self, *args): | |
if args in self.__cache: | |
return self.__cache[args] | |
else: | |
obj = super().__call__(*args) | |
self.__cache[args] = obj | |
return obj | |
class Spam(metaclass=Cached): | |
def __init__(self, name): | |
print("Creating Spam({!r})".format(name)) | |
self.name = name | |
if __name__ == "__main__": | |
a = Spam("foo") | |
b = Spam("bar") | |
print("a is b:", a is b) | |
c = Spam("foo") | |
print("a is c:", a is c) | |
################################################################################ | |
## single dispath and registering | |
from functools import singledispatch | |
@singledispatch | |
def fun(arg, verbose=False): | |
if verbose: | |
print("Let me just say,", end=" ") | |
print(arg) | |
@fun.register | |
def _(arg: int, verbose=False): | |
if verbose: | |
print("Strength in numbers, eh?", end=" ") | |
print(arg) | |
@fun.register | |
def _(arg: list, verbose=False): | |
if verbose: | |
print("Enumerate this:") | |
for i, elem in enumerate(arg): | |
print(i, elem) | |
################################################################################ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment