Last active
April 6, 2024 22:47
-
-
Save skrawcz/6b21ceb0789c5c0d2ec42885e3362093 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This module demonstrates a telephone application | |
using Burr that: | |
- captions an image | |
- creates caption embeddings (for analysis) | |
- creates a new image based on the created caption | |
""" | |
import os | |
import uuid | |
from hamilton import dataflows, driver | |
import requests | |
from burr.core import Action, ApplicationBuilder, State, default, expr | |
from burr.core.action import action | |
from burr.lifecycle import PostRunStepHook | |
# import hamilton modules | |
caption_images = dataflows.import_module("caption_images", "elijahbenizzy") | |
generate_images = dataflows.import_module("generate_images", "elijahbenizzy") | |
@action( | |
reads=["current_image_location"], | |
writes=["current_image_caption", "image_location_history"], | |
) | |
def image_caption(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]: | |
"""Action to caption an image.""" | |
current_image = state["current_image_location"] | |
result = caption_image_driver.execute( | |
["generated_caption"], inputs={"image_url": current_image} | |
) | |
updates = { | |
"current_image_caption": result["generated_caption"], | |
} | |
# could save to S3 here. | |
return result, state.update(**updates).append(image_location_history=current_image) | |
@action( | |
reads=["current_image_caption"], | |
writes=["caption_analysis"], | |
) | |
def caption_embeddings(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]: | |
result = caption_image_driver.execute( | |
["metadata"], | |
overrides={"generated_caption": state["current_image_caption"]} | |
) | |
# could save to S3 here. | |
return result, state.append(caption_analysis=result["metadata"]) | |
@action( | |
reads=["current_image_caption"], | |
writes=["current_image_location", "image_caption_history"], | |
) | |
def image_generation(state: State, generate_image_driver: driver.Driver) -> tuple[dict, State]: | |
"""Action to create an image.""" | |
current_caption = state["current_image_caption"] | |
result = generate_image_driver.execute( | |
["generated_image"], inputs={"image_generation_prompt": current_caption} | |
) | |
updates = { | |
"current_image_location": result["generated_image"], | |
} | |
# could save to S3 here. | |
return result, state.update(**updates).append(image_caption_history=current_caption) | |
@action( | |
reads=["image_location_history", "image_caption_history", "caption_analysis"], | |
writes=[] | |
) | |
def terminal_step(state: State) -> tuple[dict, State]: | |
result = {"image_location_history": state["image_location_history"], | |
"image_caption_history": state["image_caption_history"], | |
"caption_analysis": state["caption_analysis"]} | |
# could save to S3 here. | |
return result, state | |
def build_application(starting_image: str = "statemachine.png", | |
number_of_images_to_caption: int = 4): | |
"""This shows how one might define functions to be nodes.""" | |
# instantiate hamilton drivers and then bind them to the actions. | |
caption_image_driver = ( | |
driver.Builder() | |
.with_config({"include_embeddings": True}) | |
.with_modules(caption_images) | |
.build() | |
) | |
generate_image_driver = ( | |
driver.Builder() | |
.with_config({}) | |
.with_modules(generate_images) | |
.build() | |
) | |
app = ( | |
ApplicationBuilder() | |
.with_state( | |
current_image_location=starting_image, | |
current_image_caption="", | |
image_location_history=[], | |
image_caption_history=[], | |
caption_analysis=[], | |
) | |
.with_actions( | |
caption=image_caption.bind(caption_image_driver=caption_image_driver), | |
analysis=caption_embeddings.bind(caption_image_driver=caption_image_driver), | |
generate=image_generation.bind(generate_image_driver=generate_image_driver), | |
terminal=terminal_step, | |
) | |
.with_transitions( | |
("caption", "analysis", default), | |
("analysis", "terminal", | |
expr(f"len(image_caption_history) == {number_of_images_to_caption}")), | |
("analysis", "generate", default), | |
("generate", "caption", default), | |
) | |
.with_entrypoint("caption") | |
.with_tracker(project="image-telephone") | |
.build() | |
) | |
return app | |
if __name__ == "__main__": | |
import random | |
coin_flip = random.choice([True, False]) | |
# app = build_application("path/to/my/image.png") | |
app = build_application() | |
app.visualize( | |
output_file_path="statemachine", include_conditions=True, view=True, format="png" | |
) | |
if coin_flip: | |
last_action, result, state = app.run(halt_after=["terminal"]) | |
# save to S3 / download images etc. | |
print(state) | |
else: | |
# alternate way to run: | |
while True: | |
action, result, state = app.step() | |
print("action=====\n", action) | |
print("result=====\n", result) | |
# you could save S3 / download images etc. here. | |
if action.name == "terminal": | |
break | |
print(state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment