Last active
June 12, 2024 18:13
-
-
Save jeffehobbs/5b42b37c90ab4c53e11048b70efb2123 to your computer and use it in GitHub Desktop.
Ollama creates the prompt, Stable Audio renders music based on the prompt. wash rinse repeat
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ollama | |
import hashlib | |
from pydub import AudioSegment | |
import torch | |
import torchaudio | |
from einops import rearrange | |
from stable_audio_tools import get_pretrained_model | |
from stable_audio_tools.inference.generation import generate_diffusion_cond | |
# globals | |
NUM_LOOPS = 100 | |
MUSIC_PROMPT = 'Write a single-sentence prompt for a music creation program. Do not enclose the sentence in quotation marks. The bpm should be 128bpm. The genre should be a mix of the following genres: Electronic, Ambient, Chill-out, Jazz, Lo-fi Instrumental, Hip-Hop Beats.' | |
def main(): | |
i = 1 | |
while i < NUM_LOOPS: | |
generate_loop() | |
i = i + 1 | |
def generate_loop(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
device = "cpu" # override line above, comment this line out if you have a good GPU/many CUDAs | |
# Download model | |
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0") | |
sample_rate = model_config["sample_rate"] | |
sample_size = model_config["sample_size"] | |
model = model.to(device) | |
generated_prompt = get_prompt() | |
hash = hashlib.md5(generated_prompt.encode()) | |
# Set up text and timing conditioning | |
conditioning = [{ | |
"prompt": generated_prompt, | |
"seconds_start": 0, | |
"seconds_total": 37.5 | |
}] | |
# Generate stereo audio | |
output = generate_diffusion_cond( | |
model, | |
steps=100, | |
cfg_scale=7, | |
conditioning=conditioning, | |
sample_size=sample_size, | |
sigma_min=0.3, | |
sigma_max=500, | |
sampler_type="dpmpp-3m-sde", | |
device=device | |
) | |
# Rearrange audio batch to a single sequence | |
output = rearrange(output, "b d n -> d (b n)") | |
# Peak normalize, clip, convert to int16, and save to file | |
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
torchaudio.save(f"output/{hash.hexdigest()}.wav", output, sample_rate) | |
# crop audio | |
full_song = AudioSegment.from_wav(f"output/{hash.hexdigest()}.wav") | |
seconds_long = 37.5 * 1000 | |
crop = full_song[:seconds_long] | |
crop.export(f"output/{hash.hexdigest()}.wav", format="wav") | |
# save prompt into metadata manifest | |
f = open(f"output/{hash.hexdigest()}.txt", "a") | |
f.write(generated_prompt) | |
f.close() | |
def get_prompt(): | |
response = ollama.chat(model='llama3', messages=[ | |
{ | |
'role': 'user', | |
'content': MUSIC_PROMPT, | |
}, | |
]) | |
print('\n---\n' + response['message']['content'] + '\n---\n') | |
return(response['message']['content']) | |
if __name__ == "__main__" : | |
main() | |
#fin | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment