"""
Module to interact with the database
"""
from datetime import date, datetime, time, timedelta
from fastapi import HTTPException
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from sqlalchemy import Time, Date, cast
from uuid import uuid4
import secrets
import pytz

from db import models, schemas


# Define CRUD operation to collect the statistics

def get_waiting_time(place: str, db: Session):
    """ Get the last estimated waiting time for the given place """
    current_date = datetime.now(tz=pytz.timezone("Europe/Paris"))
    weekday, current_time = current_date.weekday(), current_date.time()
    opening_hours = db.query(models.OpeningHours.open_time, models.OpeningHours.close_time).filter(
        models.OpeningHours.place == place, models.OpeningHours.day == weekday).all()
    for time_slot in opening_hours:
        if current_time < time_slot.open_time:
            return schemas.WaitingTime(next_timetable=time_slot.open_time.strftime('%Hh%M'))
        elif current_time <= time_slot.close_time:
            limit = datetime.combine(date.today(), time_slot.open_time)
            last_record = db.query(models.Records.waiting_time).filter(models.Records.place == place).filter(
                models.Records.date >= limit).order_by(models.Records.date.desc()).first()
            waiting_time = None
            if last_record:
                waiting_time = round(
                    last_record.waiting_time.total_seconds() / 60)
            return schemas.WaitingTime(status=True, waiting_time=waiting_time)
    return schemas.WaitingTime()


# Define some utils function
def shift_time(t: time, delta: timedelta):
    return (datetime.combine(date(1, 1, 1), t) + delta).time()


def add_slot(slots_list, start_time, end_time, function):
    average_waiting_time = function(start_time, end_time)
    if average_waiting_time:
        name = 60 * start_time.hour + start_time.minute
        slots_list.append(schemas.RecordRead(
            name=name, time=average_waiting_time))


def get_avg_graph_points(place: str, weekday: int, min_time: time,
                         max_time: time, interval: timedelta, db: Session):
    """ Get the average waiting time for each interval between two time steps """

    def avg_time_query(start_time, end_time):
        records = db.query(
            func.round(
                func.avg(
                    60 * func.extract('HOUR', models.Records.waiting_time) +
                    func.extract('MINUTE', models.Records.waiting_time))
            )
        ).filter(
            models.Records.place == place,
            func.weekday(models.Records.date) == weekday,
            cast(models.Records.date, Time) >= start_time,
            cast(models.Records.date, Time) <= end_time,
        ).first()
        if records[0] or records[0] == 0:
            return int(records[0])
        return None

    stats = []
    start_time, end_time = min_time, shift_time(min_time, interval)
    while start_time < max_time:
        add_slot(stats, start_time, end_time, avg_time_query)
        start_time, end_time = end_time, shift_time(end_time, interval)

    return stats


def get_avg_graph(place: str, db: Session):
    """ Get the average waiting time for each interval between two time steps,
    for the current or next available timeslot"""
    current_date = datetime.now(tz=pytz.timezone("Europe/Paris"))
    weekday, current_time = current_date.weekday(), current_date.time()
    opening_hours = db.query(models.OpeningHours.open_time, models.OpeningHours.close_time).filter(
        models.OpeningHours.place == place, models.OpeningHours.day == weekday).order_by(models.OpeningHours.open_time).all()

    for time_slot in opening_hours:
        if current_time <= time_slot.close_time:
            return get_avg_graph_points(place, weekday, time_slot.open_time, time_slot.close_time, timedelta(minutes=5), db)
    return []


def get_current_graph_points(place: str, current_date: date,
                             min_time: time, max_time: time, interval: timedelta, db: Session):
    """ Get the waiting time for each interval between two time steps for the current timeslot """

    def current_time_query(start_time, end_time):
        records = db.query(
            func.round(
                func.avg(
                    60 * func.extract('HOUR', models.Records.waiting_time) +
                    func.extract('MINUTE', models.Records.waiting_time))
            )
        ).filter(
            models.Records.place == place,
            cast(models.Records.date, Date) == current_date,
            cast(models.Records.date, Time) >= start_time,
            cast(models.Records.date, Time) <= end_time
        ).first()
        if records[0] or records[0] == 0:
            return int(records[0])
        return None

    stats = []
    start_time, end_time = min_time, shift_time(min_time, interval)
    while start_time < max_time:
        add_slot(stats, start_time, end_time, current_time_query)
        start_time, end_time = end_time, shift_time(end_time, interval)

    return stats


def get_current_graph(place: str, db: Session):
    """ Get the waiting_time_graph for the current timeslot"""
    current_date = datetime.now(tz=pytz.timezone("Europe/Paris"))
    weekday, day, current_time = current_date.weekday(
    ), current_date.date(), current_date.time()
    opening_hours = db.query(models.OpeningHours.open_time, models.OpeningHours.close_time).filter(
        models.OpeningHours.place == place, models.OpeningHours.day == weekday).all()

    for time_slot in opening_hours:
        if time_slot.open_time <= current_time <= time_slot.close_time:
            points = get_current_graph_points(
                place, day, time_slot.open_time, current_time, timedelta(minutes=5), db)
            start_time = 60 * time_slot.open_time.hour + time_slot.open_time.minute
            end_time = 60 * time_slot.close_time.hour + time_slot.close_time.minute
            return schemas.Graph(data=points, start=start_time, end=end_time)
    return schemas.Graph(data=[])


