Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active August 6, 2025 03:04
Show Gist options
  • Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Some ComfyUI nodes for ACE
# By https://github.com/blepping
# License: Apache2
#
# Place this file in your custom_nodes directory and it should load automatically.
import math
import torch
import sys
import nodes
SILENCE = torch.tensor(
(
(
-0.6462,
-1.2132,
-1.3026,
-1.2432,
-1.2455,
-1.2162,
-1.2184,
-1.2114,
-1.2153,
-1.2144,
-1.2130,
-1.2115,
-1.2063,
-1.1918,
-1.1154,
-0.7924,
),
(
0.0473,
-0.3690,
-0.6507,
-0.5677,
-0.6139,
-0.5863,
-0.5783,
-0.5746,
-0.5748,
-0.5763,
-0.5774,
-0.5760,
-0.5714,
-0.5560,
-0.5393,
-0.3263,
),
(
-1.3019,
-1.9225,
-2.0812,
-2.1188,
-2.1298,
-2.1227,
-2.1080,
-2.1133,
-2.1096,
-2.1077,
-2.1118,
-2.1141,
-2.1168,
-2.1134,
-2.0720,
-1.7442,
),
(
-4.4184,
-5.5253,
-5.7387,
-5.7961,
-5.7819,
-5.7850,
-5.7980,
-5.8083,
-5.8197,
-5.8202,
-5.8231,
-5.8305,
-5.8313,
-5.8153,
-5.6875,
-4.7317,
),
(
1.5986,
2.0669,
2.0660,
2.0476,
2.0330,
2.0271,
2.0252,
2.0268,
2.0289,
2.0260,
2.0261,
2.0252,
2.0240,
2.0220,
1.9828,
1.6429,
),
(
-0.4177,
-0.9632,
-1.0095,
-1.0597,
-1.0462,
-1.0640,
-1.0607,
-1.0604,
-1.0641,
-1.0636,
-1.0631,
-1.0594,
-1.0555,
-1.0466,
-1.0139,
-0.8284,
),
(
-0.7686,
-1.0507,
-1.3932,
-1.4880,
-1.5199,
-1.5377,
-1.5333,
-1.5320,
-1.5307,
-1.5319,
-1.5360,
-1.5383,
-1.5398,
-1.5381,
-1.4961,
-1.1732,
),
(
0.0199,
-0.0880,
-0.4010,
-0.3936,
-0.4219,
-0.4026,
-0.3907,
-0.3940,
-0.3961,
-0.3947,
-0.3941,
-0.3929,
-0.3889,
-0.3741,
-0.3432,
-0.169,
),
),
dtype=torch.float32,
device="cpu",
)[None, ..., None]
BLEND_MODES = None
def _ensure_blend_modes():
global BLEND_MODES
if BLEND_MODES is None:
bi = sys.modules.get("_blepping_integrations", {}) or getattr(
nodes,
"_blepping_integrations",
{},
)
bleh = bi.get("bleh")
if bleh is not None:
BLEND_MODES = bleh.py.latent_utils.BLENDING_MODES
else:
BLEND_MODES = {
"lerp": torch.lerp,
"a_only": lambda a, _b, _t: a,
"b_only": lambda _a, b, _t: b,
"subtract_b": lambda a, b, t: a - b * t,
}
def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)):
min_val, max_val = (
latent.amin(dim=dim, keepdim=True),
latent.amax(dim=dim, keepdim=True),
)
normalized = (latent - min_val).div_(max_val - min_val)
return (
normalized.mul_(target_max - target_min)
.add_(target_min)
.clamp_(target_min, target_max)
)
def fixup_waveform(
waveform: torch.Tensor,
*,
copy: bool = True,
move_to_cpu: bool = True,
ensure_stereo: bool = False,
) -> torch.Tensor:
if move_to_cpu:
waveform = waveform.to(device="cpu", copy=copy)
if waveform.ndim == 2:
waveform = waveform[None]
elif waveform.ndim == 1:
waveform = waveform[None, None]
if ensure_stereo and waveform.shape[1] == 1:
waveform = waveform.repeat(1, 2, 1)
return waveform
TEMPORAL_SCALE_FACTOR = 44100 / 512 / 8
class SilentLatentNode:
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("LATENT",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"seconds": (
"FLOAT",
{
"default": 120.0,
"min": 1.0,
"max": 1000.0,
"step": 0.1,
"tooltip": "Number of seconds to generate. Ignored if optional latent input is connected.",
},
),
"batch_size": (
"INT",
{
"default": 1,
"min": 1,
"max": 4096,
"tooltip": "Batch size to generate. Ignored if optional latent input is connected.",
},
),
},
"optional": {
"ref_latent_opt": (
"LATENT",
{
"tooltip": "When connected the other parameters are ignored and the latent output will match the length/batch size of the reference."
},
),
},
}
@classmethod
def go(cls, *, seconds: float, batch_size: int, ref_latent_opt=None) -> tuple[dict]:
if ref_latent_opt is not None:
latent = torch.zeros(
ref_latent_opt["samples"].shape, device="cpu", dtype=torch.float32
)
else:
length = int(seconds * TEMPORAL_SCALE_FACTOR)
latent = torch.zeros(
batch_size, 8, 16, length, device="cpu", dtype=torch.float32
)
latent += SILENCE
return ({"samples": latent, "type": "audio"},)
class VisualizeLatentNode:
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"latent": ("LATENT",),
"scale_secs": (
"INT",
{
"default": 0,
"min": 0,
"max": 1000,
"tooltip": "Horizontal scale. Number of pixels that corresponds to one second of audio. You can use 0 for no scaling which is roughly 11 pixels per second.",
},
),
"scale_vertical": (
"INT",
{
"default": 1,
"min": 1,
"max": 1024,
"tooltip": "Pixel expansion factor for channels (or frequency bands if you have swap_channels_freqs mode enabled).",
},
),
"swap_channels_freqs": (
"BOOLEAN",
{
"default": False,
"tooltip": "Swaps the order of channels and frequency in the vertical dimension. When enabled, scale_vertical applies to frequency bands.",
},
),
"normalize_dims": (
"STRING",
{
"default": "-1",
"tooltip": "Dimensions the latent scale is normalized using. Must be a comma-separated list. The default setting normalizes the channels and frequency bands independently per batch, you can try -3, -2, -1 if you want to see the relative differences better.",
},
),
"mode": (
(
"split",
"combined",
"brg",
"rgb",
"bgr",
"split_flip",
"combined_flip",
"brg_flip",
"rgb_flip",
"bgr_flip",
),
{
"default": "split",
"tooltip": "Split shows a monochrome view of of each channel/freq, combined shows the average. Flip means invert the energy in the channel (i.e. white -> black). The other modes put the latent channels into the RGB channels of the preview image.",
},
),
},
}
@classmethod
def go(
cls,
*,
latent,
scale_secs,
scale_vertical,
swap_channels_freqs,
normalize_dims,
mode,
) -> tuple:
normalize_dims = normalize_dims.strip()
normalize_dims = (
()
if not normalize_dims
else tuple(int(dim) for dim in normalize_dims.split(","))
)
samples = latent["samples"].to(dtype=torch.float32, device="cpu")
if samples.ndim != 4:
raise ValueError("Expected an ACE-Steps latent with 4 dimensions")
color_mode = mode not in {"split", "combined", "split_flip", "combined_flip"}
batch, channels, freqs, temporal = samples.shape
samples = normalize_to_scale(samples, 0.0, 1.0, dim=normalize_dims)
if mode.endswith("_flip"):
samples = 1.0 - samples
if swap_channels_freqs:
samples = samples.movedim(2, 1)
if mode.startswith("combined"):
samples = samples.mean(dim=1, keepdim=True)
if scale_vertical != 1:
samples = samples.repeat_interleave(scale_vertical, dim=2)
if not color_mode:
samples = samples.reshape(batch, -1, temporal)
if scale_secs > 0:
new_temporal = round((temporal / TEMPORAL_SCALE_FACTOR) * scale_secs)
samples = torch.nn.functional.interpolate(
samples.unsqueeze(1) if not color_mode else samples,
size=(samples.shape[-2], new_temporal),
mode="nearest-exact",
)
if not color_mode:
samples = samples.squeeze(1)
if not color_mode:
return (samples[..., None].expand(*samples.shape, 3),)
rgb_count = math.ceil(samples.shape[1] / 3)
channels_pad = rgb_count * 3 - samples.shape[1]
samples = torch.cat(
(
samples,
samples.new_zeros(samples.shape[0], channels_pad, *samples.shape[-2:]),
),
dim=1,
)
samples = torch.cat(samples.chunk(rgb_count, dim=1), dim=2).movedim(1, -1)
if mode.startswith("bgr"):
samples = samples.flip(-1)
elif mode.startswith("brg"):
samples = samples.roll(-1, -1)
return (samples,)
class SplitOutLyricsNode:
DESCRIPTION = "Allows splitting out lyrics and lyrics strength from ACE-Steps CONDITIONING objects. Note that you will only be able to join it back again if it is the same shape."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("CONDITIONING", "CONDITIONING_ACE_LYRICS")
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"conditioning": ("CONDITIONING",),
"add_fake_pooled": ("BOOLEAN", {"default": True}),
},
}
@classmethod
def go(cls, *, conditioning, add_fake_pooled) -> tuple:
tags_result, lyrics_result = [], []
for cond_t, cond_d in conditioning:
cond_d = cond_d.copy()
cond_lyr = cond_d.pop("conditioning_lyrics", None)
cond_lyrstr = cond_d.pop("lyrics_strength", None)
if add_fake_pooled:
cond_d["pooled_output"] = cond_t.new_zeros(1, 1)
tags_result.append([cond_t.clone(), cond_d])
lyrics_result.append(
{
"conditioning_lyrics": cond_lyr.clone(),
"lyrics_strength": cond_lyrstr,
}
)
return (tags_result, lyrics_result)
class JoinLyricsNode:
DESCRIPTION = "Allows joining CONDITIONING_ACE_LYRICS back into CONDITIONING. Will overwrite any lyrics that exist. Must be the same shape as the conditioning the lyrics were split from."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("CONDITIONING",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"conditioning_tags": ("CONDITIONING",),
"conditioning_lyrics": ("CONDITIONING_ACE_LYRICS",),
},
}
@classmethod
def go(cls, *, conditioning_tags: list, conditioning_lyrics: list) -> tuple[list]:
ct_len, cl_len = len(conditioning_tags), len(conditioning_lyrics)
if ct_len != cl_len:
raise ValueError(
f"Different lengths for tags {ct_len} vs conditioning lyrics {cl_len}"
)
if ct_len > 0 and conditioning_lyrics[0].get("conditioning_lyrics") is None:
raise ValueError(
"conditioning_lyrics missing items, cannot combine with it."
)
result = [
[
cond_t.clone(),
cond_d.copy()
| {
"conditioning_lyrics": cond_l["conditioning_lyrics"].clone(),
"lyrics_strength": cond_l["lyrics_strength"],
"pooled_output": None,
},
]
for (cond_t, cond_d), cond_l in zip(conditioning_tags, conditioning_lyrics)
]
return (result,)
class SetAudioDtypeNode:
DESCRIPTION = "Advanced node that allows the datatype of the audio waveform. The 16 and 8 bit types are not recommended."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
_ALLOWED_DTYPES = (
"float64",
"float32",
"float16",
"bfloat16",
"float8_e4m3fn",
"float8_e5m2",
)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"audio": ("AUDIO",),
"dtype": (
cls._ALLOWED_DTYPES,
{"default": "float64", "tooltip": "TBD"},
),
},
}
@classmethod
def go(cls, *, audio: dict, dtype: str) -> tuple[dict]:
if dtype not in cls._ALLOWED_DTYPES:
raise ValueError("Bad dtype")
waveform = audio["waveform"]
dt = getattr(torch, dtype)
if waveform.dtype == dt:
return (audio,)
return (audio | {"waveform": waveform.to(dtype=dt)},)
class AudioLevelsNode:
DESCRIPTION = "The values in the waveform range for -1 to 1. This node allows you to scale audio to a percentage of that range."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"audio": ("AUDIO",),
"scale": (
"FLOAT",
{
"default": 0.95,
"min": 0.0,
"max": 1.0,
"tooltip": "Percentage where 1.0 indicates 100% of the maximum allowed value in an audio tensor. You can use 1.0 to make it as loud as possible without actually clipping.",
},
),
"per_channel": (
"BOOLEAN",
{
"default": False,
"tooltip": "When enabled, the levels for each channel will be scaled independently. For multi-channel audio (like stereo) enabling this will not preserve the relative levels between the channels so probably should be left disabled most of the time.",
},
),
},
}
@classmethod
def go(cls, *, audio: dict, scale: float, per_channel: bool) -> tuple[dict]:
waveform = audio["waveform"].to(device="cpu", copy=True)
if waveform.ndim == 1:
waveform = waveform[None, None, ...]
elif waveform.ndim == 2:
waveform = waveform[None, ...]
elif waveform.ndim != 3:
raise ValueError("Unexpected number of dimensions in waveform!")
max_val = (
waveform.abs().flatten(start_dim=2 if per_channel else 1).max(dim=-1).values
)
max_val = max_val[..., None] if per_channel else max_val[..., None, None]
# Max could be 0, multiplying by 0 is fine in that case.
waveform *= (scale / max_val).nan_to_num()
return (audio | {"waveform": waveform.clamp(-1.0, 1.0)},)
class AudioAsLatentNode:
DESCRIPTION = "This node allows you to rearrange AUDIO to look like a LATENT. Can be useful if you want to apply some latent operations to AUDIO. Can be reversed with the ACETricks LatentAsAudio node."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("LATENT",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"audio": ("AUDIO",),
"use_width": (
"BOOLEAN",
{
"default": True,
"tooltip": "When enabled, you'll get a 4 channel with height 1 and the audio audio data in the width dimension, otherwise the opposite.",
},
),
},
}
@classmethod
def go(cls, *, audio: dict, use_width: bool) -> tuple:
waveform = audio["waveform"].to(device="cpu", copy=True)
if waveform.ndim == 1:
waveform = waveform[None, None, ...]
elif waveform.ndim == 2:
waveform = waveform[None, ...]
elif waveform.ndim != 3:
raise ValueError("Unexpected number of dimensions in waveform!")
waveform = waveform.unsqueeze(2) if use_width else waveform[..., None]
return ({"samples": waveform},)
class LatentAsAudioNode:
DESCRIPTION = "This node lets you rearrange a LATENT to look like AUDIO. Mainly useful for getting back after using the ACETricks AudioAsLatent node and performing some operations. If you connect the optional audio input it will use whatever non-waveform parameters exist in it (can be stuff like the sample rate), otherwise it will just add sample_rate: 41000 and the waveform."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"latent": ("LATENT",),
"values_mode": (
("rescale", "clamp"),
{"default": "rescale"},
),
"use_width": (
"BOOLEAN",
{
"default": True,
"tooltip": "When enabled, takes the audio data from the first item in the width dimension, otherwise height.",
},
),
},
"optional": {
"audio_opt": (
"AUDIO",
{
"tooltip": "Optional audio to use as a reference for sample rate and possibly other values."
},
),
},
}
@classmethod
def go(
cls,
*,
latent: dict,
values_mode: str,
use_width: bool,
audio_opt: dict | None = None,
) -> tuple:
samples = latent["samples"]
if samples.ndim != 4:
raise ValueError("Expected a 4D latent but didn't get one")
samples = (samples[..., 0, :] if use_width else samples[..., 0]).to(
device="cpu", copy=True
)
if audio_opt is None:
audio_opt = {"sample_rate": 44100}
result = audio_opt | {"waveform": samples}
if values_mode == "clamp":
result["waveform"] = samples.clamp(-1.0, 1.0)
elif torch.any(samples.abs() > 1.0):
return AudioLevelsNode.go(audio=result, per_channel=False, scale=1.0)
return (result,)
class MonoToStereoNode:
DESCRIPTION = "Can convert mono AUDIO to stereo. It will leave AUDIO that's already stereo alone. Note: Always adds a batch dimension if it doesn't exist and moves to the CPU device."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {"required": {"audio": ("AUDIO",)}}
@classmethod
def go(cls, *, audio: dict) -> tuple:
waveform = audio["waveform"].to(device="cpu")
if waveform.ndim == 2:
waveform = waveform[None]
elif waveform.ndim == 1:
waveform = waveform[None, None]
channels = waveform.shape[1]
audio = audio.copy()
if channels == 1:
waveform = waveform.repeat(1, 2, 1)
audio["waveform"] = waveform
return (audio,)
class AudioBlendNode:
DESCRIPTION = "Blends two AUDIO inputs together. If you have ComfyUI-bleh installed you will have access to many additional blend modes."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
@classmethod
def INPUT_TYPES(cls) -> dict:
_ensure_blend_modes()
assert BLEND_MODES is not None # Make static analysis happy.
return {
"required": {
"audio_a": ("AUDIO",),
"audio_b": ("AUDIO",),
"audio_b_strength": (
"FLOAT",
{
"default": 0.5,
"min": -1000.0,
"max": 1000.0,
},
),
"blend_mode": (
tuple(BLEND_MODES.keys()),
{
"default": "lerp",
},
),
"length_mismatch_mode": (
("shrink", "blend"),
{
"default": "shrink",
"tooltip": "Shrink mode will return audio matching whatever the shortest input was. Blend will blend up to the shortest input's size and use unblended longer input to fill the rest. Note that this adjustment occurs before blending.",
},
),
"normalization_mode": (
("clamp", "levels", "levels_per_channel", "none"),
{
"default": "levels",
"tooltip": "Clamp will just clip the result to ensure it is within the permitted range. Levels will rebalance it so the maximum value is the maximum value for the permitted range. Levels per channel is the same, except the maximum value is determined separately per channel. Setting this to none is not recommended unless you are planning to do your own normalization as it may leave invalid values in the audio latent.",
},
),
"result_template": (
("a", "b"),
{
"default": "a",
"tooltip": "AUDIOs contain metadata like sampling rate. The result will be based on the metadata from the audio input you select here, with the blended result as the waveform in it.",
},
),
}
}
@classmethod
def go(
cls,
*,
audio_a: dict,
audio_b: dict,
audio_b_strength: float,
blend_mode: str,
length_mismatch_mode: str,
normalization_mode: str,
result_template: str,
) -> tuple:
wa = fixup_waveform(audio_a["waveform"])
wb = fixup_waveform(audio_b["waveform"])
if wa.dtype != wb.dtype:
wa = wa.to(dtype=torch.float32)
wb = wb.to(dtype=torch.float32)
if wa.shape[:-1] != wb.shape[:-1]:
errstr = f"Unexpected batch or channels shape mismatch in audio. audio_a has shape {wa.shape}, audio_b has shape {wb.shape}"
raise ValueError(errstr)
assert BLEND_MODES is not None # Make static analysis happy.
blend_function = BLEND_MODES[blend_mode]
walen, wblen = wa.shape[-1], wb.shape[-1]
if walen != wblen:
if length_mismatch_mode == "shrink":
minlen = min(walen, wblen)
wa = wa[..., :minlen]
wb = wb[..., :minlen]
elif walen > wblen:
wb_temp = wa.clone()
wb_temp[..., :wblen] = wb
wb = wb_temp
else:
wa_temp = wb.clone()
wa_temp[..., :walen] = wa
wa = wa_temp
walen = wblen = wa.shape[-1]
result = blend_function(wa, wb, audio_b_strength)
result_audio = audio_a.copy() if result_template == "a" else audio_b.copy()
if normalization_mode == "clamp":
result = result.clamp_(min=-1.0, max=1.0)
elif normalization_mode in {"levels", "levels_per_channel"}:
result = AudioLevelsNode.go(
audio={"waveform": result},
scale=1.0,
per_channel=normalization_mode == "levels_per_channel",
)[0]["waveform"]
result_audio["waveform"] = result
return (result_audio,)
class AudioFromBatchNode:
DESCRIPTION = "Can be used to extract batch items from AUDIO."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"audio": ("AUDIO",),
"start": (
"INT",
{
"default": 0,
"tooltip": "Start index (zero-based). Negative indexes count from the end.",
},
),
"length": ("INT", {"default": 1, "min": 0}),
}
}
@classmethod
def go(cls, *, audio: dict, start: int, length: int) -> tuple:
waveform = audio["waveform"]
if not waveform.ndim == 3:
raise ValueError("Expected 3D waveform")
batch = waveform.shape[0]
if start < 0:
start = batch + start
if start < 0:
raise ValueError("Start index is out of range")
new_waveform = waveform[start : start + length].clone()
return (audio | {"waveform": new_waveform},)
NODE_CLASS_MAPPINGS = {
"ACETricks SilentLatent": SilentLatentNode,
"ACETricks VisualizeLatent": VisualizeLatentNode,
"ACETricks CondSplitOutLyrics": SplitOutLyricsNode,
"ACETricks CondJoinLyrics": JoinLyricsNode,
"ACETricks SetAudioDtype": SetAudioDtypeNode,
"ACETricks AudioLevels": AudioLevelsNode,
"ACETricks AudioAsLatent": AudioAsLatentNode,
"ACETricks LatentAsAudio": LatentAsAudioNode,
"ACETricks MonoToStereo": MonoToStereoNode,
"ACETricks AudioBlend": AudioBlendNode,
"ACETricks AudioFromBatch": AudioFromBatchNode,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment