Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • a_first_steps
  • b_sessions
  • c_second_page
  • d_callback
  • e_oidc
  • f_refreshing_tokens
  • main
7 results

Target

Select target project
  • ViaRezo/formations/demonstration-auth-vr
1 result
Select Git revision
  • a_first_steps
  • b_sessions
  • c_second_page
  • d_callback
  • e_oidc
  • f_refreshing_tokens
  • main
7 results
Show changes

Commits on Source 10

......@@ -4,3 +4,4 @@ AUTHORIZATION_URL="https://auth.viarezo.fr/oauth/authorize"
TOKEN_URL="https://auth.viarezo.fr/oauth/token"
USERINFO_URL="https://auth.viarezo.fr/oidc/userinfo"
LOGOUT_URL="https://auth.viarezo.fr/logout"
SESSION_SECRET="my-super-secret"
\ No newline at end of file
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
import base64
import os
from flask import Flask
from dotenv import load_dotenv
from flask import Flask, session, redirect
from flask import request
import requests
import uuid
import jwt # PyJWT library
load_dotenv()
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
AUTHORIZATION_URL = os.getenv('AUTHORIZATION_URL')
TOKEN_URL = os.getenv('TOKEN_URL')
USERINFO_URL = os.getenv('USERINFO_URL')
SESSION_SECRET = os.getenv('SESSION_SECRET')
LOGOUT_URL = os.getenv('LOGOUT_URL')
def fetch_jwks(jwks_url):
response = requests.get(jwks_url)
return response.json()
def jwk_to_pem(jwk):
modulus = int.from_bytes(base64.urlsafe_b64decode(jwk["n"] + "==="), "big")
exponent = int.from_bytes(base64.urlsafe_b64decode(jwk["e"] + "==="), "big")
# Create RSA Public Key
public_key = rsa.RSAPublicNumbers(exponent, modulus).public_key()
# Convert to PEM format
pem_public_key = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem_public_key.decode()
jwks = fetch_jwks(os.getenv('JWKS_URL'))
first_key = jwks["keys"][0]
PUBLIC_KEY = jwk_to_pem(first_key)
app = Flask(__name__)
app.secret_key = SESSION_SECRET
app.config['SESSION_COOKIE_HTTPONLY'] = True # Prevents JavaScript access to the cookie
app.config['SESSION_COOKIE_SECURE'] = True # Ensures the cookie is only sent over HTTPS
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' # Mitigates CSRF attacks
app.config['SESSION_TYPE'] = 'redis' # Use Redis to store sessions (server-side)
app.config['SESSION_PERMANENT'] = False # Session is deleted when the browser is closed
app.config['SESSION_USE_SIGNER'] = True # Sign the session cookie
app.config['SESSION_KEY_PREFIX'] = 'oauth_' # Prefix for session keys
@app.route('/')
def home():
if 'id_token' in session and 'nonce' in session:
decoded_id_token = decodeIdToken(session['id_token'], session['nonce'])
if 'first_name' in decoded_id_token and not 'refresh' in request.args:
return f'Hello {decoded_id_token["first_name"]} <a href="/logout">Log out</a> / <a href="/avatar">Avatar page</a>'
elif (len(decoded_id_token) == 2 and decoded_id_token[0] == 'Token expired') or 'refresh' in request.args:
session['nonce'] = None
req = refreshTokens()
if req.status_code == 200:
session['access_token'] = req.json()['access_token']
session['id_token'] = req.json()['id_token']
session['refresh_token'] = req.json()['refresh_token']
return redirect('/')
# Generate a state token to prevent CSRF
state = str(uuid.uuid4())
session['state'] = state
# Generate a nonce token to prevent replay attacks
nonce = str(uuid.uuid4())
session['nonce'] = nonce
# If the user hasn't been redirected or has an invalid code, show a login button
login_url = getLoginUrl(state, nonce)
session['redirect_path'] = '/'
return f'Main page -> <a href={login_url}>Login</a> / <a href="/avatar">Avatar page</a>'
@app.route('/avatar')
def avatar():
if 'id_token' in session and 'nonce' in session:
decoded_id_token = decodeIdToken(session['id_token'], session['nonce'])
if 'avatar' in decoded_id_token and not 'refresh' in request.args:
return f'Here is your face : <img src="{
decoded_id_token['avatar']} "> / <a href="/">Main page</a>'
elif (len(decoded_id_token) == 2 and decoded_id_token[0] == 'Token expired') or 'refresh' in request.args:
session['nonce'] = None
req = refreshTokens()
if req.status_code == 200:
session['access_token'] = req.json()['access_token']
session['id_token'] = req.json()['id_token']
session['refresh_token'] = req.json()['refresh_token']
return redirect('/avatar')
# Generate a state token to prevent CSRF
state = str(uuid.uuid4())
session['state'] = state
# Generate a nonce token to prevent replay attacks
nonce = str(uuid.uuid4())
session['nonce'] = nonce
# The user isn't logged in, or his associated Access Token is no longer valid
login_url = getLoginUrl(state, nonce)
session['redirect_path'] = '/avatar'
return f'Avatar page -> <a href={login_url}>Login</a> / <a href="/">Main page</a>'
@app.route('/callback')
def callback():
# Check if the user has been redirected back from the OAuth provider
if 'code' in request.args:
code = request.args['code']
# Check if the state token matches
if 'state' not in request.args or request.args['state'] != session['state']:
return 'A possible CSRF attempt was detected', 401
response = getAccessToken(code)
# If the response is successful, use the access token to fetch the user's profile
if response.status_code == 200:
access_token = response.json()['access_token']
id_token = response.json()['id_token']
refresh_token = response.json()['refresh_token']
session['access_token'] = access_token
session['id_token'] = id_token
session['refresh_token'] = refresh_token
if 'redirect_path' in session:
redirect_path = session['redirect_path']
session.pop('redirect_path', None)
return redirect(redirect_path)
else:
return 'Could not fetch access token', 500
# Redirect to the main page
return redirect('/')
@app.route('/logout')
def logout():
session.clear()
return redirect(LOGOUT_URL)
def getAccessToken(code):
payload = {
"code": code,
"client_id": os.environ.get('CLIENT_ID'),
"client_secret": os.environ.get('CLIENT_SECRET'),
"redirect_uri": "http://localhost:3000/",
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"redirect_uri": "http://localhost:3000/callback",
"grant_type": "authorization_code"
}
headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
response = requests.post(os.environ.get('TOKEN_URL'), data=payload, headers=headers)
return requests.post(TOKEN_URL, data=payload, headers=headers)
# If the response is successful, use the access token to fetch the user's profile
if response.status_code == 200:
access_token = response.json()['access_token']
def refreshTokens():
payload = {
"refresh_token": session['refresh_token'],
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"grant_type": "refresh_token"
}
headers = {
"Authorization": f"Bearer {access_token}"
"Content-Type": "application/x-www-form-urlencoded"
}
response = requests.get(os.environ.get('USERINFO_URL'), headers=headers)
if response.status_code == 200:
return f'Hello, {response.json()["first_name"]}'
else:
return 'Could not fetch user profile'
return requests.post(TOKEN_URL, data=payload, headers=headers)
# If the user hasn't been redirected or has an invalid code, show a login button
def getLoginUrl(state, nonce):
params = {
"client_id": os.environ.get('CLIENT_ID'),
"redirect_uri": "http://localhost:3000/",
"client_id": CLIENT_ID,
"redirect_uri": "http://localhost:3000/callback",
"response_type": "code",
"scope": "profile"
"scope": "profile openid",
"state": state,
"nonce": nonce
}
login_url = requests.Request(
'GET', os.environ.get('AUTHORIZATION_URL'),
return requests.Request(
'GET', AUTHORIZATION_URL,
params=params).prepare().url
return f'Main page -> <a href={login_url}>Login</a>'
def decodeIdToken(id_token, nonce):
try:
# Verify token signature
decoded_id_token = jwt.decode(
id_token, PUBLIC_KEY, algorithms=["RS256"], audience=CLIENT_ID
)
except jwt.ExpiredSignatureError:
return 'Token expired', 401
except jwt.InvalidAudienceError:
return 'Invalid audience', 401
except jwt.InvalidIssuerError:
return 'Invalid issuer', 401
except jwt.InvalidSignatureError:
return 'Invalid signature', 401
except:
return 'Invalid token', 401
# Verify nonce to prevent replay attacks
if decoded_id_token.get("nonce") != nonce:
return 'Invalid nonce', 401
return decoded_id_token
if __name__ == '__main__':
......
cryptography
requests
dotenv
flask
PyJWT
\ No newline at end of file