feat(shaders): Implement uniform data passing for ShaderManager
This commit is contained in:
@@ -3,7 +3,7 @@ import OpenGL.GL as gl
|
|||||||
|
|
||||||
class ShaderManager:
|
class ShaderManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
self.program = None
|
||||||
|
|
||||||
def compile_shader(self, vertex_src: str, fragment_src: str) -> int:
|
def compile_shader(self, vertex_src: str, fragment_src: str) -> int:
|
||||||
program = gl.glCreateProgram()
|
program = gl.glCreateProgram()
|
||||||
@@ -37,4 +37,26 @@ class ShaderManager:
|
|||||||
gl.glDeleteShader(vert_shader)
|
gl.glDeleteShader(vert_shader)
|
||||||
gl.glDeleteShader(frag_shader)
|
gl.glDeleteShader(frag_shader)
|
||||||
|
|
||||||
|
self.program = program
|
||||||
return program
|
return program
|
||||||
|
|
||||||
|
def update_uniforms(self, uniforms: dict):
|
||||||
|
if self.program is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for name, value in uniforms.items():
|
||||||
|
loc = gl.glGetUniformLocation(self.program, name)
|
||||||
|
if loc == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(value, float):
|
||||||
|
gl.glUniform1f(loc, value)
|
||||||
|
elif isinstance(value, int):
|
||||||
|
gl.glUniform1i(loc, value)
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
if len(value) == 2:
|
||||||
|
gl.glUniform2f(loc, value[0], value[1])
|
||||||
|
elif len(value) == 3:
|
||||||
|
gl.glUniform3f(loc, value[0], value[1], value[2])
|
||||||
|
elif len(value) == 4:
|
||||||
|
gl.glUniform4f(loc, value[0], value[1], value[2], value[3])
|
||||||
|
|||||||
@@ -23,3 +23,29 @@ def test_shader_manager_initialization_and_compilation():
|
|||||||
assert program_id == 1
|
assert program_id == 1
|
||||||
assert mock_gl.glCreateProgram.called
|
assert mock_gl.glCreateProgram.called
|
||||||
assert mock_gl.glCreateShader.called
|
assert mock_gl.glCreateShader.called
|
||||||
|
|
||||||
|
def test_shader_manager_uniform_update():
|
||||||
|
# Mock OpenGL.GL functions
|
||||||
|
with patch("src.shader_manager.gl") as mock_gl:
|
||||||
|
from src.shader_manager import ShaderManager
|
||||||
|
manager = ShaderManager()
|
||||||
|
# Set a mock program ID
|
||||||
|
manager.program = 1
|
||||||
|
|
||||||
|
# Mock glGetUniformLocation to return some valid locations
|
||||||
|
# u_time -> 10, u_resolution -> 20
|
||||||
|
def mock_get_loc(prog, name):
|
||||||
|
if name == "u_time": return 10
|
||||||
|
if name == "u_resolution": return 20
|
||||||
|
return -1
|
||||||
|
|
||||||
|
mock_gl.glGetUniformLocation.side_effect = mock_get_loc
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
manager.update_uniforms({"u_time": 1.5, "u_resolution": (800, 600)})
|
||||||
|
|
||||||
|
# Assert calls
|
||||||
|
mock_gl.glGetUniformLocation.assert_any_call(1, "u_time")
|
||||||
|
mock_gl.glGetUniformLocation.assert_any_call(1, "u_resolution")
|
||||||
|
mock_gl.glUniform1f.assert_called_once_with(10, 1.5)
|
||||||
|
mock_gl.glUniform2f.assert_called_once_with(20, 800, 600)
|
||||||
|
|||||||
Reference in New Issue
Block a user