Last active
April 13, 2021 10:35
-
-
Save AndrewIngram/b1a6e66ce92d2d0befd2f2f65eb62ca5 to your computer and use it in GitHub Desktop.
Proper cursors with Graphene + Django. Graphene-Django's stock connections use limit/offset logic under the hood, making the whole cursor-based connection modelling kinda pointless.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import operator | |
from base64 import b64decode as _unbase64 | |
from base64 import b64encode as _base64 | |
from functools import reduce | |
from django.db.models import Q | |
from graphene import relay | |
from graphql_relay.connection import connectiontypes | |
def base64(s): | |
return _base64(s.encode('utf-8')).decode('utf-8') | |
def unbase64(s): | |
return _unbase64(s).decode('utf-8') | |
def get_attribute(instance, name): | |
if hasattr(instance, name): | |
return getattr(instance, name) | |
names = name.split("__") | |
name = names.pop(0) | |
if len(names) == 0: | |
return None | |
if hasattr(instance, name): | |
value = getattr(instance, name) | |
return get_attribute(value, "__".join(names)) | |
return None | |
def attr_from_sort(sort): | |
if sort[0] == '-': | |
return sort[1:] | |
return sort | |
def build_q_objects(sort, cursor_parts): | |
attr = attr_from_sort(sort[-1]) | |
sort_direction = 'lt' if sort[-1][0] == '-' else 'gt' | |
kwargs = { | |
f"{attr}__{sort_direction}": cursor_parts[attr], | |
} | |
for x in sort[0:-1]: | |
x_attr = attr_from_sort(x) | |
kwargs[x_attr] = cursor_parts[x_attr] | |
q = Q(**kwargs) | |
if len(sort) == 1: | |
return [q] | |
acc = build_q_objects(sort[0: -1], cursor_parts) | |
acc.append(q) | |
return acc | |
def cursor_string_from_parts(parts, sort): | |
bits = [] | |
for x in sort: | |
attr = attr_from_sort(x) | |
bits.append(parts[attr]) | |
return base64('|'.join(bits)) | |
def parts_from_cursor_string(cursor, sort): | |
cursor_parts = {} | |
bits = unbase64(cursor).split('|') | |
for i, x in enumerate(sort): | |
cursor_parts[attr_from_sort(x)] = bits[i] | |
return cursor_parts | |
def cursor_string_from_obj(obj, sort): | |
cursor_parts = {} | |
for x in sort: | |
attr_name = attr_from_sort(x) | |
attr = get_attribute(obj, attr_name) | |
if isinstance(attr, datetime.datetime): | |
attr = attr.isoformat() | |
else: | |
attr = str(attr) | |
cursor_parts[attr_name] = attr | |
return cursor_string_from_parts(cursor_parts, sort) | |
def filter_queryset(qs, cursor, sort): | |
cursor_parts = parts_from_cursor_string(cursor, sort) | |
q_objects = build_q_objects(sort, cursor_parts) | |
return qs.filter(reduce(operator.__or__, q_objects)) | |
class QuerysetConnectionField(relay.ConnectionField): | |
def __init__(self, type, *args, **kwargs): | |
return_value = super().__init__( | |
type, | |
*args, | |
**kwargs | |
) | |
# Validate class methods | |
assert hasattr(type, 'get_queryset'), f'Connection type {type} needs a `get_queryset` method' | |
assert hasattr(type, 'get_sort'), f'Connection type {type} needs a `get_sort` method' | |
return return_value | |
@classmethod | |
def connection_resolver(cls, resolver, connection_type, root, info, **args): | |
if hasattr(connection_type, 'of_type'): | |
connection_type = connection_type.of_type | |
first = args.get('first') | |
last = args.get('last') | |
after = args.get('after') | |
# before = args.get('before') | |
sort = connection_type.get_sort(**args) | |
# Validate connection arguments | |
assert first or last, ( | |
'You must provide a `first` or `last` value to properly paginate the `{}` connection.' | |
).format(info.field_name) | |
assert not (first and last), ( | |
'You cannot define both `first` and `last` values on `{}` connection.' | |
).format(info.field_name) | |
assert not last, 'last` argument is not supported' | |
qs = connection_type.get_queryset(root, info, **args).order_by(*sort) | |
if after: | |
qs = filter_queryset(qs, after, sort) | |
total_length = qs.count() | |
if first: | |
qs = qs[:first] | |
edge_type = connection_type.Edge or connectiontypes.Edge | |
edges = [ | |
edge_type( | |
node=node, | |
cursor=cursor_string_from_obj(node, sort) | |
) | |
for node in qs.iterator() | |
] | |
first_edge_cursor = edges[0].cursor if edges else None | |
last_edge_cursor = edges[-1].cursor if edges else None | |
page_info = relay.PageInfo( | |
start_cursor=first_edge_cursor, | |
end_cursor=last_edge_cursor, | |
has_previous_page=False, # TODO | |
has_next_page=isinstance(first, int) and (total_length > first), | |
) | |
return connection_type( | |
edges=edges, | |
page_info=page_info, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class User(graphene.ObjectType): | |
name = graphene.String(required=True) | |
class AllUsersConnection(graphene.relay.Connection): | |
@classmethod | |
def get_queryset(cls, root, info, **kwargs): | |
# Root would be the Query type instance in this case | |
return DjangoUser.objects.all() | |
@classmethod | |
def get_sort(cls, **kwargs): | |
return ("name",) | |
class Meta: | |
node = User | |
class Query(graphene.ObjectType): | |
all_users = QuerysetConnectionField(AllUsersConnection, required=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment