Hi Juan,
sorry for not replying, I got busy with some other stuff inbetween. Thanks for your reply and pointing out some more infos on the caching topic.
It turns out, that the present version of the pyjwt library (2.6.0) has quite some new features that address the given issues. At present, the documentation does not contain these features (there is a TODO note in the docs about that). However, it seems the library is now capable of caching keys for a given lifetime. You init the PyJWKClient like that:
jwks_client = PyJWKClient(jwks_url, cache_jwk_set=True, lifespan=300)
The lifespan is given in secconds, so it will cache for a maximum of 5 minutes like that. (cache_jwk_set=True and lifespan=300 are also the default parameter, so you might as well just skip them).
That way, the jwks will be requested only every 5 minutes and corrupted keys are removed after 5 minutes at latest. New keys are beeing requested as soon as an unkown kid is sent to the client. That way, the issue with DDOS attacks with an unkown kid remains - but honestly, if you sent DDOS attacks to my backend, you don’t need such tricks to get it down…