Skip to content

Commit 3fb6cef

Browse files
authored
Meshlet fill cluster buffers rewritten (#15955)
# Objective - Make the meshlet fill cluster buffers pass slightly faster - Address #15920 for meshlets - Added PreviousGlobalTransform as a required meshlet component to avoid extra archetype moves, slightly alleviating #14681 for meshlets - Enforce that MeshletPlugin::cluster_buffer_slots is not greater than 2^25 (glitches will occur otherwise). Technically this field controls post-lod/culling cluster count, and the issue is on pre-lod/culling cluster count, but it's still valid now, and in the future this will be more true. Needs to be merged after #15846 and #15886 ## Solution - Old pass dispatched a thread per cluster, and did a binary search over the instances to find which instance the cluster belongs to, and what meshlet index within the instance it is. - New pass dispatches a workgroup per instance, and has the workgroup loop over all meshlets in the instance in order to write out the cluster data. - Use a push constant instead of arrayLength to fix the linked bug - Remap 1d->2d dispatch for software raster only if actually needed to save on spawning excess workgroups ## Testing - Did you test these changes? If so, how? - Ran the meshlet example, and an example with 1041 instances of 32217 meshlets per instance. Profiled the second scene with nsight, went from 0.55ms -> 0.40ms. Small savings. We're pretty much VRAM bandwidth bound at this point. - How can other people (reviewers) test your changes? Is there anything specific they need to know? - Run the meshlet example ## Changelog (non-meshlets) - PreviousGlobalTransform now implements the Default trait
1 parent 6d42830 commit 3fb6cef

11 files changed

+140
-83
lines changed

crates/bevy_pbr/src/meshlet/cull_clusters.wgsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
meshlet_software_raster_indirect_args,
1414
meshlet_hardware_raster_indirect_args,
1515
meshlet_raster_clusters,
16-
meshlet_raster_cluster_rightmost_slot,
16+
constants,
1717
MeshletBoundingSphere,
1818
}
1919
#import bevy_render::maths::affine3_to_square
@@ -32,7 +32,7 @@ fn cull_clusters(
3232
) {
3333
// Calculate the cluster ID for this thread
3434
let cluster_id = local_invocation_index + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u));
35-
if cluster_id >= arrayLength(&meshlet_cluster_meshlet_ids) { return; }
35+
if cluster_id >= constants.scene_cluster_count { return; }
3636

3737
#ifdef MESHLET_SECOND_CULLING_PASS
3838
if !cluster_is_second_pass_candidate(cluster_id) { return; }
@@ -138,7 +138,7 @@ fn cull_clusters(
138138
} else {
139139
// Append this cluster to the list for hardware rasterization
140140
buffer_slot = atomicAdd(&meshlet_hardware_raster_indirect_args.instance_count, 1u);
141-
buffer_slot = meshlet_raster_cluster_rightmost_slot - buffer_slot;
141+
buffer_slot = constants.meshlet_raster_cluster_rightmost_slot - buffer_slot;
142142
}
143143
meshlet_raster_clusters[buffer_slot] = cluster_id;
144144
}
Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,50 @@
11
#import bevy_pbr::meshlet_bindings::{
2-
cluster_count,
3-
meshlet_instance_meshlet_counts_prefix_sum,
2+
scene_instance_count,
3+
meshlet_global_cluster_count,
4+
meshlet_instance_meshlet_counts,
45
meshlet_instance_meshlet_slice_starts,
56
meshlet_cluster_instance_ids,
67
meshlet_cluster_meshlet_ids,
78
}
89

910
/// Writes out instance_id and meshlet_id to the global buffers for each cluster in the scene.
1011