# Define CRUD operation for the comments

def get_comments(place: str, page: int, db: Session):
    """ Get the 20 last comments for the given place """
    if page == 0:
        comments = db.query(
            models.Comments).order_by(
            models.Comments.published_at.desc(),
            models.Comments.id.desc()).all()
    else:
        comments = db.query(models.Comments, models.Users.username).join(models.Users).filter(models.Comments.place == place).order_by(
            models.Comments.published_at.desc(), models.Comments.id.desc()).slice((page - 1) * 20, page * 20).all()
    comments_list = list(schemas.Comment(
        **comment.__dict__, username=username) for comment, username in comments)
    comments_list.reverse()
    return comments_list


def create_comment(user: schemas.User, place: str, new_comments: schemas.CommentBase, db: Session):
    """ Add a new comment to the database """
    date = datetime.now(tz=pytz.timezone("Europe/Paris"))
    db_comment = models.Comments(
        **new_comments.dict(), published_at=date, place=place, user_id=user.id)
    db.add(db_comment)
    db.commit()
    db.refresh(db_comment)
    return schemas.Comment(**db_comment.__dict__, username=user.username)


def delete_comment(id: int, db: Session):
    """ Delete the comment with the matching id """
    if id == 0:
        db.query(models.Comments).delete()
    else:
        db.query(models.Comments).filter(models.Comments.id == id).delete()
    db.commit()


# Define CRUD operation for the news

def get_news(place: str, db: Session):
    """ Get the news for the given place """
    news = db.query(
        models.News).filter(
        models.News.place == place).order_by(
            models.News.published_at.desc()).all()
    return news


def create_news(new_news: schemas.NewsBase, db: Session):
    """ Add a news to the database """
    date = datetime.now(tz=pytz.timezone("Europe/Paris"))
    db_news = models.News(**new_news.dict(), published_at=date)
    db.add(db_news)
    db.commit()
    db.refresh(db_news)
    return db_news


def delete_news(id: int, db: Session):
    """ Delete the news with the matching id """
    if id == 0:
        db.query(models.News).delete()
    else:
        db.query(models.News).filter(models.News.id == id).delete()
    db.commit()


# Define CRUD operation for the opening hours

def get_opening_hours(place: str, db: Session):
    """ Get the opening hours for the given place """
    opening_hours = db.query(
        models.OpeningHours
    ).filter(
        models.OpeningHours.place == place
    ).order_by(
        models.OpeningHours.day, models.OpeningHours.open_time
    ).all()
    return opening_hours


def create_opening_hours(
        new_opening_hours: schemas.OpeningHoursBase, db: Session):
    """ Add opening hours to the database """
    db_opening_hours = models.OpeningHours(**new_opening_hours.dict())
    db.add(db_opening_hours)
    db.commit()
    db.refresh(db_opening_hours)
    return db_opening_hours


def delete_opening_hours(id: int, db: Session):
    """ Delete the opening hours with the matching id """
    if id == 0:
        db.query(models.OpeningHours).delete()
    else:
        db.query(
            models.OpeningHours).filter(
            models.OpeningHours.id == id).delete()
    db.commit()


# Restaurants information

def get_restaurants(db: Session):
    weekday = datetime.now(tz=pytz.timezone("Europe/Paris")).weekday()
    places = db.query(models.OpeningHours.place).distinct()
    restaurants = []

    for place in places:
        opening_hours = db.query(
            models.OpeningHours).filter(
            models.OpeningHours.place == place.place,
            models.OpeningHours.day == weekday).order_by(
            models.OpeningHours.open_time).all()
        opening_hours_formated = [
            f"{row.open_time.strftime('%Hh%M')}-{row.close_time.strftime('%Hh%M')}" for row in opening_hours]
        timetable = "/".join(opening_hours_formated)
        infos = get_waiting_time(place.place, db)
        restaurants.append(schemas.Restaurant(
            **infos.dict(), name=place.place, timetable=timetable))

    return restaurants


# Define CRUD operation for the authentication

def init_user(db: Session):
    """ Add a news to the database """
    cookie = uuid4()
    state = secrets.token_urlsafe(30)
    expiration_date = datetime.now(tz=pytz.timezone(
        "Europe/Paris")) + timedelta(minutes=10)
    db_user = models.Users(state=state, cookie=cookie,
                           expiration_date=expiration_date)
    db.add(db_user)
    db.commit()
    db.refresh(db_user)
    return db_user


