Skip to content
Snippets Groups Projects
Verified Commit b1408368 authored by Arthur Conrozier's avatar Arthur Conrozier
Browse files

oidc

parent 64323b4f
Branches
No related tags found
1 merge request!1Updating main
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
import base64
import os
from dotenv import load_dotenv
from flask import Flask, session, redirect
from flask import request
import requests
import uuid
import jwt # PyJWT library
load_dotenv()
CLIENT_ID = os.getenv('CLIENT_ID')
......@@ -14,6 +18,33 @@ USERINFO_URL = os.getenv('USERINFO_URL')
SESSION_SECRET = os.getenv('SESSION_SECRET')
LOGOUT_URL = os.getenv('LOGOUT_URL')
def fetch_jwks(jwks_url):
response = requests.get(jwks_url)
return response.json()
def jwk_to_pem(jwk):
modulus = int.from_bytes(base64.urlsafe_b64decode(jwk["n"] + "==="), "big")
exponent = int.from_bytes(base64.urlsafe_b64decode(jwk["e"] + "==="), "big")
# Create RSA Public Key
public_key = rsa.RSAPublicNumbers(exponent, modulus).public_key()
# Convert to PEM format
pem_public_key = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem_public_key.decode()
jwks = fetch_jwks(os.getenv('JWKS_URL'))
first_key = jwks["keys"][0]
PUBLIC_KEY = jwk_to_pem(first_key)
app = Flask(__name__)
app.secret_key = SESSION_SECRET
......@@ -28,17 +59,21 @@ app.config['SESSION_KEY_PREFIX'] = 'oauth_' # Prefix for session keys
@app.route('/')
def home():
if 'access_token' in session:
user_profile = fetchUserProfile(session['access_token'])
if user_profile:
return f'Hello, {user_profile["first_name"]} -> <a href="/logout">Log out</a> / <a href="/avatar">Avatar page</a>'
if 'id_token' in session and 'nonce' in session:
decoded_id_token = decodeIdToken(session['id_token'], session['nonce'])
if 'first_name' in decoded_id_token:
return f'Hello {decoded_id_token["first_name"]} <a href="/logout">Log out</a> / <a href="/avatar">Avatar page</a>'
# Generate a state token to prevent CSRF
state = str(uuid.uuid4())
session['state'] = state
# Generate a nonce token to prevent replay attacks
nonce = str(uuid.uuid4())
session['nonce'] = nonce
# If the user hasn't been redirected or has an invalid code, show a login button
login_url = getLoginUrl(state)
login_url = getLoginUrl(state, nonce)
session['redirect_path'] = '/'
return f'Main page -> <a href={login_url}>Login</a> / <a href="/avatar">Avatar page</a>'
......@@ -46,18 +81,22 @@ def home():
@app.route('/avatar')
def avatar():
if 'access_token' in session:
user_profile = fetchUserProfile(session['access_token'])
if user_profile:
if 'id_token' in session and 'nonce' in session:
decoded_id_token = decodeIdToken(session['id_token'], session['nonce'])
if 'avatar' in decoded_id_token:
return f'Here is your face : <img src="{
user_profile["avatar"]} "> / <a href="/">Main page</a>'
decoded_id_token['avatar']} "> / <a href="/">Main page</a>'
# Generate a state token to prevent CSRF
state = str(uuid.uuid4())
session['state'] = state
# Generate a nonce token to prevent replay attacks
nonce = str(uuid.uuid4())
session['nonce'] = nonce
# The user isn't logged in, or his associated Access Token is no longer valid
login_url = getLoginUrl(state)
login_url = getLoginUrl(state, nonce)
session['redirect_path'] = '/avatar'
return f'Avatar page -> <a href={login_url}>Login</a> / <a href="/">Main page</a>'
......@@ -78,7 +117,9 @@ def callback():
# If the response is successful, use the access token to fetch the user's profile
if response.status_code == 200:
access_token = response.json()['access_token']
id_token = response.json()['id_token']
session['access_token'] = access_token
session['id_token'] = id_token
if 'redirect_path' in session:
redirect_path = session['redirect_path']
......@@ -97,16 +138,6 @@ def logout():
return redirect(LOGOUT_URL)
def fetchUserProfile(access_token):
headers = {
"Authorization": f"Bearer {access_token}"
}
response = requests.get(USERINFO_URL, headers=headers)
if response.status_code == 200:
return response.json()
return
def getAccessToken(code):
payload = {
"code": code,
......@@ -121,18 +152,32 @@ def getAccessToken(code):
return requests.post(TOKEN_URL, data=payload, headers=headers)
def getLoginUrl(state):
def getLoginUrl(state, nonce):
params = {
"client_id": CLIENT_ID,
"redirect_uri": "http://localhost:3000/callback",
"response_type": "code",
"scope": "profile",
"state": state
"scope": "profile openid",
"state": state,
"nonce": nonce
}
return requests.Request(
'GET', AUTHORIZATION_URL,
params=params).prepare().url
def decodeIdToken(id_token, nonce):
# Verify token signature
decoded_id_token = jwt.decode(
id_token, PUBLIC_KEY, algorithms=["RS256"], audience=CLIENT_ID
)
# Verify nonce to prevent replay attacks
if decoded_id_token.get("nonce") != nonce:
return
return decoded_id_token
if __name__ == '__main__':
app.run(debug=True, port=3000)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment