Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active July 4, 2025 00:16
Show Gist options
  • Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Some ComfyUI nodes for ACE
# By https://github.com/blepping
# License: Apache2
#
# Place this file in your custom_nodes directory and it should load automatically.
import math
import torch
SILENCE = torch.tensor((
(-0.6462, -1.2132, -1.3026, -1.2432, -1.2455, -1.2162, -1.2184, -1.2114, -1.2153, -1.2144, -1.2130, -1.2115, -1.2063, -1.1918, -1.1154, -0.7924),
( 0.0473, -0.3690, -0.6507, -0.5677, -0.6139, -0.5863, -0.5783, -0.5746, -0.5748, -0.5763, -0.5774, -0.5760, -0.5714, -0.5560, -0.5393, -0.3263),
(-1.3019, -1.9225, -2.0812, -2.1188, -2.1298, -2.1227, -2.1080, -2.1133, -2.1096, -2.1077, -2.1118, -2.1141, -2.1168, -2.1134, -2.0720, -1.7442),
(-4.4184, -5.5253, -5.7387, -5.7961, -5.7819, -5.7850, -5.7980, -5.8083, -5.8197, -5.8202, -5.8231, -5.8305, -5.8313, -5.8153, -5.6875, -4.7317),
( 1.5986, 2.0669, 2.0660, 2.0476, 2.0330, 2.0271, 2.0252, 2.0268, 2.0289, 2.0260, 2.0261, 2.0252, 2.0240, 2.0220, 1.9828, 1.6429),
(-0.4177, -0.9632, -1.0095, -1.0597, -1.0462, -1.0640, -1.0607, -1.0604, -1.0641, -1.0636, -1.0631, -1.0594, -1.0555, -1.0466, -1.0139, -0.8284),
(-0.7686, -1.0507, -1.3932, -1.4880, -1.5199, -1.5377, -1.5333, -1.5320, -1.5307, -1.5319, -1.5360, -1.5383, -1.5398, -1.5381, -1.4961, -1.1732),
( 0.0199, -0.0880, -0.4010, -0.3936, -0.4219, -0.4026, -0.3907, -0.3940, -0.3961, -0.3947, -0.3941, -0.3929, -0.3889, -0.3741, -0.3432, -0.169),
), dtype=torch.float32, device="cpu")[None, ..., None]
BLEND_MODES = None
def _ensure_blend_modes():
global BLEND_MODES
if BLEND_MODES is None:
bi = sys.modules.get("_blepping_integrations", {}) or getattr(
nodes,
"_blepping_integrations",
{},
)
bleh = bi.get("bleh")
if bleh is not None:
BLEND_MODES = bleh.py.latent_utils.BLENDING_MODES
else:
BLEND_MODES = {
"lerp": torch.lerp,
"a_only": lambda a, _b, _t: a,
"b_only": lambda _a, b, _t: b,
"subtract_b": lambda a, b, t: a - b * t,
}
def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)):
min_val, max_val = (
latent.amin(dim=dim, keepdim=True),
latent.amax(dim=dim, keepdim=True),
)
normalized = (latent - min_val).div_(max_val - min_val)
return (
normalized.mul_(target_max - target_min)
.add_(target_min)
.clamp_(target_min, target_max)
)
TEMPORAL_SCALE_FACTOR = 44100 / 512 / 8
class SilentLatentNode:
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("LATENT",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1, "tooltip": "Number of seconds to generate. Ignored if optional latent input is connected."}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "Batch size to generate. Ignored if optional latent input is connected."}),
},
"optional": {
"ref_latent_opt": ("LATENT", {"tooltip": "When connected the other parameters are ignored and the latent output will match the length/batch size of the reference."}),
},
}
@classmethod
def go(cls, *, seconds: float, batch_size: int, ref_latent_opt=None) -> dict:
if ref_latent_opt is not None:
latent = torch.zeros(ref_latent_opt["samples"].shape, device="cpu", dtype=torch.float32)
else:
length = int(seconds * TEMPORAL_SCALE_FACTOR)
latent = torch.zeros(batch_size, 8, 16, length, device="cpu", dtype=torch.float32)
latent += SILENCE
return ({"samples": latent, "type": "audio"},)
class VisualizeLatentNode:
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("IMAGE",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"latent": ("LATENT",),
"scale_secs": (
"INT",
{
"default": 0, "min": 0, "max": 1000,
"tooltip": "Horizontal scale. Number of pixels that corresponds to one second of audio. You can use 0 for no scaling which is roughly 11 pixels per second.",
},
),
"scale_vertical": (
"INT",
{
"default": 1,
"min": 1,
"max": 1024,
"tooltip": "Pixel expansion factor for channels (or frequency bands if you have swap_channels_freqs mode enabled).",
},
),
"swap_channels_freqs": (
"BOOLEAN",
{
"default": False,
"tooltip": "Swaps the order of channels and frequency in the vertical dimension. When enabled, scale_vertical applies to frequency bands.",
},
),
"normalize_dims": (
"STRING",
{
"default": "-1",
"tooltip": "Dimensions the latent scale is normalized using. Must be a comma-separated list. The default setting normalizes the channels and frequency bands independently per batch, you can try -3, -2, -1 if you want to see the relative differences better.",
},
),
"mode": (
("split", "combined", "brg", "rgb", "bgr", "split_flip", "combined_flip", "brg_flip", "rgb_flip", "bgr_flip"), {
"default": "split",
"tooltip": "Split shows a monochrome view of of each channel/freq, combined shows the average. Flip means invert the energy in the channel (i.e. white -> black). The other modes put the latent channels into the RGB channels of the preview image.",
},
),
},
}
@classmethod
def go(cls, *, latent, scale_secs, scale_vertical, swap_channels_freqs, normalize_dims, mode) -> tuple:
normalize_dims = normalize_dims.strip()
normalize_dims = () if not normalize_dims else tuple(int(dim) for dim in normalize_dims.split(","))
samples = latent["samples"].to(dtype=torch.float32, device="cpu")
if samples.ndim != 4:
raise ValueError("Expected an ACE-Steps latent with 4 dimensions")
color_mode = mode not in {"split", "combined", "split_flip", "combined_flip"}
batch, channels, freqs, temporal = samples.shape
samples = normalize_to_scale(samples, 0.0, 1.0, dim=normalize_dims)
if mode.endswith("_flip"):
samples = 1.0 - samples
if swap_channels_freqs:
samples = samples.movedim(2, 1)
if mode.startswith("combined"):
samples = samples.mean(dim=1, keepdim=True)
if scale_vertical != 1:
samples = samples.repeat_interleave(scale_vertical, dim=2)
if not color_mode:
samples = samples.reshape(batch, -1, temporal)
if scale_secs > 0:
new_temporal = round((temporal / TEMPORAL_SCALE_FACTOR) * scale_secs)
samples = torch.nn.functional.interpolate(
samples.unsqueeze(1) if not color_mode else samples,
size=(samples.shape[-2], new_temporal),
mode="nearest-exact",
)
if not color_mode:
samples = samples.squeeze(1)
if not color_mode:
return (samples[..., None].expand(*samples.shape, 3),)
rgb_count = math.ceil(samples.shape[1] / 3)
channels_pad = rgb_count * 3 - samples.shape[1]
samples = torch.cat((samples, samples.new_zeros(samples.shape[0], channels_pad, *samples.shape[-2:])), dim=1)
samples = torch.cat(samples.chunk(rgb_count, dim=1), dim=2).movedim(1, -1)
if mode.startswith("bgr"):
samples = samples.flip(-1)
elif mode.startswith("brg"):
samples = samples.roll(-1, -1)
return (samples,)
class SplitOutLyricsNode:
DESCRIPTION = "Allows splitting out lyrics and lyrics strength from ACE-Steps CONDITIONING objects. Note that you will only be able to join it back again if it is the same shape."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("CONDITIONING","CONDITIONING_ACE_LYRICS")
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"conditioning": ("CONDITIONING",),
"add_fake_pooled": ("BOOLEAN", {"default": True}),
},
}
@classmethod
def go(cls, *, conditioning, add_fake_pooled) -> dict:
tags_result, lyrics_result = [], []
for cond_t, cond_d in conditioning:
cond_d = cond_d.copy()
cond_lyr = cond_d.pop("conditioning_lyrics", None)
cond_lyrstr = cond_d.pop("lyrics_strength", None)
if add_fake_pooled:
cond_d["pooled_output"] = cond_t.new_zeros(1, 1)
tags_result.append([cond_t.clone(), cond_d])
lyrics_result.append({"conditioning_lyrics": cond_lyr.clone(), "lyrics_strength": cond_lyrstr})
return (tags_result, lyrics_result)
class JoinLyricsNode:
DESCRIPTION = "Allows joining CONDITIONING_ACE_LYRICS back into CONDITIONING. Will overwrite any lyrics that exist. Must be the same shape as the conditioning the lyrics were split from."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("CONDITIONING",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"conditioning_tags": ("CONDITIONING",),
"conditioning_lyrics": ("CONDITIONING_ACE_LYRICS",),
},
}
@classmethod
def go(cls, *, conditioning_tags, conditioning_lyrics) -> dict:
ct_len, cl_len = len(conditioning_tags), len(conditioning_lyrics)
if ct_len != cl_len:
raise ValueError(f"Different lengths for tags {ct_len} vs conditioning lyrics {cl_len}")
if ct_len > 0 and conditioning_lyrics[0].get("conditioning_lyrics") is None:
raise ValueError("conditioning_lyrics missing items, cannot combine with it.")
result = [
[
cond_t.clone(),
cond_d.copy() | {
"conditioning_lyrics": cond_l["conditioning_lyrics"].clone(),
"lyrics_strength": cond_l["lyrics_strength"],
"pooled_output": None,
},
]
for (cond_t, cond_d), cond_l in zip(conditioning_tags, conditioning_lyrics)
]
return (result,)
class SetAudioDtypeNode:
DESCRIPTION = "Advanced node that allows the datatype of the audio waveform. The 16 and 8 bit types are not recommended."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
_ALLOWED_DTYPES = ("float64","float32", "float16", "bfloat16", "float8_e4m3fn", "float8_e5m2",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"audio": ("AUDIO",),
"dtype": (cls._ALLOWED_DTYPES, {"default": "float64", "tooltip": "TBD"}),
},
}
@classmethod
def go(cls, *, audio, dtype) -> dict:
if dtype not in cls._ALLOWED_DTYPES:
raise ValueError("Bad dtype")
waveform = audio["waveform"]
dt = getattr(torch, dtype)
if waveform.dtype == dt:
return (audio,)
return (audio | {"waveform": waveform.to(dtype=dt)},)
class AudioLevelsNode:
DESCRIPTION = "The values in the waveform range for -1 to 1. This node allows you to scale audio to a percentage of that range."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"audio": ("AUDIO",),
"scale": (
"FLOAT",
{
"default": 0.95,
"min": 0.0,
"max": 1.0,
"tooltip": "Percentage where 1.0 indicates 100% of the maximum allowed value in an audio tensor. You can use 1.0 to make it as loud as possible without actually clipping.",
},
),
"per_channel": (
"BOOLEAN",
{
"default": False,
"tooltip": "When enabled, the levels for each channel will be scaled independently. For multi-channel audio (like stereo) enabling this will not preserve the relative levels between the channels so probably should be left disabled most of the time.",
},
),
},
}
@classmethod
def go(cls, *, audio: dict, scale: float, per_channel: bool) -> tuple:
waveform = audio["waveform"].to(device="cpu", copy=True)
if waveform.ndim == 1:
waveform = waveform[None, None, ...]
elif waveform.ndim == 2:
waveform = waveform[None, ...]
elif waveform.ndim != 3:
raise ValueError("Unexpected number of dimensions in waveform!")
max_val = waveform.abs().flatten(start_dim=2 if per_channel else 1).max(dim=-1).values
max_val = max_val[..., None] if per_channel else max_val[..., None, None]
# Max could be 0, multiplying by 0 is fine in that case.
waveform *= (scale / max_val).nan_to_num()
return (audio | {"waveform": waveform.clamp(-1.0, 1.0)},)
class AudioAsLatentNode:
DESCRIPTION = "This node allows you to rearrange AUDIO to look like a LATENT. Can be useful if you want to apply some latent operations to AUDIO. Can be reversed with the ACETricks LatentAsAudio node."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("LATENT",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"audio": ("AUDIO",),
"use_width": (
"BOOLEAN",
{
"default": True,
"tooltip": "When enabled, you'll get a 4 channel with height 1 and the audio audio data in the width dimension, otherwise the opposite.",
},
),
},
}
@classmethod
def go(cls, *, audio: dict, use_width: bool) -> tuple:
waveform = audio["waveform"].to(device="cpu", copy=True)
if waveform.ndim == 1:
waveform = waveform[None, None, ...]
elif waveform.ndim == 2:
waveform = waveform[None, ...]
elif waveform.ndim != 3:
raise ValueError("Unexpected number of dimensions in waveform!")
waveform = waveform.unsqueeze(2) if use_width else waveform[..., None]
return ({"samples": waveform},)
class LatentAsAudioNode:
DESCRIPTION = "This node lets you rearrange a LATENT to look like AUDIO. Mainly useful for getting back after using the ACETricks AudioAsLatent node and performing some operations. If you connect the optional audio input it will use whatever non-waveform parameters exist in it (can be stuff like the sample rate), otherwise it will just add sample_rate: 41000 and the waveform."
FUNCTION = "go"
CATEGORY = "audio/acetricks"
RETURN_TYPES = ("AUDIO",)
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"latent": ("LATENT",),
"values_mode": (
("rescale", "clamp"),
{"default": "rescale"},
),
"use_width": (
"BOOLEAN",
{
"default": True,
"tooltip": "When enabled, takes the audio data from the first item in the width dimension, otherwise height.",
},
),
},
"optional": {
"audio_opt": (
"AUDIO",
{"tooltip": "Optional audio to use as a reference for sample rate and possibly other values."}
),
},
}
@classmethod
def go(cls, *, latent: dict, values_mode: str, use_width: bool, audio_opt: dict | None=None) -> tuple:
samples = latent["samples"]
if samples.ndim != 4:
raise ValueError("Expected a 4D latent but didn't get one")
samples = (samples[..., 0, :] if use_width else samples[..., 0]).to(device="cpu", copy=True)
if audio_opt is None:
audio_opt = {"sample_rate": 44100}
result = audio_opt | {"waveform": samples}
if values_mode == "clamp":
result["waveform"] = samples.clamp(-1.0, 1.0)
elif torch.any(samples.abs() > 1.0):
return AudioLevelsNode.go(audio=result, per_channel=False, scale=1.0)
return (result,)
NODE_CLASS_MAPPINGS = {
"ACETricks SilentLatent": SilentLatentNode,
"ACETricks VisualizeLatent": VisualizeLatentNode,
"ACETricks CondSplitOutLyrics": SplitOutLyricsNode,
"ACETricks CondJoinLyrics": JoinLyricsNode,
"ACETricks SetAudioDtype": SetAudioDtypeNode,
"ACETricks AudioLevels": AudioLevelsNode,
"ACETricks AudioAsLatent": AudioAsLatentNode,
"ACETricks LatentAsAudio": LatentAsAudioNode,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment