Skip to content
Snippets Groups Projects
Select Git revision
  • 6ea64d690ab8b989dc7782f5e97af1d2cbb9cd20
  • main default
  • tp3
  • tp2
  • tp1
  • tp3-correction
  • tp2-correction
  • tp1-correction
  • admins
9 results

test_operators.py

Blame
  • Forked from an inaccessible project.
    main.py 2.79 KiB
    from fastapi import FastAPI
    from fastapi.middleware.cors import CORSMiddleware
    from dotenv import load_dotenv
    import os
    from db.database import get_db
    from fastapi import Depends
    from sqlalchemy.orm import Session
    from db import schemas
    from typing import List
    
    from db import database, models
    from routers import stats, comments, news
    
    app = FastAPI(docs_url="/api/docs", openapi_url="/api/openapi.json")
    
    # load environment variables
    load_dotenv()
    
    origins = [
        os.getenv('WEB_ROOT'),
    ]
    
    app.add_middleware(
        CORSMiddleware,
        allow_origins=origins,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"]
    )
    
    
    @app.on_event("startup")
    def on_startup():
        # Database creation
        models.Base.metadata.create_all(bind=database.engine)
    
    
    # Integration of routers
    app.include_router(stats.router)
    app.include_router(comments.router)
    app.include_router(news.router)
    
    
    @app.get('/api/records', response_model=List[schemas.Record])
    async def get_records(place: str, db: Session = Depends(get_db)):
        return db.query(models.Records).filter(models.Records.place == place).order_by(models.Records.date.desc()).all()
    
    
    @app.post('/api/records', response_model=schemas.Record)
    async def stats(record: schemas.RecordBase, db: Session = Depends(get_db)):
        db_record = models.Records(**record.dict())
        db.add(db_record)
        db.commit()
        db.refresh(db_record)
        return db_record
    
    
    @app.delete('/api/records', response_model=None)
    async def stats(id: int, db: Session = Depends(get_db)):
        if id == 0:
            db.query(models.Records).delete()
        else:
            db.query(models.Records).filter(models.Records.id == id).delete()
        db.commit()
        return
    
    
    """
    import cv2
    import numpy as np
    import keras
    
    from utils.preprocessing import fix_singular_shape, norm_by_imagenet
    
    
    model = keras.models.load_model('model')
    
    # contours of the zone of a picture that should be analyzed by the model
    contours = {
        'eiffel': [[70, 370], [420, 720], [1280, 720], [1280, 250], [930, 215], [450, 550], [130, 350]]
    }
    
    masks = {}
    for key, polygon in contours.items():
        mask = np.zeros((1280, 720, 3), dtype=np.unit8)
        cv2.fillPoly(mask, [polygon], (255, 255, 255))
        masks[key] = mask
    
    
    @app.get("/estimate/{id}")
    async def estimate_(id: str) -> float:
        # img = fetch(...)
        img = np.zeros((1280, 720, 3))
        resized_img = cv2.cvtColor(cv2.resize(img, (1280, 720)), cv2.COLOR_BGR2RGB).astype(np.float32)
        masked_img = cv2.bitwise_and(resized_img, mask[id])
        treated_img = fix_singular_shape(masked_img, 16)
        input_image = np.expand_dims(np.squeeze(norm_by_imagenet([treated_img])), axis=0)
        pred_map = np.squeeze(model.predict(input_image))
        count_prediction = np.sum(pred_map)
        return count_prediction
    """