Last active
October 31, 2020 06:34
-
-
Save zzzeek/1f7bd3cc19e7ebea8f364e5b40577067 to your computer and use it in GitHub Desktop.
buildything with lambdas
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# note: we are using SQLAlchemy that includes _cache_key | |
# currently at: | |
# https://gerrit.sqlalchemy.org/#/c/sqlalchemy/sqlalchemy/+/1204/5/ | |
import typing | |
from sqlalchemy import bindparam | |
from sqlalchemy import Column | |
from sqlalchemy import func | |
from sqlalchemy import inspection | |
from sqlalchemy import Integer | |
from sqlalchemy import select | |
from sqlalchemy import String | |
from sqlalchemy import util | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.sql import elements | |
from sqlalchemy.sql import type_api | |
class BakedSelect: | |
"""BakedSelect, is mostly like BakedQuery but a little more directed, will | |
build objects that are equivalent to both Core select() as well as | |
ORM Query. | |
It will also not be called BakedQuery, it will be an object that comes | |
automatically when you use select() and pass lambdas or callables to it. | |
""" | |
def __init__(self, *args): | |
self.steps = [] | |
self._append_step("columns", *args) | |
def _append_step(self, name, arg0, *args, **kw): | |
if hasattr(arg0, "__call__"): | |
assert not args | |
assert not kw | |
self._append_lambda_step(name, arg0) | |
else: | |
self._append_plain_step(name, arg0, *args, **kw) | |
def _append_lambda_step(self, name, fn): | |
self.steps.append(CallableStep(name, fn.__code__, fn)) | |
def _append_plain_step(self, name, *args, **kw): | |
binds = [] | |
cache_key = tuple(arg._cache_key(bindparams=binds) for arg in args) | |
self.steps.append(ClauseElementStep(name, cache_key, binds, args, kw)) | |
def where(self, *args): | |
self._append_step("where", *args) | |
return self | |
def order_by(self, *args): | |
self._append_step("order_by", *args) | |
return self | |
@property | |
def _cache_key(self): | |
return tuple((step.name, step.cache_key) for step in self.steps) | |
def result(self, cache): | |
return BakedResult(self, cache) | |
class CachedStep: | |
def invoke(self, baked_result, index): | |
"""invoke this step against a BakedResult.""" | |
raise NotImplementedError() | |
class CallableStep(CachedStep, typing.NamedTuple): | |
"""A cache step that consists of a callable that is invoked at | |
assembly time.""" | |
name: str | |
cache_key: object | |
fn: typing.Callable | |
def invoke(self, baked_result, index): | |
"""invoke this step against a BakedResult.""" | |
fn = self.fn | |
# track objects referenced inside of lambdas, create bindparams | |
# ahead of time for literal values. If bindparams are produced, | |
# then rewrite the function globals and closure as necessary so that | |
# it refers to the bindparams, then invoke the function | |
new_closure = {} | |
new_globals = fn.__globals__.copy() | |
tracker_collection = [] | |
for name in fn.__code__.co_names: | |
if name not in new_globals: | |
continue | |
bound_value = elements._clause_element_as_expr(new_globals[name]) | |
if _is_simple_literal(bound_value): | |
new_globals[name] = bind = bindparam( | |
name, | |
unique=True, | |
type_=type_api._resolve_value_to_type(bound_value), | |
) | |
tracker_collection.append(GlobalsTracker(index, name, bind)) | |
if fn.__closure__: | |
for closure_index, (fv, cell) in enumerate( | |
zip(fn.__code__.co_freevars, fn.__closure__) | |
): | |
bound_value = elements._clause_element_as_expr( | |
cell.cell_contents | |
) | |
if _is_simple_literal(bound_value): | |
new_closure[fv] = bind = bindparam( | |
fv, | |
unique=True, | |
type_=type_api._resolve_value_to_type(bound_value), | |
) | |
tracker_collection.append( | |
ClosureTracker(index, fv, bind, closure_index) | |
) | |
else: | |
new_closure[fv] = cell.cell_contents | |
if tracker_collection: | |
baked_result.trackers.extend(tracker_collection) | |
new_fn = _rewrite_code_obj( | |
fn, | |
[new_closure[name] for name in fn.__code__.co_freevars], | |
new_globals, | |
) | |
return util.to_list(new_fn()) | |
else: | |
return util.to_list(fn()) | |
class ClauseElementStep(CachedStep, typing.NamedTuple): | |
"""A cache step that consists of a sequence of ClauseElement objects.""" | |
name: str | |
cache_key: typing.Tuple[typing.Any] | |
binds: typing.Sequence[elements.BindParameter] | |
args: typing.Tuple[typing.Any] | |
kw: typing.Dict[str, typing.Any] | |
def invoke(self, baked_result, index): | |
"""invoke this step against a BakedResult.""" | |
baked_result.trackers.extend( | |
BindParamTracker(index, bind_index, bind._identifying_key, bind) | |
for bind_index, bind in enumerate(self.binds) | |
) | |
return self.args | |
class Tracker: | |
"""Tracks the value of a single BindParameter object. | |
Tracker objects are constructed for every bound value when the | |
baked query is first assembled. They store the information needed in | |
order to re-acquire an updated bound value from a new invocation of the | |
same cached statement. The trackers are stored in the cache along with | |
the cached statement. When a new BakedQuery is built up again consisting | |
of a new series of lambdas and/or expressions, the Tracker objects extract | |
updated bound values from the new steps and match them up to the steps | |
that were cached. | |
""" | |
def extract_parameter_value(self, result): | |
raise NotImplementedError() | |
class BindParamTracker(Tracker, typing.NamedTuple): | |
"""tracks BindParameter objects inside of a Core expression. | |
The expression system includes a feature inside the _cache_key() | |
function that gathers up BindParameter objects in deterministic order | |
as the cache key is formulated. This tracker is used when a plain | |
expression without a lambda is given. | |
""" | |
step_index: int | |
bind_index: int | |
name: str | |
value: elements.BindParameter | |
def extract_parameter_value(self, result): | |
binds = result.query.steps[self.step_index].binds | |
result.parameters[self.name] = binds[self.bind_index].value | |
class GlobalsTracker(Tracker, typing.NamedTuple): | |
"""tracks literal values inside the __globals__ of a function.""" | |
step_index: int | |
name: str | |
value: elements.BindParameter | |
def extract_parameter_value(self, result): | |
current_fn = result.query.steps[self.step_index].fn | |
result.parameters[ | |
self.value._identifying_key | |
] = current_fn.__globals__[self.name] | |
class ClosureTracker(Tracker, typing.NamedTuple): | |
"""tracks literal values inside the __closure__ of a function.""" | |
step_index: int | |
name: str | |
value: elements.BindParameter | |
closure_index: int | |
def extract_parameter_value(self, result): | |
current_fn = result.query.steps[self.step_index].fn | |
result.parameters[ | |
self.value._identifying_key | |
] = current_fn.__closure__[self.closure_index].cell_contents | |
class BakedResult: | |
"""BakedResult is like the existing BakedResult except it will meet an | |
interface for a generic "result" object that will be used in all cases; | |
Core select(), ORM "select()", baked or not, etc. E.g. it will be | |
transparent. The "cache" will be associated with the engine and/or | |
dialect. | |
""" | |
def __init__(self, query, cache): | |
self.cache = cache | |
self.query = query | |
self.parameters = {} | |
self.trackers = [] | |
self._assemble() | |
def _assemble(self): | |
key = self.query._cache_key | |
if key in self.cache: | |
print("return from cache") | |
self.statement, self.trackers = self.cache[key] | |
for tracker in self.trackers: | |
tracker.extract_parameter_value(self) | |
else: | |
print("generate new query with steps") | |
stmt = select() | |
for idx, step in enumerate(self.query.steps): | |
result = step.invoke(self, idx) | |
if step.name == "columns": | |
# TODO: this should be a single call | |
for c in result: | |
stmt.append_column(c) | |
elif step.name == "where": | |
stmt = stmt.where(*result) | |
elif step.name == "order_by": | |
stmt = stmt.order_by(*result) | |
stmt = stmt.compile( | |
compile_kwargs={"prevent_implicit_binds": True} | |
) | |
self.cache[key] = (stmt, self.trackers) | |
self.statement = stmt | |
for tracker in self.trackers: | |
tracker.extract_parameter_value(self) | |
def _is_simple_literal(value): | |
insp = inspection.inspect(value, raiseerr=False) | |
return ( | |
insp is None | |
and not isinstance(value, elements.Visitable) | |
and not hasattr(value, "__clause_element__") | |
) | |
def _rewrite_code_obj(f, cell_values, globals_): | |
"""Return a copy of f, with a new closure and new globals | |
yes it works in pypy :P | |
""" | |
argrange = range(len(cell_values)) | |
code = "def make_cells():\n" | |
if cell_values: | |
code += " (%s) = (%s)\n" % ( | |
", ".join("i%d" % i for i in argrange), | |
", ".join("o%d" % i for i in argrange), | |
) | |
code += " def closure():\n" | |
code += " return %s\n" % ", ".join("i%d" % i for i in argrange) | |
code += " return closure.__closure__" | |
vars_ = {"o%d" % i: cell_values[i] for i in argrange} | |
exec(code, vars_, vars_) | |
closure = vars_["make_cells"]() | |
func = type(f)(f.__code__, globals_, f.__name__, f.__defaults__, closure) | |
func.__annotations__ = f.__annotations__ | |
func.__doc__ = f.__doc__ | |
func.__kwdefaults__ = f.__kwdefaults__ | |
func.__module__ = f.__module__ | |
return func | |
# demo! | |
Base = declarative_base() | |
class User(Base): | |
__tablename__ = "users" | |
id = Column(Integer, primary_key=True) | |
name = Column(String) | |
x_value = Column(Integer) | |
def go_lambdas(name, x, y): | |
return ( | |
BakedSelect(lambda: (func.bar(User.id), User.name)) | |
.where(lambda: User.name == name) | |
.where(lambda: User.x_value == x + y + q) | |
.order_by(lambda: User.id + y) | |
) | |
def go_plain(name, x, y): | |
return ( | |
BakedSelect(func.bar(User.id), User.name) | |
.where(User.name == name) | |
.where(User.x_value == x + y + q) | |
.order_by(User.id + y) | |
) | |
cache = {} | |
for go in (go_lambdas, go_plain): | |
q = 18 | |
result = go("name1", 5, 5).result(cache) | |
print(result.statement) | |
print(result.statement.construct_params(result.parameters)) | |
result = go("name2", 10, 7).result(cache) | |
print(result.statement) | |
print(result.statement.construct_params(result.parameters)) | |
q = 25 | |
result = go("name3", 10, 9).result(cache) | |
print(result.statement) | |
print(result.statement.construct_params(result.parameters)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
prints: