Last active
July 4, 2025 00:16
-
-
Save blepping/d0f6a26b1f59ed705999945821a3ee8a to your computer and use it in GitHub Desktop.
Some ComfyUI nodes for ACE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# By https://github.com/blepping | |
# License: Apache2 | |
# | |
# Place this file in your custom_nodes directory and it should load automatically. | |
import math | |
import torch | |
SILENCE = torch.tensor(( | |
(-0.6462, -1.2132, -1.3026, -1.2432, -1.2455, -1.2162, -1.2184, -1.2114, -1.2153, -1.2144, -1.2130, -1.2115, -1.2063, -1.1918, -1.1154, -0.7924), | |
( 0.0473, -0.3690, -0.6507, -0.5677, -0.6139, -0.5863, -0.5783, -0.5746, -0.5748, -0.5763, -0.5774, -0.5760, -0.5714, -0.5560, -0.5393, -0.3263), | |
(-1.3019, -1.9225, -2.0812, -2.1188, -2.1298, -2.1227, -2.1080, -2.1133, -2.1096, -2.1077, -2.1118, -2.1141, -2.1168, -2.1134, -2.0720, -1.7442), | |
(-4.4184, -5.5253, -5.7387, -5.7961, -5.7819, -5.7850, -5.7980, -5.8083, -5.8197, -5.8202, -5.8231, -5.8305, -5.8313, -5.8153, -5.6875, -4.7317), | |
( 1.5986, 2.0669, 2.0660, 2.0476, 2.0330, 2.0271, 2.0252, 2.0268, 2.0289, 2.0260, 2.0261, 2.0252, 2.0240, 2.0220, 1.9828, 1.6429), | |
(-0.4177, -0.9632, -1.0095, -1.0597, -1.0462, -1.0640, -1.0607, -1.0604, -1.0641, -1.0636, -1.0631, -1.0594, -1.0555, -1.0466, -1.0139, -0.8284), | |
(-0.7686, -1.0507, -1.3932, -1.4880, -1.5199, -1.5377, -1.5333, -1.5320, -1.5307, -1.5319, -1.5360, -1.5383, -1.5398, -1.5381, -1.4961, -1.1732), | |
( 0.0199, -0.0880, -0.4010, -0.3936, -0.4219, -0.4026, -0.3907, -0.3940, -0.3961, -0.3947, -0.3941, -0.3929, -0.3889, -0.3741, -0.3432, -0.169), | |
), dtype=torch.float32, device="cpu")[None, ..., None] | |
BLEND_MODES = None | |
def _ensure_blend_modes(): | |
global BLEND_MODES | |
if BLEND_MODES is None: | |
bi = sys.modules.get("_blepping_integrations", {}) or getattr( | |
nodes, | |
"_blepping_integrations", | |
{}, | |
) | |
bleh = bi.get("bleh") | |
if bleh is not None: | |
BLEND_MODES = bleh.py.latent_utils.BLENDING_MODES | |
else: | |
BLEND_MODES = { | |
"lerp": torch.lerp, | |
"a_only": lambda a, _b, _t: a, | |
"b_only": lambda _a, b, _t: b, | |
"subtract_b": lambda a, b, t: a - b * t, | |
} | |
def normalize_to_scale(latent, target_min, target_max, *, dim=(-3, -2, -1)): | |
min_val, max_val = ( | |
latent.amin(dim=dim, keepdim=True), | |
latent.amax(dim=dim, keepdim=True), | |
) | |
normalized = (latent - min_val).div_(max_val - min_val) | |
return ( | |
normalized.mul_(target_max - target_min) | |
.add_(target_min) | |
.clamp_(target_min, target_max) | |
) | |
TEMPORAL_SCALE_FACTOR = 44100 / 512 / 8 | |
class SilentLatentNode: | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("LATENT",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1, "tooltip": "Number of seconds to generate. Ignored if optional latent input is connected."}), | |
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "Batch size to generate. Ignored if optional latent input is connected."}), | |
}, | |
"optional": { | |
"ref_latent_opt": ("LATENT", {"tooltip": "When connected the other parameters are ignored and the latent output will match the length/batch size of the reference."}), | |
}, | |
} | |
@classmethod | |
def go(cls, *, seconds: float, batch_size: int, ref_latent_opt=None) -> dict: | |
if ref_latent_opt is not None: | |
latent = torch.zeros(ref_latent_opt["samples"].shape, device="cpu", dtype=torch.float32) | |
else: | |
length = int(seconds * TEMPORAL_SCALE_FACTOR) | |
latent = torch.zeros(batch_size, 8, 16, length, device="cpu", dtype=torch.float32) | |
latent += SILENCE | |
return ({"samples": latent, "type": "audio"},) | |
class VisualizeLatentNode: | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("IMAGE",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"latent": ("LATENT",), | |
"scale_secs": ( | |
"INT", | |
{ | |
"default": 0, "min": 0, "max": 1000, | |
"tooltip": "Horizontal scale. Number of pixels that corresponds to one second of audio. You can use 0 for no scaling which is roughly 11 pixels per second.", | |
}, | |
), | |
"scale_vertical": ( | |
"INT", | |
{ | |
"default": 1, | |
"min": 1, | |
"max": 1024, | |
"tooltip": "Pixel expansion factor for channels (or frequency bands if you have swap_channels_freqs mode enabled).", | |
}, | |
), | |
"swap_channels_freqs": ( | |
"BOOLEAN", | |
{ | |
"default": False, | |
"tooltip": "Swaps the order of channels and frequency in the vertical dimension. When enabled, scale_vertical applies to frequency bands.", | |
}, | |
), | |
"normalize_dims": ( | |
"STRING", | |
{ | |
"default": "-1", | |
"tooltip": "Dimensions the latent scale is normalized using. Must be a comma-separated list. The default setting normalizes the channels and frequency bands independently per batch, you can try -3, -2, -1 if you want to see the relative differences better.", | |
}, | |
), | |
"mode": ( | |
("split", "combined", "brg", "rgb", "bgr", "split_flip", "combined_flip", "brg_flip", "rgb_flip", "bgr_flip"), { | |
"default": "split", | |
"tooltip": "Split shows a monochrome view of of each channel/freq, combined shows the average. Flip means invert the energy in the channel (i.e. white -> black). The other modes put the latent channels into the RGB channels of the preview image.", | |
}, | |
), | |
}, | |
} | |
@classmethod | |
def go(cls, *, latent, scale_secs, scale_vertical, swap_channels_freqs, normalize_dims, mode) -> tuple: | |
normalize_dims = normalize_dims.strip() | |
normalize_dims = () if not normalize_dims else tuple(int(dim) for dim in normalize_dims.split(",")) | |
samples = latent["samples"].to(dtype=torch.float32, device="cpu") | |
if samples.ndim != 4: | |
raise ValueError("Expected an ACE-Steps latent with 4 dimensions") | |
color_mode = mode not in {"split", "combined", "split_flip", "combined_flip"} | |
batch, channels, freqs, temporal = samples.shape | |
samples = normalize_to_scale(samples, 0.0, 1.0, dim=normalize_dims) | |
if mode.endswith("_flip"): | |
samples = 1.0 - samples | |
if swap_channels_freqs: | |
samples = samples.movedim(2, 1) | |
if mode.startswith("combined"): | |
samples = samples.mean(dim=1, keepdim=True) | |
if scale_vertical != 1: | |
samples = samples.repeat_interleave(scale_vertical, dim=2) | |
if not color_mode: | |
samples = samples.reshape(batch, -1, temporal) | |
if scale_secs > 0: | |
new_temporal = round((temporal / TEMPORAL_SCALE_FACTOR) * scale_secs) | |
samples = torch.nn.functional.interpolate( | |
samples.unsqueeze(1) if not color_mode else samples, | |
size=(samples.shape[-2], new_temporal), | |
mode="nearest-exact", | |
) | |
if not color_mode: | |
samples = samples.squeeze(1) | |
if not color_mode: | |
return (samples[..., None].expand(*samples.shape, 3),) | |
rgb_count = math.ceil(samples.shape[1] / 3) | |
channels_pad = rgb_count * 3 - samples.shape[1] | |
samples = torch.cat((samples, samples.new_zeros(samples.shape[0], channels_pad, *samples.shape[-2:])), dim=1) | |
samples = torch.cat(samples.chunk(rgb_count, dim=1), dim=2).movedim(1, -1) | |
if mode.startswith("bgr"): | |
samples = samples.flip(-1) | |
elif mode.startswith("brg"): | |
samples = samples.roll(-1, -1) | |
return (samples,) | |
class SplitOutLyricsNode: | |
DESCRIPTION = "Allows splitting out lyrics and lyrics strength from ACE-Steps CONDITIONING objects. Note that you will only be able to join it back again if it is the same shape." | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("CONDITIONING","CONDITIONING_ACE_LYRICS") | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"conditioning": ("CONDITIONING",), | |
"add_fake_pooled": ("BOOLEAN", {"default": True}), | |
}, | |
} | |
@classmethod | |
def go(cls, *, conditioning, add_fake_pooled) -> dict: | |
tags_result, lyrics_result = [], [] | |
for cond_t, cond_d in conditioning: | |
cond_d = cond_d.copy() | |
cond_lyr = cond_d.pop("conditioning_lyrics", None) | |
cond_lyrstr = cond_d.pop("lyrics_strength", None) | |
if add_fake_pooled: | |
cond_d["pooled_output"] = cond_t.new_zeros(1, 1) | |
tags_result.append([cond_t.clone(), cond_d]) | |
lyrics_result.append({"conditioning_lyrics": cond_lyr.clone(), "lyrics_strength": cond_lyrstr}) | |
return (tags_result, lyrics_result) | |
class JoinLyricsNode: | |
DESCRIPTION = "Allows joining CONDITIONING_ACE_LYRICS back into CONDITIONING. Will overwrite any lyrics that exist. Must be the same shape as the conditioning the lyrics were split from." | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("CONDITIONING",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"conditioning_tags": ("CONDITIONING",), | |
"conditioning_lyrics": ("CONDITIONING_ACE_LYRICS",), | |
}, | |
} | |
@classmethod | |
def go(cls, *, conditioning_tags, conditioning_lyrics) -> dict: | |
ct_len, cl_len = len(conditioning_tags), len(conditioning_lyrics) | |
if ct_len != cl_len: | |
raise ValueError(f"Different lengths for tags {ct_len} vs conditioning lyrics {cl_len}") | |
if ct_len > 0 and conditioning_lyrics[0].get("conditioning_lyrics") is None: | |
raise ValueError("conditioning_lyrics missing items, cannot combine with it.") | |
result = [ | |
[ | |
cond_t.clone(), | |
cond_d.copy() | { | |
"conditioning_lyrics": cond_l["conditioning_lyrics"].clone(), | |
"lyrics_strength": cond_l["lyrics_strength"], | |
"pooled_output": None, | |
}, | |
] | |
for (cond_t, cond_d), cond_l in zip(conditioning_tags, conditioning_lyrics) | |
] | |
return (result,) | |
class SetAudioDtypeNode: | |
DESCRIPTION = "Advanced node that allows the datatype of the audio waveform. The 16 and 8 bit types are not recommended." | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("AUDIO",) | |
_ALLOWED_DTYPES = ("float64","float32", "float16", "bfloat16", "float8_e4m3fn", "float8_e5m2",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"audio": ("AUDIO",), | |
"dtype": (cls._ALLOWED_DTYPES, {"default": "float64", "tooltip": "TBD"}), | |
}, | |
} | |
@classmethod | |
def go(cls, *, audio, dtype) -> dict: | |
if dtype not in cls._ALLOWED_DTYPES: | |
raise ValueError("Bad dtype") | |
waveform = audio["waveform"] | |
dt = getattr(torch, dtype) | |
if waveform.dtype == dt: | |
return (audio,) | |
return (audio | {"waveform": waveform.to(dtype=dt)},) | |
class AudioLevelsNode: | |
DESCRIPTION = "The values in the waveform range for -1 to 1. This node allows you to scale audio to a percentage of that range." | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("AUDIO",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"audio": ("AUDIO",), | |
"scale": ( | |
"FLOAT", | |
{ | |
"default": 0.95, | |
"min": 0.0, | |
"max": 1.0, | |
"tooltip": "Percentage where 1.0 indicates 100% of the maximum allowed value in an audio tensor. You can use 1.0 to make it as loud as possible without actually clipping.", | |
}, | |
), | |
"per_channel": ( | |
"BOOLEAN", | |
{ | |
"default": False, | |
"tooltip": "When enabled, the levels for each channel will be scaled independently. For multi-channel audio (like stereo) enabling this will not preserve the relative levels between the channels so probably should be left disabled most of the time.", | |
}, | |
), | |
}, | |
} | |
@classmethod | |
def go(cls, *, audio: dict, scale: float, per_channel: bool) -> tuple: | |
waveform = audio["waveform"].to(device="cpu", copy=True) | |
if waveform.ndim == 1: | |
waveform = waveform[None, None, ...] | |
elif waveform.ndim == 2: | |
waveform = waveform[None, ...] | |
elif waveform.ndim != 3: | |
raise ValueError("Unexpected number of dimensions in waveform!") | |
max_val = waveform.abs().flatten(start_dim=2 if per_channel else 1).max(dim=-1).values | |
max_val = max_val[..., None] if per_channel else max_val[..., None, None] | |
# Max could be 0, multiplying by 0 is fine in that case. | |
waveform *= (scale / max_val).nan_to_num() | |
return (audio | {"waveform": waveform.clamp(-1.0, 1.0)},) | |
class AudioAsLatentNode: | |
DESCRIPTION = "This node allows you to rearrange AUDIO to look like a LATENT. Can be useful if you want to apply some latent operations to AUDIO. Can be reversed with the ACETricks LatentAsAudio node." | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("LATENT",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"audio": ("AUDIO",), | |
"use_width": ( | |
"BOOLEAN", | |
{ | |
"default": True, | |
"tooltip": "When enabled, you'll get a 4 channel with height 1 and the audio audio data in the width dimension, otherwise the opposite.", | |
}, | |
), | |
}, | |
} | |
@classmethod | |
def go(cls, *, audio: dict, use_width: bool) -> tuple: | |
waveform = audio["waveform"].to(device="cpu", copy=True) | |
if waveform.ndim == 1: | |
waveform = waveform[None, None, ...] | |
elif waveform.ndim == 2: | |
waveform = waveform[None, ...] | |
elif waveform.ndim != 3: | |
raise ValueError("Unexpected number of dimensions in waveform!") | |
waveform = waveform.unsqueeze(2) if use_width else waveform[..., None] | |
return ({"samples": waveform},) | |
class LatentAsAudioNode: | |
DESCRIPTION = "This node lets you rearrange a LATENT to look like AUDIO. Mainly useful for getting back after using the ACETricks AudioAsLatent node and performing some operations. If you connect the optional audio input it will use whatever non-waveform parameters exist in it (can be stuff like the sample rate), otherwise it will just add sample_rate: 41000 and the waveform." | |
FUNCTION = "go" | |
CATEGORY = "audio/acetricks" | |
RETURN_TYPES = ("AUDIO",) | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"latent": ("LATENT",), | |
"values_mode": ( | |
("rescale", "clamp"), | |
{"default": "rescale"}, | |
), | |
"use_width": ( | |
"BOOLEAN", | |
{ | |
"default": True, | |
"tooltip": "When enabled, takes the audio data from the first item in the width dimension, otherwise height.", | |
}, | |
), | |
}, | |
"optional": { | |
"audio_opt": ( | |
"AUDIO", | |
{"tooltip": "Optional audio to use as a reference for sample rate and possibly other values."} | |
), | |
}, | |
} | |
@classmethod | |
def go(cls, *, latent: dict, values_mode: str, use_width: bool, audio_opt: dict | None=None) -> tuple: | |
samples = latent["samples"] | |
if samples.ndim != 4: | |
raise ValueError("Expected a 4D latent but didn't get one") | |
samples = (samples[..., 0, :] if use_width else samples[..., 0]).to(device="cpu", copy=True) | |
if audio_opt is None: | |
audio_opt = {"sample_rate": 44100} | |
result = audio_opt | {"waveform": samples} | |
if values_mode == "clamp": | |
result["waveform"] = samples.clamp(-1.0, 1.0) | |
elif torch.any(samples.abs() > 1.0): | |
return AudioLevelsNode.go(audio=result, per_channel=False, scale=1.0) | |
return (result,) | |
NODE_CLASS_MAPPINGS = { | |
"ACETricks SilentLatent": SilentLatentNode, | |
"ACETricks VisualizeLatent": VisualizeLatentNode, | |
"ACETricks CondSplitOutLyrics": SplitOutLyricsNode, | |
"ACETricks CondJoinLyrics": JoinLyricsNode, | |
"ACETricks SetAudioDtype": SetAudioDtypeNode, | |
"ACETricks AudioLevels": AudioLevelsNode, | |
"ACETricks AudioAsLatent": AudioAsLatentNode, | |
"ACETricks LatentAsAudio": LatentAsAudioNode, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment