Skip to content

Instantly share code, notes, and snippets.

@sevagh
Created May 6, 2025 01:46
Show Gist options
  • Save sevagh/23209e01b906bf25cccb4532fb3fd81f to your computer and use it in GitHub Desktop.
Save sevagh/23209e01b906bf25cccb4532fb3fd81f to your computer and use it in GitHub Desktop.
Replace large Concat operator with multiple concat operators to bypass WebGPU buffer limitation
import onnx
from onnx import helper, numpy_helper, shape_inference
import uuid
INPUT_LIMIT = 4 # conservative, to stay within WebGPU's buffer cap
def split_concat_node(model, node):
new_nodes = []
input_chunks = [node.input[i:i+INPUT_LIMIT] for i in range(0, len(node.input), INPUT_LIMIT)]
intermediate_outputs = []
for i, chunk in enumerate(input_chunks):
inter_output = f"{node.name}_chunk{i}_{uuid.uuid4().hex[:8]}"
concat_node = helper.make_node(
'Concat',
inputs=chunk,
outputs=[inter_output],
name=f"{node.name}_chunk{i}",
axis=int([attr.i for attr in node.attribute if attr.name == "axis"][0])
)
new_nodes.append(concat_node)
intermediate_outputs.append(inter_output)
final_output = node.output[0]
final_concat = helper.make_node(
'Concat',
inputs=intermediate_outputs,
outputs=[final_output],
name=f"{node.name}_final",
axis=int([attr.i for attr in node.attribute if attr.name == "axis"][0])
)
new_nodes.append(final_concat)
return new_nodes
def patch_model(model_path, output_path):
model = onnx.load(model_path)
graph = model.graph
new_nodes = []
for node in graph.node:
if node.op_type == 'Concat' and len(node.input) > INPUT_LIMIT:
new_nodes.extend(split_concat_node(model, node))
else:
new_nodes.append(node)
# Replace old nodes with new node list
graph.ClearField('node')
graph.node.extend(new_nodes)
# Optional: Infer shapes for debugging/verification
model = shape_inference.infer_shapes(model)
onnx.save(model, output_path)
print(f"Saved patched model to: {output_path}")
if __name__ == "__main__":
patch_model("./onnx-models/basicpitch/model.onnx", "./onnx-models/basicpitch/model_patched_webgpu.onnx")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment