Skip to content

Instantly share code, notes, and snippets.

@greggman
Last active July 24, 2025 18:25
Show Gist options
  • Save greggman/e73f3da8c8824c4a55ae5189164e60fb to your computer and use it in GitHub Desktop.
Save greggman/e73f3da8c8824c4a55ae5189164e60fb to your computer and use it in GitHub Desktop.
WebGPU: Compute Shader Mipmap Generation (v1b)
:root {
color-scheme: light dark;
}
canvas, img {
border: 1px solid gray;
margin: 5px;
max-width: 128px;
width: 128px;
image-rendering: pixelated;
}
/*bug-in-github-api-content-can-not-be-empty*/
import RollingAverage from 'https://webgpufundamentals.org/webgpu/resources/js/rolling-average.js';
import TimingHelper from 'https://webgpufundamentals.org/webgpu/resources/js/timing-helper.js';
async function main() {
const adapter = await navigator.gpu.requestAdapter();
const hasTiming = false; //adapter.features.has('timestamp-query');
const requiredFeatures = [
...(hasTiming ? ['timestamp-query'] : []),
];
const device = await adapter.requestDevice({ requiredFeatures });
device.addEventListener('uncapturederror', e => console.error(e.error.message));
const r = [255, 0, 0, 255];
const y = [255, 255, 0, 255];
const g = [ 0, 255, 0, 255];
const c = [ 0, 255, 255, 255];
const b = [ 0, 0, 255, 255];
const m = [255, 0, 255, 255];
const colors = [r, y, g, c, b, m]
const tests = [
//makeCanvas(4096, 4096),
//makeCanvas(4095, 4093),
makeSmallCanvas(7),
...range(63, i => {
const s = 64 - i;
return makeTestData(s, s, colors);
}),
];
for (const canvas of tests) {
await test(device, canvas);
}
}
const hsl = (h, s, l) => `hsl(${h * 360 | 0}, ${s * 100}%, ${l * 100 | 0}%)`;
function makeTestData(w, h, colors) {
const data = new Uint8ClampedArray(w * h * 4);
for (let y = 0; y < h; ++y) {
for (let x = 0; x < w; ++x) {
data.set(colors[(x + y) % colors.length], (y * h + x) * 4);
}
}
return new ImageData(data, w, h);
}
function makeSmallCanvas(w) {
const canvas = new OffscreenCanvas(w, 1);
const ctx = canvas.getContext('2d');
for (let x = 0; x < w; ++x) {
ctx.fillStyle = hsl(x / w, 1, 0.5);
ctx.fillRect(x, 0, 1, 1);
}
return canvas;
}
function makeCanvas(w, h) {
const canvas = new OffscreenCanvas(w, h);
const ctx = canvas.getContext('2d');
const hue = (w + h) * 0.1434;
ctx.fillStyle = hsl(hue, 1, 0.75);
ctx.fillRect(0, 0, w, h);
ctx.lineWidth = 32;
ctx.stokeStyle = hsl(hue + 0.33, 1, 0.25);
ctx.strokeRect(0, 0, w, h);
ctx.font = `bold ${Math.min(w, h) * 0.9 | 0}px monospace`;
ctx.textAlign = 'center';
ctx.textBaseline = 'middle';
ctx.fillStyle = hsl(hue + 0.5, 1, 0.25);;
ctx.fillText('F', w / 2, h / 2);
return canvas;
}
let computePassTimingHelper;
let savedTimingHelpers = [];
async function test(device, canvas) {
const hasTiming = device.features.has('timestamp-query');
computePassTimingHelper = computePassTimingHelper ?? new TimingHelper(device);
const { width, height } = canvas;
const texture = device.createTexture({
format: 'rgba8unorm',
size: [width, height],
usage: GPUTextureUsage.STORAGE_BINDING |
GPUTextureUsage.TEXTURE_BINDING |
GPUTextureUsage.RENDER_ATTACHMENT |
GPUTextureUsage.COPY_DST |
GPUTextureUsage.COPY_SRC,
mipLevelCount: numMipLevels(width, height),
});
device.queue.copyExternalImageToTexture(
{ source: canvas },
{ texture },
[ width, height ],
);
console.log(`\ntexture: ${width}x${height}, mipLevelCount: ${texture.mipLevelCount} `)
const dummyTextureViews = range(3, () => device.createTexture({
format: 'rgba8unorm',
size: [1],
usage: GPUTextureUsage.STORAGE_BINDING
}).createView());
const module = device.createShaderModule({
code: `
@group(0) @binding(0) var smp: sampler;
@group(0) @binding(1) var mip0: texture_2d<f32>;
@group(0) @binding(2) var mip1: texture_storage_2d<rgba8unorm, write>;
@group(0) @binding(3) var mip2: texture_storage_2d<rgba8unorm, write>;
@group(0) @binding(4) var mip3: texture_storage_2d<rgba8unorm, write>;
@group(0) @binding(5) var mip4: texture_storage_2d<rgba8unorm, write>;
var<workgroup> texels1: array<array<vec4f, 8>, 8>;
var<workgroup> texels2: array<array<vec4f, 4>, 4>;
var<workgroup> texels3: array<array<vec4f, 2>, 2>;
// It doesn't seem like we need to check bounds. We bind a dummy texture.
// for each mip level not used. We use textureSampleLevel for the top level
// so it will clamp-to-edge. textureStore is speced to "not execute" if
// out of bounds.
fn processMip0ToMip1(blockXY: vec2u, lid: vec2u) {
let mip1TexelXY = blockXY * 8 + lid.xy;
let texelNdx = lid.xy;
let mip1Size = textureDimensions(mip1);
let uv = (vec2f(mip1TexelXY) + 0.5) / vec2f(mip1Size);
let c = textureSampleLevel(mip0, smp, uv, 0.0);
texels1[texelNdx.y][texelNdx.x] = c;
textureStore(mip1, mip1TexelXY, c);
}
@compute @workgroup_size(8, 8) fn cs(
@builtin(local_invocation_id) lid: vec3u,
@builtin(workgroup_id) wid: vec3u,
) {
let blockXY = wid.xy;
// generate mip1 from mip0
processMip0ToMip1(blockXY, lid.xy);
workgroupBarrier();
// generate mip2 from mip1
if (lid.x < 4 && lid.y < 4) {
let mip2TexelXY = blockXY * 4 + lid.xy;
let texelNdx = lid.xy;
let srcNdx = texelNdx * 2;
let c0 = texels1[srcNdx.y ][srcNdx.x ];
let c1 = texels1[srcNdx.y ][srcNdx.x + 1];
let c2 = texels1[srcNdx.y + 1][srcNdx.x ];
let c3 = texels1[srcNdx.y + 1][srcNdx.x + 1];
let c = mix(mix(c0, c1, 0.5), mix(c2, c3, 0.5), 0.5);
texels2[texelNdx.y][texelNdx.x] = c;
textureStore(mip2, mip2TexelXY, c);
}
workgroupBarrier();
// generate mip3 from mip2
if (lid.x < 2 && lid.y < 2) {
let mip3TexelXY = blockXY * 2 + lid.xy;
let texelNdx = lid.xy;
let srcNdx = texelNdx * 2;
let c0 = texels2[srcNdx.y ][srcNdx.x ];
let c1 = texels2[srcNdx.y ][srcNdx.x + 1];
let c2 = texels2[srcNdx.y + 1][srcNdx.x ];
let c3 = texels2[srcNdx.y + 1][srcNdx.x + 1];
let c = mix(mix(c0, c1, 0.5), mix(c2, c3, 0.5), 0.5);
texels3[texelNdx.y][texelNdx.x] = c;
textureStore(mip3, mip3TexelXY, c);
}
workgroupBarrier();
// generate mip4 from mip3
if (lid.x < 1 && lid.y < 1) {
let mip4TexelXY = blockXY + lid.xy;
let texelNdx = lid.xy;
let srcNdx = texelNdx * 2;
let c0 = texels3[srcNdx.y ][srcNdx.x ];
let c1 = texels3[srcNdx.y ][srcNdx.x + 1];
let c2 = texels3[srcNdx.y + 1][srcNdx.x ];
let c3 = texels3[srcNdx.y + 1][srcNdx.x + 1];
let c = mix(mix(c0, c1, 0.5), mix(c2, c3, 0.5), 0.5);
textureStore(mip4, mip4TexelXY, c);
}
}
`,
});
const pipeline = device.createComputePipeline({
layout: 'auto',
compute: { module },
});
const sampler = device.createSampler({
minFilter: 'linear',
magFilter: 'linear',
});
const encoder = device.createCommandEncoder();
const pass = computePassTimingHelper.beginComputePass(encoder);
pass.setPipeline(pipeline);
// 0: 4096 ---0
// 1: 2048 ---1
// 2: 1024 ---2
// 3: 512 ---3
// 4: 256 ---4 ---0
// 5: 128 ---1
// 6: 64 ---2
// 7: 32 ---3
// 8: 16 ---4 ---0
// 9: 8 ---1
//10: 4 ---2
//11: 2 ---3
//12: 1 ---4
const getMipLevelView = (texture, baseMipLevel, dummyNdx) => baseMipLevel < texture.mipLevelCount
? texture.createView({ baseMipLevel: baseMipLevel, mipLevelCount: 1 })
: dummyTextureViews[dummyNdx];
const kLevelsPerPass = 4;
const kBlockSize = 16;
const numMipLevelsToGenerate = texture.mipLevelCount - 1;
for (let baseMipLevel = 0; baseMipLevel < numMipLevelsToGenerate; baseMipLevel += kLevelsPerPass) {
const levelSize = [
Math.max(1, texture.width >> baseMipLevel),
Math.max(1, texture.height >> baseMipLevel),
];
const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: sampler },
{ binding: 1, resource: texture.createView({ baseMipLevel: baseMipLevel, mipLevelCount: 1 }) },
{ binding: 2, resource: texture.createView({ baseMipLevel: baseMipLevel + 1, mipLevelCount: 1 }) },
{ binding: 3, resource: getMipLevelView(texture, baseMipLevel + 2, 0) },
{ binding: 4, resource: getMipLevelView(texture, baseMipLevel + 3, 1) },
{ binding: 5, resource: getMipLevelView(texture, baseMipLevel + 4, 2) },
],
});
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(...levelSize.map(size => Math.ceil(size / kBlockSize)));
}
pass.end();
device.queue.submit([encoder.finish()]);
const computeNs = await computePassTimingHelper.getResult();
if (hasTiming) {
console.log(`compute speed: ${(computeNs / 1000).toFixed(1)}µs`);
}
log(`${texture.width}x${texture.height} via compute-pass`)
showMips(device, texture);
document.body.append(document.createElement('hr'));
const renderNs = await generateMips(device, texture);
if (hasTiming) {
console.log(`render speed: ${(renderNs / 1000).toFixed(1)}µs`);
}
log(`${texture.width}x${texture.height} via render-pass`)
showMips(device, texture);
document.body.append(document.createElement('hr'));
}
function log(...args) {
const elem = document.createElement('pre');
elem.textContent = args.join(' ');
document.body.appendChild(elem);
}
// -------------
let showMipsCanvas;
let showMipsContext;
function showMips(device, texture) {
const bigCSS = Math.max(texture.width, texture.height) <= 24;
for (let mipLevel = 0; mipLevel < texture.mipLevelCount; ++mipLevel) {
const width = Math.max(1, texture.width >> mipLevel);
const height = Math.max(1, texture.height >> mipLevel);
showMipsCanvas = showMipsCanvas ?? document.createElement('canvas');
showMipsCanvas.width = width;
showMipsCanvas.height = height;
showMipsContext = showMipsContext ?? showMipsCanvas.getContext('webgpu');
showMipsContext.configure({
device,
format: 'rgba8unorm',
usage: GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.RENDER_ATTACHMENT | GPUTextureUsage.COPY_DST,
});
const encoder = device.createCommandEncoder();
encoder.copyTextureToTexture(
{ texture, mipLevel },
{ texture:showMipsContext.getCurrentTexture() },
[ width, height ],
);
device.queue.submit([encoder.finish()]);
const img = new Image();
if (bigCSS) {
img.style.imageRendering = 'pixelated';
//img.style.width = `${width * 16}px`;
}
showMipsCanvas.toBlob(blob => {
img.src = URL.createObjectURL(blob);
});
document.body.append(img);
}
}
/**
* Get the default viewDimension
* Note: It's only a guess. The user needs to tell us to be
* correct in all cases because we can't distinguish between
* a 2d texture and a 2d-array texture with 1 layer, nor can
* we distinguish between a 2d-array texture with 6 layers and
* a cubemap.
*/
function getDefaultViewDimensionForTexture(dimension, depthOrArrayLayers) {
switch (dimension) {
case '1d':
return '1d';
default:
case '2d':
return depthOrArrayLayers > 1 ? '2d-array' : '2d';
case '3d':
return '3d';
}
}
const numMipLevels = (...sizes) => {
const maxSize = Math.max(...sizes);
return 1 + Math.log2(maxSize) | 0;
};
const generateMips = (() => {
let sampler;
let module;
const pipelineByFormatAndView = {};
return async function generateMips(device, texture, textureBindingViewDimension) {
textureBindingViewDimension = textureBindingViewDimension ??
getDefaultViewDimensionForTexture(texture.dimension, texture.depthOrArrayLayers);
if (!module) {
module = device.createShaderModule({
label: 'textured quad shaders for mip level generation',
code: `
const faceMat = array(
mat3x3f( 0, 0, -2, 0, -2, 0, 1, 1, 1), // pos-x
mat3x3f( 0, 0, 2, 0, -2, 0, -1, 1, -1), // neg-x
mat3x3f( 2, 0, 0, 0, 0, 2, -1, 1, -1), // pos-y
mat3x3f( 2, 0, 0, 0, 0, -2, -1, -1, 1), // neg-y
mat3x3f( 2, 0, 0, 0, -2, 0, -1, 1, 1), // pos-z
mat3x3f(-2, 0, 0, 0, -2, 0, 1, 1, -1)); // neg-z
struct VSOutput {
@builtin(position) position: vec4f,
@location(0) texcoord: vec2f,
@location(1) @interpolate(flat, either) baseArrayLayer: u32,
};
@vertex fn vs(
@builtin(vertex_index) vertexIndex : u32,
@builtin(instance_index) baseArrayLayer: u32,
) -> VSOutput {
var pos = array<vec2f, 3>(
vec2f(-1.0, -1.0),
vec2f(-1.0, 3.0),
vec2f( 3.0, -1.0),
);
var vsOutput: VSOutput;
let xy = pos[vertexIndex];
vsOutput.position = vec4f(xy, 0.0, 1.0);
vsOutput.texcoord = xy * vec2f(0.5, -0.5) + vec2f(0.5);
vsOutput.baseArrayLayer = baseArrayLayer;
return vsOutput;
}
@group(0) @binding(0) var ourSampler: sampler;
@group(0) @binding(1) var ourTexture2d: texture_2d<f32>;
@fragment fn fs2d(fsInput: VSOutput) -> @location(0) vec4f {
return textureSample(ourTexture2d, ourSampler, fsInput.texcoord);
}
@group(0) @binding(1) var ourTexture2dArray: texture_2d_array<f32>;
@fragment fn fs2darray(fsInput: VSOutput) -> @location(0) vec4f {
return textureSample(
ourTexture2dArray,
ourSampler,
fsInput.texcoord,
fsInput.baseArrayLayer);
}
@group(0) @binding(1) var ourTextureCube: texture_cube<f32>;
@fragment fn fscube(fsInput: VSOutput) -> @location(0) vec4f {
return textureSample(
ourTextureCube,
ourSampler,
faceMat[fsInput.baseArrayLayer] * vec3f(fract(fsInput.texcoord), 1));
}
@group(0) @binding(1) var ourTextureCubeArray: texture_cube_array<f32>;
@fragment fn fscubearray(fsInput: VSOutput) -> @location(0) vec4f {
return textureSample(
ourTextureCubeArray,
ourSampler,
faceMat[fsInput.baseArrayLayer] * vec3f(fract(fsInput.texcoord), 1), fsInput.baseArrayLayer);
}
`,
});
sampler = device.createSampler({
minFilter: 'linear',
magFilter: 'linear',
});
}
const id = `${texture.format}.${textureBindingViewDimension}`;
if (!pipelineByFormatAndView[id]) {
// chose an fragment shader based on the viewDimension (removes the '-' from 2d-array and cube-array)
const entryPoint = `fs${textureBindingViewDimension.replace(/[\W]/, '')}`;
pipelineByFormatAndView[id] = device.createRenderPipeline({
label: `mip level generator pipeline for ${textureBindingViewDimension}, format: ${texture.format}`,
layout: 'auto',
vertex: {
module,
},
fragment: {
module,
entryPoint,
targets: [{ format: texture.format }],
},
});
}
const pipeline = pipelineByFormatAndView[id];
const timingHelpers = [];
const encoder = device.createCommandEncoder({
label: 'mip gen encoder',
});
for (let baseMipLevel = 1; baseMipLevel < texture.mipLevelCount; ++baseMipLevel) {
for (let layer = 0; layer < texture.depthOrArrayLayers; ++layer) {
const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: sampler },
{
binding: 1,
resource: texture.createView({
dimension: textureBindingViewDimension,
baseMipLevel: baseMipLevel - 1,
mipLevelCount: 1,
}),
},
],
});
const renderPassDescriptor = {
label: 'our basic canvas renderPass',
colorAttachments: [
{
view: texture.createView({
dimension: '2d',
baseMipLevel,
mipLevelCount: 1,
baseArrayLayer: layer,
arrayLayerCount: 1,
}),
loadOp: 'clear',
storeOp: 'store',
},
],
};
const timingHelper = savedTimingHelpers[timingHelpers.length] ?? new TimingHelper(device);
savedTimingHelpers[timingHelpers.length] = timingHelper;
timingHelpers.push(timingHelper);
const pass = timingHelper.beginRenderPass(encoder, renderPassDescriptor);
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
// draw 3 vertices, 1 instance, first instance (instance_index) = layer
pass.draw(3, 1, 0, layer);
pass.end();
}
}
const commandBuffer = encoder.finish();
device.queue.submit([commandBuffer]);
const ns = await Promise.all(timingHelpers.map(t => t.getResult()));
const totalNs = ns.reduce((a, b) => a + b);
return totalNs;
};
})();
const range = (num, fn) => new Array(num).fill(0).map((_, i) => fn(i));
main();
{"name":"WebGPU: Compute Shader Mipmap Generation (v1b)","settings":{},"filenames":["index.html","index.css","index.js"]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment