"""
ShaderManager
=============
Example Usage
=============
Illustrative only (needs an OpenGL context, loaded shader sources, ``my_mvp_matrix``,
and ``set_common_uniforms`` in scope)::
shader_manager = ShaderManager()
for shader_type_value in ShaderType:
shader_manager.load_shader_source_string(shader_type_value)
if shader_manager.use_shader_type(ShaderType.ATOMS):
shader_manager.current_shader_program = shader_manager.get(ShaderType.ATOMS)
set_common_uniforms(
shader_manager.current_shader_program,
mvp_matrix=my_mvp_matrix,
point_size=15.0,
highlight=True,
highlight_color=(1.0, 1.0, 0.0),
)
File naming convention:
=======================
Ensure GLSL files follow the naming pattern:
atoms_vert.glsl
atoms_frag.glsl
bonds_vert.glsl
bonds_frag.glsl
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
import numpy as np
from decologr import Decologr as log
from pyglm import glm
from picogl.backend.modern.core.shader.context import gl_context_available
from picogl.backend.modern.core.shader.program import ShaderProgram
from picogl.backend.modern.core.uniform.location import get_uniform_location
from picogl.backend.modern.core.uniform.location_value import set_uniform_location_value
from picogl.backend.modern.core.uniform.mvp import shader_uniform_set_mvp
from picogl.backend.modern.core.uniform.set_location import set_uniform_name_value
from picogl.globals import PICOGL_SHADER_SRC_DIRECTORY, SHADER_SRC_DIRECTORY
from picogl.shaders.compile import compile_shaders
from picogl.shaders.generate import generate_shader_programs
from picogl.shaders.load import (
load_fragment_and_vertex_for_shader_type,
load_shader_source_string,
)
from picogl.shaders.type import ShaderType
[docs]
def _progress_iter(
pairs: Iterable[Tuple[int, ShaderType]], *, desc: str, total: int
) -> Iterable[Tuple[int, ShaderType]]:
"""Optional tqdm in real terminals only; GUI apps use plain iteration (no monitor thread)."""
import sys
# tqdm spawns a background monitor thread; mixing that with Qt + gl init has
# caused segfaults when stderr is not a TTY (IDE / GUI runs).
if not sys.stderr.isatty():
return pairs
try:
from tqdm.rich import tqdm
except ImportError:
try:
from tqdm import tqdm
except ImportError:
return pairs
return tqdm(
pairs,
desc=desc,
total=total,
unit="shader",
leave=False,
monitor_interval=0,
)
@dataclass
[docs]
class ShaderManager:
[docs]
shaders: Dict[ShaderType, ShaderProgram] = field(default_factory=dict)
[docs]
fallback_shader: Optional[ShaderProgram] = None
[docs]
default_shader_type: ShaderType = ShaderType.DEFAULT
[docs]
current_shader_type: ShaderType = ShaderType.DEFAULT
[docs]
current_shader: Optional[ShaderProgram] = None
[docs]
current_shader_program: Optional[int] = None
[docs]
_initialized: bool = False
[docs]
_initializing: bool = False
[docs]
shader_directory: str = ""
[docs]
fallback_shader_directory: str = ""
[docs]
def use_shader_program(self, shader_program: ShaderProgram) -> None:
"""
use_shader_program
:param shader_program: PicoGLShader
:return: None
Bind the given shader shader_program and update current_shader/shader_program ID
"""
if not self._initialized and not self._initializing:
self.initialize_shaders()
if not shader_program:
log.error(
"❌ Cannot bind: shader_program is None or invalid", scope="load_shader"
)
return
try:
shader_program.bind()
self.current_shader = shader_program
self.current_shader_program = shader_program.program_id()
except RuntimeError as ex:
log.warning(
f"Shader bind skipped: {ex}",
scope="load_shader",
)
except Exception as ex:
log.error(
f"❌ Failed to bind shader shader_program: {ex}", scope="load_shader"
)
[docs]
def bind(self, shader_program: ShaderProgram) -> None:
"""
bind
:param shader_program: ShaderProgram
:return: None
"""
shader_program.bind()
self.current_shader = shader_program
self.current_shader_program = shader_program.program_id()
[docs]
def unbind(self) -> None:
"""
unbind
:return: None
"""
from OpenGL.GL import glUseProgram
glUseProgram(0)
self.current_shader = None
self.current_shader_program = None
[docs]
def get_shader_type(
self, shader_type: ShaderType
) -> Optional[ShaderProgram | ShaderProgram]:
"""
Return the shader shader_program for the given ShaderType, loading if necessary.
"""
cached = self.shaders.get(shader_type)
if cached is not None:
return cached
if shader_type not in self.shaders:
shader_number = list(ShaderType).index(shader_type)
self.load_shader(shader_type, shader_number)
return self.shaders.get(shader_type)
return None
[docs]
def use_shader_type(
self,
shader_type: ShaderType,
mvp_matrix: np.ndarray | glm.mat4 = None,
zoom_scale: int | float = None,
) -> bool:
"""
Load (if needed) and bind the shader of the given type.
Returns True when the program is bound in the current gl context.
"""
if not self._initialized and not self._initializing:
self.initialize_shaders()
shader = self.get_shader_type(shader_type)
if not shader:
log.error(
f"❌ Shader type {shader_type} could not be loaded or bound.",
scope=self.__class__.__name__,
)
return False
self.use_shader_program(shader)
if self.current_shader is not shader:
return False
self.current_shader_type = shader_type
if mvp_matrix is not None:
self.update_mvp_uniform(mvp_matrix=mvp_matrix)
if zoom_scale is not None:
if self.current_shader_type == ShaderType.ATOMS:
loc = get_uniform_location(
self.current_shader.program_id(), "zoom_scale"
)
if loc != -1:
set_uniform_location_value(loc, zoom_scale)
return True
[docs]
def use_default_shader(self, mvp_matrix: np.ndarray | glm.mat4 = None) -> None:
"""
use_default_shader
:param mvp_matrix: np.ndarray | glm.mat4
:return:
Bind the default shader type.
"""
if not self._initialized and not self._initializing:
self.initialize_shaders()
self.use_shader_type(
shader_type=self.default_shader_type, mvp_matrix=mvp_matrix
)
[docs]
def initialize_shaders(
self,
shader_dir: str = None,
*,
on_shader_loaded: Optional[Callable[[int, int, ShaderType], None]] = None,
):
"""Initialize src and mark gl state as ready."""
# Load src into the manager. If caller does not provide a directory,
# keep an existing shader_directory or default to PicoGL's packaged src root.
if shader_dir:
target_dir = str(shader_dir)
elif self.shader_directory:
target_dir = str(self.shader_directory)
else:
target_dir = str(PICOGL_SHADER_SRC_DIRECTORY)
if self._initialized:
if target_dir == str(self.shader_directory):
return
self.release_shaders()
self._initialized = False
if not gl_context_available():
log.warning(
"ShaderManager.initialize_shaders deferred: no current OpenGL context. "
"Load shaders from initializeGL / paintGL after the gl widget context is current.",
scope="load_shader",
)
return
if self._initializing:
return
self._initializing = True
try:
self.shader_directory = target_dir
failed = []
shader_pairs = list(enumerate(ShaderType))
n = len(shader_pairs)
for shader_number, shader_type in _progress_iter(
shader_pairs, desc="Shader programs", total=n
):
log.message(
f"Loading shader type: '{shader_type.value} from {self.shader_directory}'",
silent=True,
scope="load_shader",
)
self.load_shader(shader_type, shader_number)
if on_shader_loaded is not None:
try:
on_shader_loaded(shader_number, n, shader_type)
except Exception:
pass
if self.shaders[shader_type] is self.fallback_shader:
failed.append(shader_type)
if failed:
log.warning(
f"⚠️ Shader fallback used for: {', '.join(st.value for st in failed)}",
scope="load_shader",
)
log.message(
"✅ Shader sources loaded (including fallback where needed).",
scope="load_shader",
silent=True,
)
default_shader = self.shaders.get(self.default_shader_type)
if default_shader is None:
default_shader = self.get_shader_type(self.default_shader_type)
if default_shader:
self.use_shader_program(default_shader)
if self.current_shader is default_shader:
self.current_shader_type = self.default_shader_type
self._initialized = True
if not self._initialized:
log.error(
"ShaderManager: default shader could not be bound; "
"modern rendering will stay disabled until gl init succeeds.",
scope="load_shader",
)
finally:
self._initializing = False
[docs]
def load_shader(self, shader_type: str, shader_number: int) -> None:
"""
load_shader
:param shader_type: ShaderType
:return: None
"""
if not gl_context_available():
log.warning(
f"Cannot compile shader {shader_type}: no current OpenGL context",
scope="load_shader",
)
return
try:
log.message(
f"Loading shaders from {self.shader_directory}",
silent=True,
scope="load_shader",
)
vertex_src, fragment_src = load_fragment_and_vertex_for_shader_type(
shader_type.value, self.shader_directory
)
picogl_shader_program = generate_shader_programs(
vertex_src, fragment_src, shader_type
)
if picogl_shader_program:
log.message(
f"[{shader_number}/{len(ShaderType)}] ✅ Shader type `{shader_type}` compiled and registered",
scope=self.__class__.__name__,
silent=True,
)
self.shaders[shader_type] = picogl_shader_program
else:
log.warning(f"⚠️ Falling back for {shader_type}", scope="load_shader")
self._ensure_fallback()
self.shaders[shader_type] = self.fallback_shader
except Exception as ex:
log.warning(
f"⚠️ Shader load failed for {shader_type}: {ex}", scope="load_shader"
)
self._ensure_fallback()
self.shaders[shader_type] = self.fallback_shader
[docs]
def _fallback_shader_sources(self) -> tuple[str, str]:
"""Resolve fallback GLSL from the active shader root, then PicoGL defaults."""
if self.shader_directory:
base = Path(self.shader_directory)
root = base / "src" if (base / "src").is_dir() else base
fallback_dir = root / "fallback"
vert_path = fallback_dir / "vertex.glsl"
frag_path = fallback_dir / "fragment.glsl"
if vert_path.is_file() and frag_path.is_file():
return (
load_shader_source_string(str(vert_path)),
load_shader_source_string(str(frag_path)),
)
return (
load_shader_source_string("fallback_vertex.glsl", SHADER_SRC_DIRECTORY),
load_shader_source_string("fallback_fragment.glsl", SHADER_SRC_DIRECTORY),
)
[docs]
def _ensure_fallback(self):
"""
_ensure_fallback
:return: None
"""
if self.fallback_shader is None:
try:
vert, frag = self._fallback_shader_sources()
self.fallback_shader = compile_shaders(vert, frag, "fallback")
log.message(
"✅ Fallback shader_manager.current_shader_program compiled",
silent=True,
scope="load_shader",
)
except Exception as ex:
log.error(
f"❌ Fallback shader_manager.current_shader_program setup failed: {ex}"
)
[docs]
def get(self, shader_type: ShaderType) -> Optional[ShaderProgram | ShaderProgram]:
return self.shaders.get(shader_type)
[docs]
def release_shaders(self):
"""
release_shaders
:return: None
"""
for key, shader in self.shaders.items():
try:
shader.release()
except (Exception,):
pass
self.shaders.clear()
if self.fallback_shader:
try:
self.fallback_shader.release()
except (Exception,):
pass
self.fallback_shader = None