You want runtime typechecking.
See either beartype [1] or typeguard [2]. And if you're doing any kind of array-based programming (JAX or not), then jaxtyping [3].
[1] https://github.com/beartype/beartype/
Thanks for posting this. I had seen beartype several years ago but I don't believe it had the whole-module registration feature yet. I'm looking forward to trying both of the libraries since the ergonomics are better than decorating every function individually.