12+
var<workgroup> cluster_slice_start_workgroup: u32;
13+
1114
@compute
12-
@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 cluster per thread
15+
@workgroup_size(1024, 1, 1) // 1024 threads per workgroup, 1 instance per workgroup
1316
fn fill_cluster_buffers(
1417
@builtin(workgroup_id) workgroup_id: vec3<u32>,
1518
@builtin(num_workgroups) num_workgroups: vec3<u32>,
1619
@builtin(local_invocation_index) local_invocation_index: u32,
1720
) {
18-
// Calculate the cluster ID for this thread
19-
let cluster_id = local_invocation_index + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u));
20-
if cluster_id >= cluster_count { return; } // TODO: Could be an arrayLength?
21-
22-
// Binary search to find the instance this cluster belongs to
23-
var left = 0u;
24-
var right = arrayLength(&meshlet_instance_meshlet_counts_prefix_sum) - 1u;
25-
while left <= right {
26-
let mid = (left + right) / 2u;
27-
if meshlet_instance_meshlet_counts_prefix_sum[mid] <= cluster_id {
28-
left = mid + 1u;
29-
} else {
30-
right = mid - 1u;
31-
}
21+
// Calculate the instance ID for this workgroup
22+
var instance_id = workgroup_id.x + (workgroup_id.y * num_workgroups.x);
23+
if instance_id >= scene_instance_count { return; }
24+
25+
let instance_meshlet_count = meshlet_instance_meshlet_counts[instance_id];
26+
let instance_meshlet_slice_start = meshlet_instance_meshlet_slice_starts[instance_id];
27+
28+
// Reserve cluster slots for the instance and broadcast to the workgroup
29+
if local_invocation_index == 0u {
30+
cluster_slice_start_workgroup = atomicAdd(&meshlet_global_cluster_count, instance_meshlet_count);
3231
}
33-
let instance_id = right;
32+
let cluster_slice_start = workgroupUniformLoad(&cluster_slice_start_workgroup);
3433

35-
// Find the meshlet ID for this cluster within the instance's MeshletMesh
36-
let meshlet_id_local = cluster_id - meshlet_instance_meshlet_counts_prefix_sum[instance_id];
34+
// Loop enough times to write out all the meshlets for the instance given that each thread writes 1 meshlet in each iteration
35+
for (var clusters_written = 0u; clusters_written < instance_meshlet_count; clusters_written += 1024u) {
36+
// Calculate meshlet ID within this instance's MeshletMesh to process for this thread
37+
let meshlet_id_local = clusters_written + local_invocation_index;
38+
if meshlet_id_local >= instance_meshlet_count { return; }
3739

38-
// Find the overall meshlet ID in the global meshlet buffer
39-
let meshlet_id = meshlet_id_local + meshlet_instance_meshlet_slice_starts[instance_id];
40+
// Find the overall cluster ID in the global cluster buffer
41+
let cluster_id = cluster_slice_start + meshlet_id_local;
4042

41-
// Write results to buffers
42-
meshlet_cluster_instance_ids[cluster_id] = instance_id;
43-
meshlet_cluster_meshlet_ids[cluster_id] = meshlet_id;
43+
// Find the overall meshlet ID in the global meshlet buffer
44+
let meshlet_id = instance_meshlet_slice_start + meshlet_id_local;
45+
46+
// Write results to buffers
47+
meshlet_cluster_instance_ids[cluster_id] = instance_id;
48+
meshlet_cluster_meshlet_ids[cluster_id] = meshlet_id;
49+
}
4450
}

crates/bevy_pbr/src/meshlet/instance_manager.rs

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,46 @@ use bevy_ecs::{
1010
query::Has,
1111
system::{Local, Query, Res, ResMut, Resource, SystemState},
1212
};
13-
use bevy_render::sync_world::MainEntity;
14-
use bevy_render::{render_resource::StorageBuffer, view::RenderLayers, MainWorld};
13+
use bevy_render::{
14+
render_resource::StorageBuffer, sync_world::MainEntity, view::RenderLayers, MainWorld,
15+
};
1516
use bevy_transform::components::GlobalTransform;
1617
use bevy_utils::{HashMap, HashSet};
1718
use core::ops::{DerefMut, Range};
1819

1920
/// Manages data for each entity with a [`MeshletMesh`].
2021
#[derive(Resource)]
2122
pub struct InstanceManager {
22-
/// Amount of clusters in the scene (sum of all meshlet counts across all instances)
23+
/// Amount of instances in the scene.
24+
pub scene_instance_count: u32,
25+
/// Amount of clusters in the scene.
2326
pub scene_cluster_count: u32,
2427

25-
/// Per-instance [`MainEntity`], [`RenderLayers`], and [`NotShadowCaster`]
28+
/// Per-instance [`MainEntity`], [`RenderLayers`], and [`NotShadowCaster`].
2629
pub instances: Vec<(MainEntity, RenderLayers, bool)>,
27-
/// Per-instance [`MeshUniform`]
30+
/// Per-instance [`MeshUniform`].
2831
pub instance_uniforms: StorageBuffer<Vec<MeshUniform>>,
29-
/// Per-instance material ID
32+
/// Per-instance material ID.
3033
pub instance_material_ids: StorageBuffer<Vec<u32>>,
31-
/// Prefix-sum of meshlet counts per instance
32-
pub instance_meshlet_counts_prefix_sum: StorageBuffer<Vec<u32>>,
33-
/// Per-instance index to the start of the instance's slice of the meshlets buffer
34+
/// Per-instance count of meshlets in the instance's [`MeshletMesh`].
35+
pub instance_meshlet_counts: StorageBuffer<Vec<u32>>,
36+
/// Per-instance index to the start of the instance's slice of the meshlets buffer.
3437
pub instance_meshlet_slice_starts: StorageBuffer<Vec<u32>>,
3538
/// Per-view per-instance visibility bit. Used for [`RenderLayers`] and [`NotShadowCaster`] support.
3639
pub view_instance_visibility: EntityHashMap<StorageBuffer<Vec<u32>>>,
3740

38-
/// Next material ID available for a [`Material`]
41+
/// Next material ID available for a [`Material`].
3942
next_material_id: u32,
40-
/// Map of [`Material`] to material ID
43+
/// Map of [`Material`] to material ID.
4144
material_id_lookup: HashMap<UntypedAssetId, u32>,
42-
/// Set of material IDs used in the scene
45+
/// Set of material IDs used in the scene.
4346
material_ids_present_in_scene: HashSet<u32>,
4447
}
4548

4649
impl InstanceManager {
4750
pub fn new() -> Self {
4851
Self {
52+
scene_instance_count: 0,
4953
scene_cluster_count: 0,
5054

5155
instances: Vec::new(),
@@ -59,9 +63,9 @@ impl InstanceManager {
5963
buffer.set_label(Some("meshlet_instance_material_ids"));
6064
buffer
6165
},
62-
instance_meshlet_counts_prefix_sum: {
66+
instance_meshlet_counts: {
6367
let mut buffer = StorageBuffer::default();
64-
buffer.set_label(Some("meshlet_instance_meshlet_counts_prefix_sum"));
68+
buffer.set_label(Some("meshlet_instance_meshlet_counts"));
6569
buffer
6670
},
6771
instance_meshlet_slice_starts: {
@@ -80,7 +84,7 @@ impl InstanceManager {
8084
#[allow(clippy::too_many_arguments)]
8185
pub fn add_instance(
8286
&mut self,
83-
instance: Entity,
87+
instance: MainEntity,
8488
meshlets_slice: Range<u32>,
8589
transform: &GlobalTransform,
8690
previous_transform: Option<&PreviousGlobalTransform>,
@@ -108,20 +112,21 @@ impl InstanceManager {
108112

109113
// Append instance data
110114
self.instances.push((
111-
instance.into(),
115+
instance,
112116
render_layers.cloned().unwrap_or(RenderLayers::default()),
113117
not_shadow_caster,
114118
));
115119
self.instance_uniforms.get_mut().push(mesh_uniform);
116120
self.instance_material_ids.get_mut().push(0);
117-
self.instance_meshlet_counts_prefix_sum
121+
self.instance_meshlet_counts
118122
.get_mut()
119-
.push(self.scene_cluster_count);
123+
.push(meshlets_slice.len() as u32);
120124
self.instance_meshlet_slice_starts
121125
.get_mut()
122126
.push(meshlets_slice.start);
123127

124-
self.scene_cluster_count += meshlets_slice.end - meshlets_slice.start;
128+
self.scene_instance_count += 1;
129+
self.scene_cluster_count += meshlets_slice.len() as u32;
125130
}
126131

127132
/// Get the material ID for a [`crate::Material`].
@@ -140,12 +145,13 @@ impl InstanceManager {
140145
}
141146

142147
pub fn reset(&mut self, entities: &Entities) {
148+
self.scene_instance_count = 0;
143149
self.scene_cluster_count = 0;
144150

145151
self.instances.clear();
146152
self.instance_uniforms.get_mut().clear();
147153
self.instance_material_ids.get_mut().clear();
148-
self.instance_meshlet_counts_prefix_sum.get_mut().clear();
154+
self.instance_meshlet_counts.get_mut().clear();
149155
self.instance_meshlet_slice_starts.get_mut().clear();
150156
self.view_instance_visibility
151157
.retain(|view_entity, _| entities.contains(*view_entity));
@@ -227,7 +233,7 @@ pub fn extract_meshlet_mesh_entities(
227233

228234
// Add the instance's data to the instance manager
229235
instance_manager.add_instance(
230-
instance,
236+
instance.into(),
231237
meshlets_slice,
232238
transform,
233239
previous_transform,

crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,27 @@ struct DrawIndirectArgs {
5151
const CENTIMETERS_PER_METER = 100.0;
5252

5353
#ifdef MESHLET_FILL_CLUSTER_BUFFERS_PASS
54-
var<push_constant> cluster_count: u32;
55-
@group(0) @binding(0) var<storage, read> meshlet_instance_meshlet_counts_prefix_sum: array<u32>; // Per entity instance
54+
var<push_constant> scene_instance_count: u32;
55+
@group(0) @binding(0) var<storage, read> meshlet_instance_meshlet_counts: array<u32>; // Per entity instance
5656
@group(0) @binding(1) var<storage, read> meshlet_instance_meshlet_slice_starts: array<u32>; // Per entity instance
5757
@group(0) @binding(2) var<storage, read_write> meshlet_cluster_instance_ids: array<u32>; // Per cluster
5858
@group(0) @binding(3) var<storage, read_write> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
59+
@group(0) @binding(4) var<storage, read_write> meshlet_global_cluster_count: atomic<u32>; // Single object shared between all workgroups
5960
#endif
6061

6162
#ifdef MESHLET_CULLING_PASS
62-
var<push_constant> meshlet_raster_cluster_rightmost_slot: u32;
63+
struct Constants { scene_cluster_count: u32, meshlet_raster_cluster_rightmost_slot: u32 }
64+
var<push_constant> constants: Constants;
6365
@group(0) @binding(0) var<storage, read> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
6466
@group(0) @binding(1) var<storage, read> meshlet_bounding_spheres: array<MeshletBoundingSpheres>; // Per meshlet
6567
@group(0) @binding(2) var<storage, read> meshlet_simplification_errors: array<u32>; // Per meshlet
6668
@group(0) @binding(3) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
6769
@group(0) @binding(4) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
6870
@group(0) @binding(5) var<storage, read> meshlet_view_instance_visibility: array<u32>; // 1 bit per entity instance, packed as a bitmask
6971
@group(0) @binding(6) var<storage, read_write> meshlet_second_pass_candidates: array<atomic<u32>>; // 1 bit per cluster , packed as a bitmask
70-
@group(0) @binding(7) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs; // Single object shared between all workgroups/clusters/triangles
71-
@group(0) @binding(8) var<storage, read_write> meshlet_hardware_raster_indirect_args: DrawIndirectArgs; // Single object shared between all workgroups/clusters/triangles
72-
@group(0) @binding(9) var<storage, read_write> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups/clusters/triangles
72+
@group(0) @binding(7) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs; // Single object shared between all workgroups
73+
@group(0) @binding(8) var<storage, read_write> meshlet_hardware_raster_indirect_args: DrawIndirectArgs; // Single object shared between all workgroups
74+
@group(0) @binding(9) var<storage, read_write> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups
7375
@group(0) @binding(10) var depth_pyramid: texture_2d<f32>; // From the end of the last frame for the first culling pass, and from the first raster pass for the second culling pass
7476
@group(0) @binding(11) var<uniform> view: View;
7577
@group(0) @binding(12) var<uniform> previous_view: PreviousViewUniforms;
@@ -95,7 +97,7 @@ fn cluster_is_second_pass_candidate(cluster_id: u32) -> bool {
9597
@group(0) @binding(3) var<storage, read> meshlet_vertex_positions: array<u32>; // Many per meshlet
9698
@group(0) @binding(4) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
9799
@group(0) @binding(5) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
98-
@group(0) @binding(6) var<storage, read> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups/clusters/triangles
100+
@group(0) @binding(6) var<storage, read> meshlet_raster_clusters: array<u32>; // Single object shared between all workgroups
99101
@group(0) @binding(7) var<storage, read> meshlet_software_raster_cluster_count: u32;
100102
#ifdef MESHLET_VISIBILITY_BUFFER_RASTER_PASS_OUTPUT
101103
@group(0) @binding(8) var<storage, read_write> meshlet_visibility_buffer: array<atomic<u64>>; // Per pixel

crates/bevy_pbr/src/meshlet/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ use self::{
5656
},
5757
visibility_buffer_raster_node::MeshletVisibilityBufferRasterPassNode,
5858
};
59-
use crate::{graph::NodePbr, Material, MeshMaterial3d};
59+
use crate::{graph::NodePbr, Material, MeshMaterial3d, PreviousGlobalTransform};
6060
use bevy_app::{App, Plugin, PostUpdate};
6161
use bevy_asset::{load_internal_asset, AssetApp, AssetId, Handle};
6262
use bevy_core_pipeline::{
@@ -129,6 +129,8 @@ pub struct MeshletPlugin {
129129
/// If this number is too low, you'll see rendering artifacts like missing or blinking meshes.
130130
///
131131
/// Each cluster slot costs 4 bytes of VRAM.
132+
///
133+
/// Must not be greater than 2^25.
132134
pub cluster_buffer_slots: u32,
133135
}
134136

@@ -147,6 +149,11 @@ impl Plugin for MeshletPlugin {
147149
#[cfg(target_endian = "big")]
148150
compile_error!("MeshletPlugin is only supported on little-endian processors.");
149151

152+
if self.cluster_buffer_slots > 2_u32.pow(25) {
153+
error!("MeshletPlugin::cluster_buffer_slots must not be greater than 2^25.");
154+
std::process::exit(1);
155+
}
156+
150157
load_internal_asset!(
151158
app,
152159
MESHLET_BINDINGS_SHADER_HANDLE,
@@ -293,7 +300,7 @@ impl Plugin for MeshletPlugin {
293300
/// The meshlet mesh equivalent of [`bevy_render::mesh::Mesh3d`].
294301
#[derive(Component, Clone, Debug, Default, Deref, DerefMut, Reflect, PartialEq, Eq, From)]
295302
#[reflect(Component, Default)]
296-
#[require(Transform, Visibility)]
303+
#[require(Transform, PreviousGlobalTransform, Visibility)]
297304
pub struct MeshletMesh3d(pub Handle<MeshletMesh>);
298305

299306
impl From<MeshletMesh3d> for AssetId<MeshletMesh> {

crates/bevy_pbr/src/meshlet/pipelines.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl FromWorld for MeshletPipelines {
8484
layout: vec![cull_layout.clone()],
8585
push_constant_ranges: vec![PushConstantRange {
8686
stages: ShaderStages::COMPUTE,
87-
range: 0..4,
87+
range: 0..8,
8888
}],
8989
shader: MESHLET_CULLING_SHADER_HANDLE,
9090
shader_defs: vec![
@@ -99,7 +99,7 @@ impl FromWorld for MeshletPipelines {
9999
layout: vec![cull_layout],
100100
push_constant_ranges: vec![PushConstantRange {
101101
stages: ShaderStages::COMPUTE,
102-
range: 0..4,
102+
range: 0..8,
103103
}],
104104
shader: MESHLET_CULLING_SHADER_HANDLE,
105105
shader_defs: vec![
@@ -441,7 +441,10 @@ impl FromWorld for MeshletPipelines {
441441
pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
442442
label: Some("meshlet_remap_1d_to_2d_dispatch_pipeline".into()),
443443
layout: vec![layout],
444-
push_constant_ranges: vec![],
444+
push_constant_ranges: vec![PushConstantRange {
445+
stages: ShaderStages::COMPUTE,
446+
range: 0..4,
447+
}],
445448
shader: MESHLET_REMAP_1D_TO_2D_DISPATCH_SHADER_HANDLE,
446449
shader_defs: vec![],
447450
entry_point: "remap_dispatch".into(),

crates/bevy_pbr/src/meshlet/remap_1d_to_2d_dispatch.wgsl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ struct DispatchIndirectArgs {
88

99
@group(0) @binding(0) var<storage, read_write> meshlet_software_raster_indirect_args: DispatchIndirectArgs;
1010
@group(0) @binding(1) var<storage, read_write> meshlet_software_raster_cluster_count: u32;
11+
var<push_constant> max_compute_workgroups_per_dimension: u32;
1112

1213
@compute
1314
@workgroup_size(1, 1, 1)
1415
fn remap_dispatch() {
1516
meshlet_software_raster_cluster_count = meshlet_software_raster_indirect_args.x;
1617

17-
let n = u32(ceil(sqrt(f32(meshlet_software_raster_indirect_args.x))));
18-
meshlet_software_raster_indirect_args.x = n;
19-
meshlet_software_raster_indirect_args.y = n;
18+
if meshlet_software_raster_cluster_count > max_compute_workgroups_per_dimension {
19+
let n = u32(ceil(sqrt(f32(meshlet_software_raster_cluster_count))));
20+
meshlet_software_raster_indirect_args.x = n;
21+
meshlet_software_raster_indirect_args.y = n;
22+
}
2023
}

0 commit comments

Comments
 (0)