Skip to content
Snippets Groups Projects
Select Git revision
  • 00ed6a769c8ddb61ddc6648cefd22b64e869ee50
  • master default
2 results

PredictionRepository.php

Blame
  • main.py 2.82 KiB
    from datetime import datetime, time, timedelta
    from typing import List
    from fastapi import Body, Depends, FastAPI
    from fastapi.middleware.cors import CORSMiddleware
    from sqlalchemy.orm import Session
    from dotenv import load_dotenv
    import os
    
    from db import crud, schemas, database, models
    
    
    # load environment variables
    load_dotenv()
    
    app = FastAPI(docs_url="/api/docs", openapi_url="/api/openapi.json")
    
    origins = [
        os.getenv('WEB_ROOT'),
    ]
    
    app.add_middleware(
        CORSMiddleware,
        allow_origins=origins,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"]
    )
    
    
    def get_db():
        """Create a database session."""
        db = database.SessionLocal()
        try:
            yield db
        finally:
            db.close()
    
    
    @app.on_event("startup")
    def on_startup():
        # Database creation
        models.Base.metadata.create_all(bind=database.engine)
    
    
    @app.get('/api/{place}', response_model=List[schemas.Record])
    async def eatfast(place: str, db: Session = Depends(get_db)):
        return crud.get_records(place, db)
    
    
    @app.post('/api/create', response_model=schemas.Record)
    async def post(record: schemas.RecordBase = Body(...), db: Session = Depends(get_db)):
        return crud.create_record(record, db)
    
    
    @app.get('/api/{place}/waiting_time', response_model=timedelta)
    async def waiting_time(place: str, db: Session = Depends(get_db)):
        return crud.get_waiting_time(place, db)
    
    
    @app.get('/api/{place}/stats/{day}/{min_time_hour}/{min_time_mn}/{max_time_hour}/{max_time_mn}/{interval}', response_model=list)
    async def stats(place: str, day: int, min_time_hour: int, min_time_mn: int,
                    max_time_hour: int, max_time_mn: int, interval: timedelta, db: Session = Depends(get_db)):
        return crud.get_stats(place, day, min_time_hour, min_time_mn, max_time_hour, max_time_mn, interval, db)
    
    
    """
    import cv2
    import numpy as np
    import keras
    
    from utils.preprocessing import fix_singular_shape, norm_by_imagenet
    
    
    model = keras.models.load_model('model')
    
    # contours of the zone of a picture that should be analyzed by the model
    contours = {
        'eiffel': [[70, 370], [420, 720], [1280, 720], [1280, 250], [930, 215], [450, 550], [130, 350]]
    }
    
    masks = {}
    for key, polygon in contours.items():
        mask = np.zeros((1280, 720, 3), dtype=np.unit8)
        cv2.fillPoly(mask, [polygon], (255, 255, 255))
        masks[key] = mask
    
    
    @app.get("/estimate/{id}")
    async def estimate_(id: str) -> float:
        # img = fetch(...)
        img = np.zeros((1280, 720, 3))
        resized_img = cv2.cvtColor(cv2.resize(img, (1280, 720)), cv2.COLOR_BGR2RGB).astype(np.float32)
        masked_img = cv2.bitwise_and(resized_img, mask[id])
        treated_img = fix_singular_shape(masked_img, 16)
        input_image = np.expand_dims(np.squeeze(norm_by_imagenet([treated_img])), axis=0)
        pred_map = np.squeeze(model.predict(input_image))
        count_prediction = np.sum(pred_map)
        return count_prediction
    """