def get_user(cookie: str, db: Session):
    """ Get user infos """
    try:
        user = db.query(models.Users).filter(
            models.Users.cookie == cookie).one()
    except BaseException:
        raise HTTPException(status_code=401, detail="Invalid cookie")

    if pytz.timezone("Europe/Paris").localize(user.expiration_date) < datetime.now(tz=pytz.timezone("Europe/Paris")):
        user.cookie = None
        db.add(user)
        db.commit()
        raise HTTPException(status_code=401, detail="Expired cookie")

    return user


def delete_state(user: schemas.User, db: Session):
    """ Delete the state of a user """
    user.state = None
    db.add(user)
    db.commit()


def update_user(user: schemas.User, user_info: dict, db: Session):
    full_name = f"{user_info['firstName']} {user_info['lastName']}"
    expiration_date = datetime.now(
        tz=pytz.timezone("Europe/Paris")) + timedelta(days=3)
    existing_user = db.query(models.Users).filter(
        models.Users.username == full_name).first()
    if existing_user:
        existing_user.cookie = user.cookie
        existing_user.expiration_date = expiration_date
        db.delete(user)
        db.add(existing_user)
        db.commit()
        db.refresh(existing_user)
        return existing_user
    else:
        user.username = full_name
        user.expiration_date = expiration_date
        db.add(user)
        db.commit()
        db.refresh(user)
        return user


def end_session(cookie: str, db: Session):
    user = db.query(models.Users).filter(models.Users.cookie == cookie).one()
    user.expiration_date = datetime.now(tz=pytz.timezone("Europe/Paris"))
    user.cookie = None
    db.add(user)
    db.commit()
    return


def delete_user(cookie: str, db: Session):
    db.query(models.Users).filter(models.Users.cookie == cookie).delete()
    db.commit()
    return


# Define CRUD operations for data collection

def get_records(place: str, db: Session):
    records = db.query(models.Records).filter(
        models.Records.place == place).order_by(models.Records.date.desc()).all()
    return records


def create_record(record: schemas.RecordBase, db: Session):
    db_record = models.Records(**record.dict())
    db.add(db_record)
    db.commit()
    db.refresh(db_record)
    return db_record


def delete_record(id: int, db: Session):
    if id == 0:
        db.query(models.Records).delete()
    else:
        db.query(models.Records).filter(models.Records.id == id).delete()
    db.commit()
    return


def get_collaborative_records(place: str, db: Session):
    records = db.query(models.CollaborativeRecords).filter(
        models.CollaborativeRecords.place == place).order_by(models.CollaborativeRecords.date.desc()).all()
    return [schemas.CollaborativeRecords(**record.__dict__) for record in records]


def create_collaborative_record(user: schemas.User, place: str, db: Session):
    current_date = datetime.now(tz=pytz.timezone("Europe/Paris"))
    date, weekday, current_time = current_date.date(
    ), current_date.weekday(), current_date.time()

    try:
        time_slot = db.query(
            models.OpeningHours).filter(
            models.OpeningHours.place == place,
            models.OpeningHours.day == weekday,
            models.OpeningHours.open_time <= current_time,
            models.OpeningHours.close_time >= current_time).one()
    except BaseException:
        raise HTTPException(status_code=404, detail="No restaurant opened")

    last_record = db.query(models.CollaborativeRecords).filter(
        models.CollaborativeRecords.user_id == user.id).order_by(models.CollaborativeRecords.date.desc()).first()
    if not last_record or last_record.date <= datetime.combine(date, time_slot.open_time):
        db_record = models.CollaborativeRecords(
            user_id=user.id, place=place, date=current_date)
        db.add(db_record)
        db.commit()
        db.refresh(db_record)
        return db_record

    raise HTTPException(status_code=406, detail="Client already registered")


def update_collaborative_record(user: schemas.User, db: Session):
    current_date = datetime.now(tz=pytz.timezone("Europe/Paris"))
    date, weekday, current_time = current_date.date(
    ), current_date.weekday(), current_date.time()
    last_record = db.query(models.CollaborativeRecords).filter(
        models.CollaborativeRecords.user_id == user.id).order_by(models.CollaborativeRecords.date.desc()).first()

    try:
        time_slot = db.query(
            models.OpeningHours).filter(
            models.OpeningHours.place == last_record.place,
            models.OpeningHours.day == weekday,
            models.OpeningHours.open_time <= current_time,
            models.OpeningHours.close_time >= current_time).one()
    except BaseException:
        raise HTTPException(status_code=404, detail="No restaurant opened")

    if last_record.date >= datetime.combine(date, time_slot.open_time) and not last_record.waiting_time:
        last_record.waiting_time = current_date - \
            pytz.timezone("Europe/Paris").localize(last_record.date)
        print(last_record.waiting_time)
        db.add(last_record)
        db.commit()
        db.refresh(last_record)
        return schemas.CollaborativeRecords(**last_record.__dict__)

    raise HTTPException(status_code=406, detail="Client already registered")


def delete_collaborative_record(id: int, db: Session):
    if id == 0:
        db.query(models.CollaborativeRecords).delete()
    else:
        db.query(models.CollaborativeRecords).filter(
            models.CollaborativeRecords.id == id).delete()
    db.commit()
    return