diff --git a/src/shader_manager.py b/src/shader_manager.py index 19ecf57..2f72c5b 100644 --- a/src/shader_manager.py +++ b/src/shader_manager.py @@ -3,7 +3,7 @@ import OpenGL.GL as gl class ShaderManager: def __init__(self): - pass + self.program = None def compile_shader(self, vertex_src: str, fragment_src: str) -> int: program = gl.glCreateProgram() @@ -37,4 +37,26 @@ class ShaderManager: gl.glDeleteShader(vert_shader) gl.glDeleteShader(frag_shader) + self.program = program return program + + def update_uniforms(self, uniforms: dict): + if self.program is None: + return + + for name, value in uniforms.items(): + loc = gl.glGetUniformLocation(self.program, name) + if loc == -1: + continue + + if isinstance(value, float): + gl.glUniform1f(loc, value) + elif isinstance(value, int): + gl.glUniform1i(loc, value) + elif isinstance(value, (list, tuple)): + if len(value) == 2: + gl.glUniform2f(loc, value[0], value[1]) + elif len(value) == 3: + gl.glUniform3f(loc, value[0], value[1], value[2]) + elif len(value) == 4: + gl.glUniform4f(loc, value[0], value[1], value[2], value[3]) diff --git a/tests/test_shader_manager.py b/tests/test_shader_manager.py index f4e9cc6..124d15c 100644 --- a/tests/test_shader_manager.py +++ b/tests/test_shader_manager.py @@ -23,3 +23,29 @@ def test_shader_manager_initialization_and_compilation(): assert program_id == 1 assert mock_gl.glCreateProgram.called assert mock_gl.glCreateShader.called + +def test_shader_manager_uniform_update(): + # Mock OpenGL.GL functions + with patch("src.shader_manager.gl") as mock_gl: + from src.shader_manager import ShaderManager + manager = ShaderManager() + # Set a mock program ID + manager.program = 1 + + # Mock glGetUniformLocation to return some valid locations + # u_time -> 10, u_resolution -> 20 + def mock_get_loc(prog, name): + if name == "u_time": return 10 + if name == "u_resolution": return 20 + return -1 + + mock_gl.glGetUniformLocation.side_effect = mock_get_loc + + # Call the method + manager.update_uniforms({"u_time": 1.5, "u_resolution": (800, 600)}) + + # Assert calls + mock_gl.glGetUniformLocation.assert_any_call(1, "u_time") + mock_gl.glGetUniformLocation.assert_any_call(1, "u_resolution") + mock_gl.glUniform1f.assert_called_once_with(10, 1.5) + mock_gl.glUniform2f.assert_called_once_with(20, 800, 600)