{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python394jvsc74a57bd0d55a872fb12b64c3eb6a530d12935ddebcb38da0925d2cc3bd9c2ebc1d370b0d",
   "display_name": "Python 3.9.4 64-bit"
  },
  "metadata": {
   "interpreter": {
    "hash": "d55a872fb12b64c3eb6a530d12935ddebcb38da0925d2cc3bd9c2ebc1d370b0d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "INFO:tensorflow:Enabling eager execution\n",
      "INFO:tensorflow:Enabling v2 tensorshape\n",
      "INFO:tensorflow:Enabling resource variables\n",
      "INFO:tensorflow:Enabling tensor equality\n",
      "INFO:tensorflow:Enabling control flow v2\n",
      "WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named \"keras_metadata.pb\" in the SavedModel directory.\n",
      "Model used: firstModel\n"
     ]
    }
   ],
   "source": [
    "#@title Imports\n",
    "#%load_ext autoreload  #Need to uncomment for import sometime, dont understand\n",
    "\n",
    "#Tensorflow :\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import datasets, layers, models, losses\n",
    "import tensorflow_datasets as tfds\n",
    "#from google.colab import files\n",
    "\n",
    "#Others :\n",
    "from matplotlib import image\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import random as rd\n",
    "import cv2\n",
    "import csv\n",
    "\n",
    "#Data loaders :\n",
    "from loadFer2013DS import *\n",
    "from loadRavdessDS import *\n",
    "from loadExpWDS import *\n",
    "from loadAffwildDS import *\n",
    "\n",
    "#Others\n",
    "from utils import *\n",
    "from config import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Chargement des données\n",
    "print(\"Array loading...\")\n",
    "Xf = np.load(\"data/array/Xf\")\n",
    "Xe = np.load(\"data/array/Xe\")\n",
    "Xa = np.load(\"data/array/Xa\")\n",
    "Xr = np.load(\"data/array/Xr\")\n",
    "\n",
    "Yf = np.load(\"data/array/Yf\")\n",
    "Ye = np.load(\"data/array/Ye\")\n",
    "Ya = np.load(\"data/array/Ya\")\n",
    "Yr = np.load(\"data/array/Yr\")\n",
    "\n",
    "print(\"Concatenation...\")\n",
    "X = np.concatenate([Xf, Xa, Xe, Xr])\n",
    "Y = np.concatenate([Yf, Ya, Ye, Yr])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Enregistre X et Y directement, à faire si assez de ram\n",
    "np.save(\"data/array/X\", X)\n",
    "np.save(\"data/array/Y\", Y)\n",
    "def loadData():\n",
    "    return np.load(\"data/array/X\"), np.load(\"data/array/Y\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Visualisation de chaque dataset\n",
    "for X, Y, name in zip([Xf, Xr, Xe, Xa], [Yf, Yr, Ye, Ya], [\"fer2013\", \"ravdess\", \"expW\", \"affwild\"]):\n",
    "    N=5\n",
    "    M=5\n",
    "    print(\"Dataset:\", name)\n",
    "    print(\"Images:\", X.shape, \"La   bels:\", Y.shape)\n",
    "    plt.figure()\n",
    "    for i in range(N*M):\n",
    "        if X.shape[0] == 0: continue\n",
    "        k = rd.randrange(X.shape[0])\n",
    "        plt.subplot(N, M, i+1)\n",
    "        plt.xticks([])\n",
    "        plt.yticks([])\n",
    "        plt.grid(False)\n",
    "\n",
    "        afficher(X[k])\n",
    "        plt.title(emotions[int(Y[k])])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Visualisation du dataset global\n",
    "print(\"X_train:\", X_train.shape)\n",
    "print(\"Y_train:\", Y_train.shape)\n",
    "\n",
    "print(\"\\nX_test:\", X_test.shape)\n",
    "print(\"Y_test:\", Y_test.shape)\n",
    "\n",
    "N=5\n",
    "M=5\n",
    "plt.figure()\n",
    "for i in range(N*M):\n",
    "    k = rd.randrange(X_train.shape[0])\n",
    "    plt.subplot(N, M, i+1)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.grid(False)\n",
    "\n",
    "    afficher(X_train[k])\n",
    "    plt.title(emotions[int(Y_train[k])])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Hyperparamètres\n",
    "epochs = 2\n",
    "batch_size = 128\n",
    "validation_size = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Labels catégoriques\n",
    "Ycat = keras.utils.to_categorical(Y)\n",
    "\n",
    "print(\"X\", X.shape)\n",
    "print(\"Y\", Ycat.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#MODELE\n",
    "class MyModel(keras.Sequential):\n",
    "\n",
    "    def __init__(self, input_shape):\n",
    "        super(MyModel, self).__init__()\n",
    "        #Pre processing\n",
    "        self.add(keras.layers.experimental.preprocessing.RandomContrast(factor=(0.5,0.5)))\n",
    "        self.add(keras.layers.experimental.preprocessing.RandomFlip(mode=\"horizontal\"))\n",
    "        \n",
    "        #48*48 *1\n",
    "        self.add(keras.layers.Conv2D(32, kernel_size = (3, 3), activation = 'relu', input_shape = input_shape))        \n",
    "        self.add(keras.layers.MaxPooling2D(pool_size = 2))\n",
    "        self.add(keras.layers.BatchNormalization())\n",
    "\n",
    "        #23*23 *32\n",
    "        self.add(keras.layers.Conv2D(64, kernel_size = (3, 3), activation = 'relu'))\n",
    "        self.add(keras.layers.MaxPooling2D(pool_size = 2))\n",
    "        self.add(keras.layers.BatchNormalization())\n",
    "\n",
    "        #10*10 *64\n",
    "        self.add(keras.layers.Conv2D(128, kernel_size = (3, 3), activation = 'relu'))\n",
    "        self.add(keras.layers.MaxPooling2D(pool_size = 2))\n",
    "        self.add(keras.layers.BatchNormalization())\n",
    "\n",
    "        #4*4 *128\n",
    "        self.add(keras.layers.Conv2D(256, kernel_size = (3, 3), activation = 'relu'))\n",
    "        self.add(keras.layers.MaxPooling2D(pool_size = 2))\n",
    "        self.add(keras.layers.BatchNormalization())\n",
    "\n",
    "        #1*1 *256\n",
    "        self.add(keras.layers.Flatten())\n",
    "        self.add(keras.layers.Dense(128, activation = 'relu'))\n",
    "        self.add(keras.layers.Dropout(0.2))\n",
    "        self.add(keras.layers.Dense(64, activation = 'relu'))\n",
    "        self.add(keras.layers.Dropout(0.2))\n",
    "        #self.add(keras.layers.BatchNormalization())\n",
    "        self.add(keras.layers.Dense(7, activation = 'softmax'))\n",
    "        #7\n",
    "    \n",
    "    def predir(self, monImage):\n",
    "        return self.predict(np.array([monImage]))[0,:]\n",
    "\n",
    "    def compile_o(self):\n",
    "        self.compile(optimizer = 'adam', loss=losses.categorical_crossentropy, metrics = ['accuracy'])\n",
    "\n",
    "myModel = MyModel(input_shape)\n",
    "myModel.compile_o()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theImage = X_train[0]\n",
    "afficher(theImage)\n",
    "print(predir(myModel, theImage))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = myModel.fit(X, Y, epochs=5, validation_rate=0.05)\n",
    "\n",
    "#Affichage de l'historique de l'apprentissage\n",
    "plt.plot(history.history['accuracy'], label='accuracy')\n",
    "plt.plot(history.history['val_accuracy'], label='val_accuracy')\n",
    "plt.legend()\n",
    "plt.ylim([min(history.history['val_accuracy']+history.history['accuracy']), 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myModel.save('exp904')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}