feat(shaders): Implement uniform data passing for ShaderManager

This commit is contained in:
2026-03-13 12:29:10 -04:00
parent ac4f63b76e
commit 09383960be
2 changed files with 49 additions and 1 deletions

View File

@@ -3,7 +3,7 @@ import OpenGL.GL as gl
class ShaderManager: class ShaderManager:
def __init__(self): def __init__(self):
pass self.program = None
def compile_shader(self, vertex_src: str, fragment_src: str) -> int: def compile_shader(self, vertex_src: str, fragment_src: str) -> int:
program = gl.glCreateProgram() program = gl.glCreateProgram()
@@ -37,4 +37,26 @@ class ShaderManager:
gl.glDeleteShader(vert_shader) gl.glDeleteShader(vert_shader)
gl.glDeleteShader(frag_shader) gl.glDeleteShader(frag_shader)
self.program = program
return program return program
def update_uniforms(self, uniforms: dict):
if self.program is None:
return
for name, value in uniforms.items():
loc = gl.glGetUniformLocation(self.program, name)
if loc == -1:
continue
if isinstance(value, float):
gl.glUniform1f(loc, value)
elif isinstance(value, int):
gl.glUniform1i(loc, value)
elif isinstance(value, (list, tuple)):
if len(value) == 2:
gl.glUniform2f(loc, value[0], value[1])
elif len(value) == 3:
gl.glUniform3f(loc, value[0], value[1], value[2])
elif len(value) == 4:
gl.glUniform4f(loc, value[0], value[1], value[2], value[3])

View File

@@ -23,3 +23,29 @@ def test_shader_manager_initialization_and_compilation():
assert program_id == 1 assert program_id == 1
assert mock_gl.glCreateProgram.called assert mock_gl.glCreateProgram.called
assert mock_gl.glCreateShader.called assert mock_gl.glCreateShader.called
def test_shader_manager_uniform_update():
# Mock OpenGL.GL functions
with patch("src.shader_manager.gl") as mock_gl:
from src.shader_manager import ShaderManager
manager = ShaderManager()
# Set a mock program ID
manager.program = 1
# Mock glGetUniformLocation to return some valid locations
# u_time -> 10, u_resolution -> 20
def mock_get_loc(prog, name):
if name == "u_time": return 10
if name == "u_resolution": return 20
return -1
mock_gl.glGetUniformLocation.side_effect = mock_get_loc
# Call the method
manager.update_uniforms({"u_time": 1.5, "u_resolution": (800, 600)})
# Assert calls
mock_gl.glGetUniformLocation.assert_any_call(1, "u_time")
mock_gl.glGetUniformLocation.assert_any_call(1, "u_resolution")
mock_gl.glUniform1f.assert_called_once_with(10, 1.5)
mock_gl.glUniform2f.assert_called_once_with(20, 800, 600)