from typing import List
from fastapi import Body, Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from dotenv import load_dotenv
import os

from db import crud, schemas, database


# load environment variables
load_dotenv()

app = FastAPI(docs_url="/api/docs", openapi_url="/api/openapi.json")

origins = [
    os.getenv('WEB_ROOT'),
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)

def get_db():
    """Create a database session."""
    db = database.SessionLocal()
    try:
        yield db
    finally:
        db.close()


@app.get('/api/{place}', response_model=List[schemas.Record])
async def eatfast(place: str, db: Session = Depends(get_db)):
    return crud.get_records(place, db)


@app.post('/api/create', response_model=schemas.Record)
async def post(record: schemas.RecordBase = Body(...), db: Session = Depends(get_db)):
    return crud.create_record(record, db)


"""
import cv2
import numpy as np
import keras

from utils.preprocessing import fix_singular_shape, norm_by_imagenet


model = keras.models.load_model('model')

# contours of the zone of a picture that should be analyzed by the model
contours = {
    'eiffel': [[70, 370], [420, 720], [1280, 720], [1280, 250], [930, 215], [450, 550], [130, 350]]
}

masks = {}
for key, polygon in contours.items():
    mask = np.zeros((1280, 720, 3), dtype=np.unit8)
    cv2.fillPoly(mask, [polygon], (255, 255, 255))
    masks[key] = mask


@app.get("/estimate/{id}")
async def estimate_(id: str) -> float:
    # img = fetch(...)
    img = np.zeros((1280, 720, 3))
    resized_img = cv2.cvtColor(cv2.resize(img, (1280, 720)), cv2.COLOR_BGR2RGB).astype(np.float32)
    masked_img = cv2.bitwise_and(resized_img, mask[id])
    treated_img = fix_singular_shape(masked_img, 16)
    input_image = np.expand_dims(np.squeeze(norm_by_imagenet([treated_img])), axis=0)
    pred_map = np.squeeze(model.predict(input_image))
    count_prediction = np.sum(pred_map)
    return count_prediction
"""