Last active
May 3, 2018 13:51
-
-
Save rmax/f4e912aadc8447463209d4fe672c260c to your computer and use it in GitHub Desktop.
An Elasticsearch reader for Dask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dask import delayed | |
from elasticsearch import Elasticsearch | |
from elasticsearch.helpers import scan | |
def read_elasticsearch(query=None, npartitions=8, client_cls=None, | |
client_kwargs=None, **kwargs): | |
"""Reads documents from Elasticsearch. | |
By default, documents are sorted by ``_doc``. For more information see the | |
scrolling section in Elasticsearch documentation. | |
Parameters | |
---------- | |
query : dict, optional | |
Search query. | |
npartitions : int, optional | |
Number of partitions, default is 8. | |
client_cls : elasticsearch.Elasticsearch, optional | |
Elasticsearch client class. | |
client_kwargs : dict, optional | |
Elasticsearch client parameters. | |
**params | |
Additional keyword arguments are passed to the the | |
``elasticsearch.helpers.scan`` function. | |
Returns | |
------- | |
out : List[Delayed] | |
A list of ``dask.Delayed`` objects. | |
Examples | |
-------- | |
Get all documents in elasticsearch. | |
>>> docs = dask.bag.from_delayed(read_elasticsearch()) | |
Get documents matching a given query. | |
>>> query = {"query": {"match_all": {}}} | |
>>> docs = dask.bag.from_delayed(read_elasticsearch(query, index="myindex", doc_type="stuff")) | |
""" | |
query = query or {} | |
# Sorting by _doc is preferred for scrolling. | |
query.setdefault('sort', ['_doc']) | |
if client_cls is None: | |
client_cls = Elasticsearch | |
values = [] | |
for idx in range(npartitions): | |
slice = {'id': idx, 'max': npartitions} | |
scan_kwargs = dict(kwargs, query=dict(query, slice=slice)) | |
values.append( | |
delayed(_elasticsearch_scan)(client_cls, client_kwargs, **scan_kwargs) | |
) | |
return values | |
def _elasticsearch_scan(client_cls, client_kwargs, **params): | |
# This method is executed in the worker's process and here we instantiate | |
# the ES client as it cannot be serialized. | |
client = client_cls(**(client_kwargs or {})) | |
return list(scan(client, **params)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
In [2]: import dask.bag | |
In [3]: from dask_elasticsearch import read_elasticsearch | |
In [4]: docs = dask.bag.from_delayed(read_elasticsearch(index="myindex")) | |
In [5]: from dask.diagnostics import progress | |
In [6]: progress.ProgressBar().register() | |
In [7]: docs | |
dask.bag<bag-fro..., npartitions=8> | |
In [8]: docs.count().compute() | |
[########################################] | 100% Completed | 0.4s | |
346 | |
In [9]: docs.map_partitions(len).compute() | |
[########################################] | 100% Completed | 0.3s | |
(46, 30, 33, 71, 66, 30, 33, 37) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment