ranim_render/primitives/
mesh_items.rs

1use crate::utils::{WgpuContext, WgpuVecBuffer};
2use bytemuck::{Pod, Zeroable};
3use glam::Vec3;
4use ranim_core::{components::rgba::Rgba, core_item::mesh_item::MeshItem};
5
6#[repr(C)]
7#[derive(Debug, Default, Clone, Copy, Pod, Zeroable)]
8pub struct MeshTransform {
9    pub transform: [[f32; 4]; 4],
10}
11
12pub struct MeshItemsBuffer {
13    /// Per-vertex positions (vertex buffer)
14    pub(crate) vertices_buffer: WgpuVecBuffer<Vec3>,
15    /// Per-vertex mesh id (vertex buffer)
16    pub(crate) mesh_ids_buffer: WgpuVecBuffer<u32>,
17    /// Per-vertex colors (vertex buffer)
18    pub(crate) vertex_colors_buffer: WgpuVecBuffer<Rgba>,
19    /// Per-vertex normals (vertex buffer) — all-zero → flat shading fallback
20    pub(crate) vertex_normals_buffer: WgpuVecBuffer<Vec3>,
21    /// Merged triangle indices (index buffer)
22    pub(crate) indices_buffer: WgpuVecBuffer<u32>,
23
24    /// Per-mesh transform matrices (storage buffer, indexed by mesh_id)
25    pub(crate) transforms_buffer: WgpuVecBuffer<MeshTransform>,
26
27    pub(crate) item_count: u32,
28    pub(crate) total_vertices: u32,
29    pub(crate) total_indices: u32,
30
31    pub(crate) render_bind_group: Option<wgpu::BindGroup>,
32}
33
34impl MeshItemsBuffer {
35    pub fn new(ctx: &WgpuContext) -> Self {
36        let vertex_usage = wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST;
37        let index_usage = wgpu::BufferUsages::INDEX | wgpu::BufferUsages::COPY_DST;
38        let storage_ro = wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST;
39
40        Self {
41            vertices_buffer: WgpuVecBuffer::new(ctx, Some("MeshVertices"), vertex_usage, 1),
42            mesh_ids_buffer: WgpuVecBuffer::new(ctx, Some("MeshIds"), vertex_usage, 1),
43            vertex_colors_buffer: WgpuVecBuffer::new(
44                ctx,
45                Some("MeshVertexColors"),
46                vertex_usage,
47                1,
48            ),
49            vertex_normals_buffer: WgpuVecBuffer::new(
50                ctx,
51                Some("MeshVertexNormals"),
52                vertex_usage,
53                1,
54            ),
55            indices_buffer: WgpuVecBuffer::new(ctx, Some("MeshIndices"), index_usage, 1),
56            transforms_buffer: WgpuVecBuffer::new(ctx, Some("MeshTransforms"), storage_ro, 1),
57            item_count: 0,
58            total_vertices: 0,
59            total_indices: 0,
60            render_bind_group: None,
61        }
62    }
63
64    pub fn update(&mut self, ctx: &WgpuContext, mesh_items: &[MeshItem]) {
65        if mesh_items.is_empty() {
66            self.item_count = 0;
67            self.total_vertices = 0;
68            self.total_indices = 0;
69            return;
70        }
71
72        let item_count = mesh_items.len();
73        let total_vertices: usize = mesh_items.iter().map(|m| m.points.len()).sum();
74        let total_indices: usize = mesh_items.iter().map(|m| m.triangle_indices.len()).sum();
75
76        let mut transforms = Vec::with_capacity(item_count);
77        let mut all_vertices = Vec::with_capacity(total_vertices);
78        let mut all_mesh_ids = Vec::with_capacity(total_vertices);
79        let mut all_vertex_colors = Vec::with_capacity(total_vertices);
80        let mut all_vertex_normals = Vec::with_capacity(total_vertices);
81        let mut all_indices = Vec::with_capacity(total_indices);
82
83        let mut vertex_offset: u32 = 0;
84
85        for (mesh_idx, mesh) in mesh_items.iter().enumerate() {
86            let vc = mesh.points.len() as u32;
87
88            transforms.push(MeshTransform {
89                transform: mesh.transform.to_cols_array_2d(),
90            });
91
92            all_vertices.extend_from_slice(&mesh.points);
93            all_mesh_ids.extend(std::iter::repeat_n(mesh_idx as u32, vc as usize));
94            all_vertex_colors.extend_from_slice(&mesh.vertex_colors);
95
96            // Pad normals with zero if shorter than points (flat shading fallback)
97            let normals = &mesh.vertex_normals;
98            let normals_len = normals.len();
99            if normals_len >= vc as usize {
100                all_vertex_normals.extend_from_slice(&normals[..vc as usize]);
101            } else {
102                all_vertex_normals.extend_from_slice(normals);
103                all_vertex_normals
104                    .extend(std::iter::repeat_n(Vec3::ZERO, vc as usize - normals_len));
105            }
106
107            all_indices.extend(mesh.triangle_indices.iter().map(|&i| i + vertex_offset));
108
109            vertex_offset += vc;
110        }
111
112        self.item_count = item_count as u32;
113        self.total_vertices = total_vertices as u32;
114        self.total_indices = total_indices as u32;
115
116        // Vertex/index buffers (no bind group dependency)
117        self.vertices_buffer.set(ctx, &all_vertices);
118        self.mesh_ids_buffer.set(ctx, &all_mesh_ids);
119        self.vertex_colors_buffer.set(ctx, &all_vertex_colors);
120        self.vertex_normals_buffer.set(ctx, &all_vertex_normals);
121        self.indices_buffer.set(ctx, &all_indices);
122
123        // Storage buffers (bind group recreated on realloc)
124        let any_realloc = self.transforms_buffer.set(ctx, &transforms);
125
126        if any_realloc || self.render_bind_group.is_none() {
127            self.render_bind_group = Some(Self::create_render_bind_group(ctx, self));
128        }
129    }
130
131    pub fn item_count(&self) -> u32 {
132        self.item_count
133    }
134
135    pub fn total_indices(&self) -> u32 {
136        self.total_indices
137    }
138
139    pub fn vertex_buffer_layouts() -> [wgpu::VertexBufferLayout<'static>; 4] {
140        [
141            // Slot 0: positions (vec3<f32>)
142            wgpu::VertexBufferLayout {
143                array_stride: std::mem::size_of::<Vec3>() as u64,
144                step_mode: wgpu::VertexStepMode::Vertex,
145                attributes: &[wgpu::VertexAttribute {
146                    format: wgpu::VertexFormat::Float32x3,
147                    offset: 0,
148                    shader_location: 0,
149                }],
150            },
151            // Slot 1: mesh_id (u32)
152            wgpu::VertexBufferLayout {
153                array_stride: std::mem::size_of::<u32>() as u64,
154                step_mode: wgpu::VertexStepMode::Vertex,
155                attributes: &[wgpu::VertexAttribute {
156                    format: wgpu::VertexFormat::Uint32,
157                    offset: 0,
158                    shader_location: 1,
159                }],
160            },
161            // Slot 2: vertex_color (vec4<f32>)
162            wgpu::VertexBufferLayout {
163                array_stride: std::mem::size_of::<Rgba>() as u64,
164                step_mode: wgpu::VertexStepMode::Vertex,
165                attributes: &[wgpu::VertexAttribute {
166                    format: wgpu::VertexFormat::Float32x4,
167                    offset: 0,
168                    shader_location: 2,
169                }],
170            },
171            // Slot 3: vertex_normal (vec3<f32>)
172            wgpu::VertexBufferLayout {
173                array_stride: std::mem::size_of::<Vec3>() as u64,
174                step_mode: wgpu::VertexStepMode::Vertex,
175                attributes: &[wgpu::VertexAttribute {
176                    format: wgpu::VertexFormat::Float32x3,
177                    offset: 0,
178                    shader_location: 3,
179                }],
180            },
181        ]
182    }
183
184    pub fn render_bind_group_layout(ctx: &WgpuContext) -> wgpu::BindGroupLayout {
185        ctx.device
186            .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
187                label: Some("MeshItems Render BGL"),
188                entries: &[
189                    // binding 0: transforms (per-mesh, vertex stage)
190                    bgl_storage_entry(0, wgpu::ShaderStages::VERTEX),
191                ],
192            })
193    }
194
195    fn create_render_bind_group(ctx: &WgpuContext, this: &Self) -> wgpu::BindGroup {
196        ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
197            label: Some("MeshItems Render BG"),
198            layout: &Self::render_bind_group_layout(ctx),
199            entries: &[bg_entry(0, &this.transforms_buffer.buffer)],
200        })
201    }
202}
203
204fn bgl_storage_entry(binding: u32, visibility: wgpu::ShaderStages) -> wgpu::BindGroupLayoutEntry {
205    wgpu::BindGroupLayoutEntry {
206        binding,
207        visibility,
208        ty: wgpu::BindingType::Buffer {
209            ty: wgpu::BufferBindingType::Storage { read_only: true },
210            has_dynamic_offset: false,
211            min_binding_size: None,
212        },
213        count: None,
214    }
215}
216
217fn bg_entry(binding: u32, buffer: &wgpu::Buffer) -> wgpu::BindGroupEntry<'_> {
218    wgpu::BindGroupEntry {
219        binding,
220        resource: wgpu::BindingResource::Buffer(buffer.as_entire_buffer_binding()),
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use std::path::Path;
227
228    use super::*;
229    use crate::{Renderer, resource::RenderPool};
230    use glam::{Mat4, Vec3};
231    use pollster::block_on;
232    use ranim_core::{components::rgba::Rgba, core_item::CoreItem, store::CoreItemStore};
233
234    fn create_triangle_mesh(color: Rgba, offset: Vec3) -> MeshItem {
235        MeshItem {
236            points: vec![
237                Vec3::new(0.0, 1.0, 0.0) + offset,
238                Vec3::new(-1.0, -1.0, 0.0) + offset,
239                Vec3::new(1.0, -1.0, 0.0) + offset,
240            ],
241            triangle_indices: vec![0, 1, 2],
242            transform: Mat4::IDENTITY,
243            vertex_colors: vec![color; 3],
244            vertex_normals: vec![Vec3::ZERO; 3],
245        }
246    }
247
248    fn create_quad_mesh(color: Rgba, offset: Vec3) -> MeshItem {
249        MeshItem {
250            points: vec![
251                Vec3::new(-1.0, 1.0, 0.0) + offset,
252                Vec3::new(1.0, 1.0, 0.0) + offset,
253                Vec3::new(1.0, -1.0, 0.0) + offset,
254                Vec3::new(-1.0, -1.0, 0.0) + offset,
255            ],
256            triangle_indices: vec![0, 1, 2, 0, 2, 3],
257            transform: Mat4::IDENTITY,
258            vertex_colors: vec![color; 4],
259            vertex_normals: vec![Vec3::ZERO; 4],
260        }
261    }
262
263    fn create_sphere_mesh(color: Rgba, radius: f32, position: Vec3) -> MeshItem {
264        let mut points = Vec::new();
265        let mut indices = Vec::new();
266
267        // Simple UV sphere
268        let lat_segments = 20;
269        let lon_segments = 20;
270
271        for lat in 0..=lat_segments {
272            let theta = lat as f32 * std::f32::consts::PI / lat_segments as f32;
273            let sin_theta = theta.sin();
274            let cos_theta = theta.cos();
275
276            for lon in 0..=lon_segments {
277                let phi = lon as f32 * 2.0 * std::f32::consts::PI / lon_segments as f32;
278                let sin_phi = phi.sin();
279                let cos_phi = phi.cos();
280
281                let x = sin_theta * cos_phi;
282                let y = sin_theta * sin_phi;
283                let z = cos_theta;
284
285                points.push(Vec3::new(x * radius, y * radius, z * radius) + position);
286            }
287        }
288
289        for lat in 0..lat_segments {
290            for lon in 0..lon_segments {
291                let first = lat * (lon_segments + 1) + lon;
292                let second = first + lon_segments + 1;
293
294                indices.push(first);
295                indices.push(second);
296                indices.push(first + 1);
297
298                indices.push(second);
299                indices.push(second + 1);
300                indices.push(first + 1);
301            }
302        }
303
304        let vertex_colors = vec![color; points.len()];
305        let vertex_normals = points.iter().map(|p| (*p - position).normalize()).collect();
306
307        MeshItem {
308            points,
309            triangle_indices: indices,
310            transform: Mat4::IDENTITY,
311            vertex_colors,
312            vertex_normals,
313        }
314    }
315
316    #[test]
317    fn render_mesh_items() {
318        use ranim_core::core_item::camera_frame::CameraFrame;
319
320        let ctx = block_on(WgpuContext::new());
321
322        let width = 800u32;
323        let height = 600u32;
324
325        let mut renderer = Renderer::new(&ctx, width, height, 8);
326        let mut render_textures = renderer.new_render_textures(&ctx);
327        let mut pool = RenderPool::new();
328
329        let mut store = CoreItemStore::new();
330
331        let red = Rgba(glam::Vec4::new(1.0, 0.0, 0.0, 1.0));
332        let green = Rgba(glam::Vec4::new(0.0, 1.0, 0.0, 1.0));
333        let blue = Rgba(glam::Vec4::new(0.0, 0.0, 1.0, 0.8));
334        let yellow = Rgba(glam::Vec4::new(1.0, 1.0, 0.0, 0.9));
335
336        let camera_frame = CameraFrame::default();
337        let triangle1 = create_triangle_mesh(red, Vec3::new(-2.0, 0.0, 0.0));
338        let triangle2 = create_triangle_mesh(green, Vec3::new(2.0, 0.0, 0.0));
339        let quad1 = create_quad_mesh(blue, Vec3::new(0.0, 2.0, 0.0));
340        let quad2 = create_quad_mesh(yellow, Vec3::new(0.0, -2.0, 0.0));
341
342        store.update(
343            [
344                ((0, 0), CoreItem::CameraFrame(camera_frame)),
345                ((1, 0), CoreItem::MeshItem(triangle1)),
346                ((1, 1), CoreItem::MeshItem(triangle2)),
347                ((2, 0), CoreItem::MeshItem(quad1)),
348                ((3, 1), CoreItem::MeshItem(quad2)),
349            ]
350            .into_iter(),
351        );
352
353        let clear_color = wgpu::Color {
354            r: 0.1,
355            g: 0.1,
356            b: 0.1,
357            a: 1.0,
358        };
359
360        renderer.render_store_with_pool(&ctx, &mut render_textures, clear_color, &store, &mut pool);
361        pool.clean();
362
363        ctx.device
364            .poll(wgpu::PollType::wait_indefinitely())
365            .unwrap();
366
367        let buffer = render_textures.get_rendered_texture_img_buffer(&ctx);
368
369        let output_path = Path::new("../../output/mesh_items_render.png");
370        buffer.save(output_path).expect("Failed to save image");
371
372        println!("Rendered image saved to: {:?}", output_path);
373        println!("Open it to see the mesh rendering result!");
374
375        assert!(output_path.exists(), "Image file should be created");
376    }
377
378    #[test]
379    fn test_nested_transparent_spheres() {
380        use ranim_core::core_item::camera_frame::CameraFrame;
381
382        let ctx = block_on(WgpuContext::new());
383        let width = 800u32;
384        let height = 600u32;
385
386        let mut renderer = Renderer::new(&ctx, width, height, 8);
387        let mut render_textures = renderer.new_render_textures(&ctx);
388        let mut pool = RenderPool::new();
389        let mut store = CoreItemStore::new();
390
391        // Create nested spheres:
392        // 1. Outer transparent sphere (blue, alpha=0.3, radius=2.0)
393        // 2. Middle opaque sphere (red, alpha=1.0, radius=1.5)
394        // 3. Inner transparent sphere (green, alpha=0.5, radius=1.0)
395
396        let outer_transparent = Rgba(glam::Vec4::new(0.0, 0.0, 1.0, 0.3));
397        let middle_opaque = Rgba(glam::Vec4::new(1.0, 0.0, 0.0, 1.0));
398        let inner_transparent = Rgba(glam::Vec4::new(0.0, 1.0, 0.0, 0.5));
399
400        let outer_sphere = create_sphere_mesh(outer_transparent, 2.0, Vec3::ZERO);
401        let middle_sphere = create_sphere_mesh(middle_opaque, 1.5, Vec3::ZERO);
402        let inner_sphere = create_sphere_mesh(inner_transparent, 1.0, Vec3::ZERO);
403
404        let camera_frame = CameraFrame::default();
405
406        store.update(
407            [
408                ((0, 0), CoreItem::CameraFrame(camera_frame)),
409                ((1, 0), CoreItem::MeshItem(outer_sphere)),
410                ((2, 0), CoreItem::MeshItem(middle_sphere)),
411                ((3, 0), CoreItem::MeshItem(inner_sphere)),
412            ]
413            .into_iter(),
414        );
415
416        let clear_color = wgpu::Color {
417            r: 0.1,
418            g: 0.1,
419            b: 0.1,
420            a: 1.0,
421        };
422
423        renderer.render_store_with_pool(&ctx, &mut render_textures, clear_color, &store, &mut pool);
424        pool.clean();
425
426        ctx.device
427            .poll(wgpu::PollType::wait_indefinitely())
428            .unwrap();
429
430        // Analyze depth buffer
431        let depth_data = render_textures.get_depth_texture_data(&ctx);
432        let mut min_depth = f32::MAX;
433        let mut max_depth = f32::MIN;
434        let mut depth_histogram: std::collections::HashMap<u32, usize> =
435            std::collections::HashMap::new();
436
437        for &d in depth_data {
438            if (d - 1.0).abs() > 0.001 {
439                min_depth = min_depth.min(d);
440                max_depth = max_depth.max(d);
441                let bucket = (d * 10000.0) as u32;
442                *depth_histogram.entry(bucket).or_insert(0) += 1;
443            }
444        }
445
446        println!("\n=== Nested Spheres Depth Test ===");
447        println!("Depth buffer analysis:");
448        println!("  Min depth: {}", min_depth);
449        println!("  Max depth: {}", max_depth);
450        println!("\nDepth histogram (top 10 buckets):");
451        let mut buckets: Vec<_> = depth_histogram.iter().collect();
452        buckets.sort_by_key(|(k, _)| *k);
453        for (bucket, count) in buckets.iter().take(10) {
454            println!(
455                "    depth ~{:.4}: {} pixels",
456                **bucket as f32 / 10000.0,
457                count
458            );
459        }
460
461        let buffer = render_textures.get_rendered_texture_img_buffer(&ctx);
462
463        // Sample some pixels to see actual colors
464        println!("\nColor samples (center region):");
465        let center_x = width / 2;
466        let center_y = height / 2;
467        for dy in [-50, 0, 50].iter() {
468            for dx in [-50, 0, 50].iter() {
469                let x = (center_x as i32 + dx) as u32;
470                let y = (center_y as i32 + dy) as u32;
471                if x < width && y < height {
472                    let pixel = buffer.get_pixel(x, y);
473                    println!(
474                        "  ({:3}, {:3}): R={:3} G={:3} B={:3} A={:3}",
475                        dx, dy, pixel[0], pixel[1], pixel[2], pixel[3]
476                    );
477                }
478            }
479        }
480
481        let buffer = render_textures.get_rendered_texture_img_buffer(&ctx);
482        let output_path = Path::new("../../output/nested_spheres_render.png");
483        buffer.save(output_path).expect("Failed to save image");
484
485        let depth_buffer = render_textures.get_depth_texture_img_buffer(&ctx);
486        let depth_path = Path::new("../../output/nested_spheres_depth.png");
487        depth_buffer
488            .save(depth_path)
489            .expect("Failed to save depth image");
490
491        println!("\nImages saved to output/");
492        println!("\nExpected behavior:");
493        println!("  - Outer transparent blue sphere should be visible");
494        println!("  - Middle opaque red sphere should occlude inner green sphere");
495        println!("  - Inner green sphere should NOT be visible from outside");
496        println!("  - Depth buffer should show opaque red sphere's depth");
497
498        assert!(output_path.exists(), "Image file should be created");
499        assert!(depth_path.exists(), "Depth image file should be created");
500    }
501}