Hello world¶

The code below draws a (hard-coded) triangle to the screen and contains the bare minimum to draw something with custom WebGPU shaders.

In [1]:
import webgpu.jupyter

shader_code = """

// Data structure which is the output of the vertex shader and the input of the fragment shader
struct FragmentInput {
    @builtin(position) p: vec4<f32>,
    @location(0) color: vec4<f32>,
};

// Vertex shader, returns a FragmentInput object
@vertex
fn vertex_main(
  @builtin(vertex_index) vertex_index : u32
) -> FragmentInput {

  var pos = array<vec4f, 3>(
    vec4f( 0.0,  0.5, 0., 1.),
    vec4f(-0.5, -0.5, 0., 1.),
    vec4f( 0.5, -0.5, 0., 1.)
  );

  var color = array<vec4f, 3>(
    vec4f(1., 0., 0., 1.),
    vec4f(0., 1., 0., 1.),
    vec4f(0., 0., 1., 1.)
  );
  
  return FragmentInput( pos[vertex_index], color[vertex_index] );
}

@fragment
fn fragment_main(input: FragmentInput) -> @location(0) vec4f {
  return input.color;
}
"""


def draw_function_in_pyodide(data):
    """This function will be serialized and sent to the pyodide environment within the browser.
    Packages like "js" or "pyodide.ffi" are not available in the usual python environment.
    """
    # interface for the Javascript environment in the browser, provided by pyodide
    import js
    import pyodide.ffi

    # webgpu.jupyter initializes the pyodide environment as soon as you import it
    from webgpu.jupyter import gpu

    # Some necessary data structures defined by the WebGPU standard
    from webgpu.webgpu_api import (
        CompareFunction,
        DepthStencilState,
        FragmentState,
        PrimitiveState,
        PrimitiveTopology,
        TextureFormat,
        VertexState,
    )

    device = gpu.device

    # compile the shader code
    # open the JS console (F12 in most browsers) to check for compile errors
    shader_module = device.createShaderModule(data["shader_code"])

    # Create the render pipeline object, this defines the used vertex and fragment shader and behaviour of depth buffer, output color format etc.
    pipeline = device.createRenderPipeline(
        device.createPipelineLayout([]),
        vertex=VertexState(module=shader_module, entryPoint="vertex_main"),
        fragment=FragmentState(
            module=shader_module,
            entryPoint="fragment_main",
            targets=[gpu.color_target],
        ),
        primitive=PrimitiveState(topology=PrimitiveTopology.triangle_list),
        depthStencil=DepthStencilState(
            format=gpu.depth_format,
            depthWriteEnabled=True,
            depthCompare=CompareFunction.less,
        ),
        multisample=gpu.multisample,
    )

    def render_function(t):
        """Render function, this function will be called every time a new frame is requested"""
        encoder = gpu.device.createCommandEncoder()
        render_pass = gpu.begin_render_pass(encoder)
        render_pass.setPipeline(pipeline)
        render_pass.draw(3)
        render_pass.end()
        gpu.device.queue.submit([encoder.finish()])

    render_function = pyodide.ffi.create_proxy(render_function)
    gpu.input_handler.render_function = render_function
    js.requestAnimationFrame(render_function)


# Call draw_function_in_pyodide with the shader code as argument within the pyodide environment
webgpu.jupyter.DrawCustom({"shader_code": shader_code}, draw_function_in_pyodide)