{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Hyperparameter Tuning Example for Self-Organizing Maps (SOM)\n", "===========================================================\n", "\n", "This script demonstrates how to use the SOM class to perform hyperparameter tuning. By defining a\n", "grid of hyperparameters and systematically testing all combinations, the script identifies the\n", "optimal configuration of hyperparameters for a given dataset.\n", "\n", "Features of This Script:\n", "-------------------------\n", "1. Grid Search Implementation:\n", " - Hyperparameters such as `scale_method`, `x_dim`, `y_dim`, `topology`, `neighborhood_fnc`,\n", " and `epochs` are systematically tested over a predefined range of values.\n", " - The metrics Percent Variance Explained (PVE) and Topographic Error are combined into\n", " a scoring function to evaluate the SOM's performance for each combination.\n", "\n", "2. Ease of Use:\n", " - The script leverages Python's `itertools.product` for a clean and systematic exploration of\n", " hyperparameter combinations.\n", " - Metrics are calculated using the SOM class' built-in methods, making the evaluation process\n", " seamless.\n", "\n", "3. Visualization of Results:\n", " - Once the best hyperparameters are identified, the SOM is retrained, and component planes\n", " and categorical data distributions are visualized and saved.\n", "\n", "Considerations:\n", "---------------\n", "- Time Complexity:\n", " - Depending on the size of the dataset and the number of hyperparameter combinations, this \n", " process may take significant time. Adjust the ranges of the hyperparameters to balance \n", " between thoroughness and computational efficiency.\n", " \n", "- Extensibility:\n", " - The scoring function can be adjusted based on specific requirements. In this example, the \n", " score is computed as PVE minus a scaled Topographic Error.\n", "\n", "Output:\n", "-------\n", "- Best hyperparameters and their resulting score\n", "- Final SOM trained with the best parameters.\n", "- Saved visualizations of component planes and categorical data distributions.\n", "\n", "Usage:\n", "------\n", "Modify the dataset path and hyperparameter grid as needed." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Standard imports\n", "import itertools\n", "import os\n", "import sys\n", "\n", "# Third party imports\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "# Local imports\n", "notebook_dir = os.getcwd()\n", "parent_dir = os.path.abspath(os.path.join(notebook_dir, '..', '..'))\n", "sys.path.append(parent_dir)\n", "from SOM.utils.som_utils import SOM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. Load Data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load data\n", "train_dat_path = os.path.join(parent_dir, 'SOM', 'data', 'titanic_training_data.csv')\n", "other_dat_path = os.path.join(parent_dir, 'SOM', 'data', 'titanic_categorical_data.csv')\n", "\n", "train_dat = pd.read_csv(train_dat_path)\n", "other_dat = pd.read_csv(other_dat_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. Define Hyperparameters" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Define hyperparameter grid\n", "hyperparameter_grid = {\n", " \"scale_method\": [\"zscore\", \"minmax\"],\n", " \"x_dim\": [3, 5, 7],\n", " \"y_dim\": [2, 4, 6],\n", " \"topology\": [\"rectangular\", \"hexagonal\"],\n", " \"neighborhood_fnc\": [\"gaussian\", \"bubble\"],\n", " \"epochs\": [50, 100, 200],\n", "}\n", "\n", "# Initialize variables to store the best parameters and score\n", "best_params = None\n", "best_score = -float(\"inf\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 4. Grid Search for Optimal Parameters" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tested params: ('zscore', 3, 2, 'rectangular', 'gaussian', 50) | Score: 3.07\n", "Tested params: ('zscore', 3, 2, 'rectangular', 'gaussian', 100) | Score: 22.30\n", "Tested params: ('zscore', 3, 2, 'rectangular', 'gaussian', 200) | Score: 3.37\n", "Tested params: ('zscore', 3, 2, 'rectangular', 'bubble', 50) | Score: 55.71\n", "Tested params: ('zscore', 3, 2, 'rectangular', 'bubble', 100) | Score: -29.33\n", "Tested params: ('zscore', 3, 2, 'rectangular', 'bubble', 200) | Score: -19.81\n", "Tested params: ('zscore', 3, 2, 'hexagonal', 'gaussian', 50) | Score: -25.46\n", "Tested params: ('zscore', 3, 2, 'hexagonal', 'gaussian', 100) | Score: 0.15\n", "Tested params: ('zscore', 3, 2, 'hexagonal', 'gaussian', 200) | Score: -11.31\n", "Tested params: ('zscore', 3, 2, 'hexagonal', 'bubble', 50) | Score: -2.13\n", "Tested params: ('zscore', 3, 2, 'hexagonal', 'bubble', 100) | Score: -36.89\n", "Tested params: ('zscore', 3, 2, 'hexagonal', 'bubble', 200) | Score: -24.99\n", "Tested params: ('zscore', 3, 4, 'rectangular', 'gaussian', 50) | Score: 27.55\n", "Tested params: ('zscore', 3, 4, 'rectangular', 'gaussian', 100) | Score: 54.05\n", "Tested params: ('zscore', 3, 4, 'rectangular', 'gaussian', 200) | Score: 56.14\n", "Tested params: ('zscore', 3, 4, 'rectangular', 'bubble', 50) | Score: 21.94\n", "Tested params: ('zscore', 3, 4, 'rectangular', 'bubble', 100) | Score: 6.22\n", "Tested params: ('zscore', 3, 4, 'rectangular', 'bubble', 200) | Score: 10.33\n", "Tested params: ('zscore', 3, 4, 'hexagonal', 'gaussian', 50) | Score: 17.04\n", "Tested params: ('zscore', 3, 4, 'hexagonal', 'gaussian', 100) | Score: 1.76\n", "Tested params: ('zscore', 3, 4, 'hexagonal', 'gaussian', 200) | Score: 24.68\n", "Tested params: ('zscore', 3, 4, 'hexagonal', 'bubble', 50) | Score: -12.24\n", "Tested params: ('zscore', 3, 4, 'hexagonal', 'bubble', 100) | Score: -21.09\n", "Tested params: ('zscore', 3, 4, 'hexagonal', 'bubble', 200) | Score: -9.83\n", "Tested params: ('zscore', 3, 6, 'rectangular', 'gaussian', 50) | Score: 47.69\n", "Tested params: ('zscore', 3, 6, 'rectangular', 'gaussian', 100) | Score: 64.72\n", "Tested params: ('zscore', 3, 6, 'rectangular', 'gaussian', 200) | Score: 41.66\n", "Tested params: ('zscore', 3, 6, 'rectangular', 'bubble', 50) | Score: 10.19\n", "Tested params: ('zscore', 3, 6, 'rectangular', 'bubble', 100) | Score: 2.10\n", "Tested params: ('zscore', 3, 6, 'rectangular', 'bubble', 200) | Score: 15.35\n", "Tested params: ('zscore', 3, 6, 'hexagonal', 'gaussian', 50) | Score: 0.50\n", "Tested params: ('zscore', 3, 6, 'hexagonal', 'gaussian', 100) | Score: -3.38\n", "Tested params: ('zscore', 3, 6, 'hexagonal', 'gaussian', 200) | Score: 15.23\n", "Tested params: ('zscore', 3, 6, 'hexagonal', 'bubble', 50) | Score: -2.27\n", "Tested params: ('zscore', 3, 6, 'hexagonal', 'bubble', 100) | Score: -9.24\n", "Tested params: ('zscore', 3, 6, 'hexagonal', 'bubble', 200) | Score: -9.86\n", "Tested params: ('zscore', 5, 2, 'rectangular', 'gaussian', 50) | Score: 32.92\n", "Tested params: ('zscore', 5, 2, 'rectangular', 'gaussian', 100) | Score: 37.05\n", "Tested params: ('zscore', 5, 2, 'rectangular', 'gaussian', 200) | Score: 20.39\n", "Tested params: ('zscore', 5, 2, 'rectangular', 'bubble', 50) | Score: 14.60\n", "Tested params: ('zscore', 5, 2, 'rectangular', 'bubble', 100) | Score: -3.79\n", "Tested params: ('zscore', 5, 2, 'rectangular', 'bubble', 200) | Score: -9.82\n", "Tested params: ('zscore', 5, 2, 'hexagonal', 'gaussian', 50) | Score: -12.21\n", "Tested params: ('zscore', 5, 2, 'hexagonal', 'gaussian', 100) | Score: 8.12\n", "Tested params: ('zscore', 5, 2, 'hexagonal', 'gaussian', 200) | Score: -17.13\n", "Tested params: ('zscore', 5, 2, 'hexagonal', 'bubble', 50) | Score: -0.94\n", "Tested params: ('zscore', 5, 2, 'hexagonal', 'bubble', 100) | Score: -8.55\n", "Tested params: ('zscore', 5, 2, 'hexagonal', 'bubble', 200) | Score: -9.82\n", "Tested params: ('zscore', 5, 4, 'rectangular', 'gaussian', 50) | Score: 46.99\n", "Tested params: ('zscore', 5, 4, 'rectangular', 'gaussian', 100) | Score: 60.28\n", "Tested params: ('zscore', 5, 4, 'rectangular', 'gaussian', 200) | Score: 60.30\n", "Tested params: ('zscore', 5, 4, 'rectangular', 'bubble', 50) | Score: -0.35\n", "Tested params: ('zscore', 5, 4, 'rectangular', 'bubble', 100) | Score: -6.41\n", "Tested params: ('zscore', 5, 4, 'rectangular', 'bubble', 200) | Score: -4.08\n", "Tested params: ('zscore', 5, 4, 'hexagonal', 'gaussian', 50) | Score: 23.32\n", "Tested params: ('zscore', 5, 4, 'hexagonal', 'gaussian', 100) | Score: 29.52\n", "Tested params: ('zscore', 5, 4, 'hexagonal', 'gaussian', 200) | Score: 11.27\n", "Tested params: ('zscore', 5, 4, 'hexagonal', 'bubble', 50) | Score: -5.53\n", "Tested params: ('zscore', 5, 4, 'hexagonal', 'bubble', 100) | Score: -12.43\n", "Tested params: ('zscore', 5, 4, 'hexagonal', 'bubble', 200) | Score: -8.00\n", "Tested params: ('zscore', 5, 6, 'rectangular', 'gaussian', 50) | Score: 66.65\n", "Tested params: ('zscore', 5, 6, 'rectangular', 'gaussian', 100) | Score: 72.04\n", "Tested params: ('zscore', 5, 6, 'rectangular', 'gaussian', 200) | Score: 61.56\n", "Tested params: ('zscore', 5, 6, 'rectangular', 'bubble', 50) | Score: 17.12\n", "Tested params: ('zscore', 5, 6, 'rectangular', 'bubble', 100) | Score: 4.29\n", "Tested params: ('zscore', 5, 6, 'rectangular', 'bubble', 200) | Score: 11.11\n", "Tested params: ('zscore', 5, 6, 'hexagonal', 'gaussian', 50) | Score: 46.63\n", "Tested params: ('zscore', 5, 6, 'hexagonal', 'gaussian', 100) | Score: 17.17\n", "Tested params: ('zscore', 5, 6, 'hexagonal', 'gaussian', 200) | Score: 24.15\n", "Tested params: ('zscore', 5, 6, 'hexagonal', 'bubble', 50) | Score: -5.84\n", "Tested params: ('zscore', 5, 6, 'hexagonal', 'bubble', 100) | Score: -1.17\n", "Tested params: ('zscore', 5, 6, 'hexagonal', 'bubble', 200) | Score: 0.19\n", "Tested params: ('zscore', 7, 2, 'rectangular', 'gaussian', 50) | Score: 56.91\n", "Tested params: ('zscore', 7, 2, 'rectangular', 'gaussian', 100) | Score: 46.95\n", "Tested params: ('zscore', 7, 2, 'rectangular', 'gaussian', 200) | Score: 49.48\n", "Tested params: ('zscore', 7, 2, 'rectangular', 'bubble', 50) | Score: 18.53\n", "Tested params: ('zscore', 7, 2, 'rectangular', 'bubble', 100) | Score: 1.30\n", "Tested params: ('zscore', 7, 2, 'rectangular', 'bubble', 200) | Score: 6.63\n", "Tested params: ('zscore', 7, 2, 'hexagonal', 'gaussian', 50) | Score: 16.15\n", "Tested params: ('zscore', 7, 2, 'hexagonal', 'gaussian', 100) | Score: 23.12\n", "Tested params: ('zscore', 7, 2, 'hexagonal', 'gaussian', 200) | Score: 25.97\n", "Tested params: ('zscore', 7, 2, 'hexagonal', 'bubble', 50) | Score: 16.15\n", "Tested params: ('zscore', 7, 2, 'hexagonal', 'bubble', 100) | Score: -10.75\n", "Tested params: ('zscore', 7, 2, 'hexagonal', 'bubble', 200) | Score: -2.47\n", "Tested params: ('zscore', 7, 4, 'rectangular', 'gaussian', 50) | Score: 60.97\n", "Tested params: ('zscore', 7, 4, 'rectangular', 'gaussian', 100) | Score: 70.44\n", "Tested params: ('zscore', 7, 4, 'rectangular', 'gaussian', 200) | Score: 38.54\n", "Tested params: ('zscore', 7, 4, 'rectangular', 'bubble', 50) | Score: -7.36\n", "Tested params: ('zscore', 7, 4, 'rectangular', 'bubble', 100) | Score: -4.40\n", "Tested params: ('zscore', 7, 4, 'rectangular', 'bubble', 200) | Score: 33.14\n", "Tested params: ('zscore', 7, 4, 'hexagonal', 'gaussian', 50) | Score: 35.18\n", "Tested params: ('zscore', 7, 4, 'hexagonal', 'gaussian', 100) | Score: 15.62\n", "Tested params: ('zscore', 7, 4, 'hexagonal', 'gaussian', 200) | Score: 38.17\n", "Tested params: ('zscore', 7, 4, 'hexagonal', 'bubble', 50) | Score: -8.62\n", "Tested params: ('zscore', 7, 4, 'hexagonal', 'bubble', 100) | Score: -8.88\n", "Tested params: ('zscore', 7, 4, 'hexagonal', 'bubble', 200) | Score: 27.96\n", "Tested params: ('zscore', 7, 6, 'rectangular', 'gaussian', 50) | Score: 64.22\n", "Tested params: ('zscore', 7, 6, 'rectangular', 'gaussian', 100) | Score: 73.20\n", "Tested params: ('zscore', 7, 6, 'rectangular', 'gaussian', 200) | Score: 72.54\n", "Tested params: ('zscore', 7, 6, 'rectangular', 'bubble', 50) | Score: 16.17\n", "Tested params: ('zscore', 7, 6, 'rectangular', 'bubble', 100) | Score: -0.89\n", "Tested params: ('zscore', 7, 6, 'rectangular', 'bubble', 200) | Score: 0.65\n", "Tested params: ('zscore', 7, 6, 'hexagonal', 'gaussian', 50) | Score: 31.35\n", "Tested params: ('zscore', 7, 6, 'hexagonal', 'gaussian', 100) | Score: 7.94\n", "Tested params: ('zscore', 7, 6, 'hexagonal', 'gaussian', 200) | Score: 17.92\n", "Tested params: ('zscore', 7, 6, 'hexagonal', 'bubble', 50) | Score: 8.89\n", "Tested params: ('zscore', 7, 6, 'hexagonal', 'bubble', 100) | Score: -4.53\n", "Tested params: ('zscore', 7, 6, 'hexagonal', 'bubble', 200) | Score: -0.75\n", "Tested params: ('minmax', 3, 2, 'rectangular', 'gaussian', 50) | Score: 54.92\n", "Tested params: ('minmax', 3, 2, 'rectangular', 'gaussian', 100) | Score: 73.38\n", "Tested params: ('minmax', 3, 2, 'rectangular', 'gaussian', 200) | Score: 77.83\n", "Tested params: ('minmax', 3, 2, 'rectangular', 'bubble', 50) | Score: 43.36\n", "Tested params: ('minmax', 3, 2, 'rectangular', 'bubble', 100) | Score: 69.40\n", "Tested params: ('minmax', 3, 2, 'rectangular', 'bubble', 200) | Score: -16.27\n", "Tested params: ('minmax', 3, 2, 'hexagonal', 'gaussian', 50) | Score: 37.99\n", "Tested params: ('minmax', 3, 2, 'hexagonal', 'gaussian', 100) | Score: 3.66\n", "Tested params: ('minmax', 3, 2, 'hexagonal', 'gaussian', 200) | Score: -14.72\n", "Tested params: ('minmax', 3, 2, 'hexagonal', 'bubble', 50) | Score: 35.37\n", "Tested params: ('minmax', 3, 2, 'hexagonal', 'bubble', 100) | Score: 32.29\n", "Tested params: ('minmax', 3, 2, 'hexagonal', 'bubble', 200) | Score: -16.55\n", "Tested params: ('minmax', 3, 4, 'rectangular', 'gaussian', 50) | Score: 71.42\n", "Tested params: ('minmax', 3, 4, 'rectangular', 'gaussian', 100) | Score: 54.54\n", "Tested params: ('minmax', 3, 4, 'rectangular', 'gaussian', 200) | Score: 67.86\n", "Tested params: ('minmax', 3, 4, 'rectangular', 'bubble', 50) | Score: -4.85\n", "Tested params: ('minmax', 3, 4, 'rectangular', 'bubble', 100) | Score: 40.56\n", "Tested params: ('minmax', 3, 4, 'rectangular', 'bubble', 200) | Score: -5.09\n", "Tested params: ('minmax', 3, 4, 'hexagonal', 'gaussian', 50) | Score: 1.15\n", "Tested params: ('minmax', 3, 4, 'hexagonal', 'gaussian', 100) | Score: 56.80\n", "Tested params: ('minmax', 3, 4, 'hexagonal', 'gaussian', 200) | Score: 27.91\n", "Tested params: ('minmax', 3, 4, 'hexagonal', 'bubble', 50) | Score: -14.38\n", "Tested params: ('minmax', 3, 4, 'hexagonal', 'bubble', 100) | Score: 10.45\n", "Tested params: ('minmax', 3, 4, 'hexagonal', 'bubble', 200) | Score: -9.01\n", "Tested params: ('minmax', 3, 6, 'rectangular', 'gaussian', 50) | Score: 72.05\n", "Tested params: ('minmax', 3, 6, 'rectangular', 'gaussian', 100) | Score: 75.45\n", "Tested params: ('minmax', 3, 6, 'rectangular', 'gaussian', 200) | Score: 67.80\n", "Tested params: ('minmax', 3, 6, 'rectangular', 'bubble', 50) | Score: 29.96\n", "Tested params: ('minmax', 3, 6, 'rectangular', 'bubble', 100) | Score: 10.18\n", "Tested params: ('minmax', 3, 6, 'rectangular', 'bubble', 200) | Score: -0.06\n", "Tested params: ('minmax', 3, 6, 'hexagonal', 'gaussian', 50) | Score: 10.09\n", "Tested params: ('minmax', 3, 6, 'hexagonal', 'gaussian', 100) | Score: 30.30\n", "Tested params: ('minmax', 3, 6, 'hexagonal', 'gaussian', 200) | Score: 17.54\n", "Tested params: ('minmax', 3, 6, 'hexagonal', 'bubble', 50) | Score: -6.74\n", "Tested params: ('minmax', 3, 6, 'hexagonal', 'bubble', 100) | Score: -8.72\n", "Tested params: ('minmax', 3, 6, 'hexagonal', 'bubble', 200) | Score: -12.10\n", "Tested params: ('minmax', 5, 2, 'rectangular', 'gaussian', 50) | Score: 79.85\n", "Tested params: ('minmax', 5, 2, 'rectangular', 'gaussian', 100) | Score: 62.12\n", "Tested params: ('minmax', 5, 2, 'rectangular', 'gaussian', 200) | Score: 65.07\n", "Tested params: ('minmax', 5, 2, 'rectangular', 'bubble', 50) | Score: 41.49\n", "Tested params: ('minmax', 5, 2, 'rectangular', 'bubble', 100) | Score: 46.84\n", "Tested params: ('minmax', 5, 2, 'rectangular', 'bubble', 200) | Score: -17.90\n", "Tested params: ('minmax', 5, 2, 'hexagonal', 'gaussian', 50) | Score: -0.18\n", "Tested params: ('minmax', 5, 2, 'hexagonal', 'gaussian', 100) | Score: 16.58\n", "Tested params: ('minmax', 5, 2, 'hexagonal', 'gaussian', 200) | Score: -4.60\n", "Tested params: ('minmax', 5, 2, 'hexagonal', 'bubble', 50) | Score: 41.49\n", "Tested params: ('minmax', 5, 2, 'hexagonal', 'bubble', 100) | Score: 44.18\n", "Tested params: ('minmax', 5, 2, 'hexagonal', 'bubble', 200) | Score: -18.46\n", "Tested params: ('minmax', 5, 4, 'rectangular', 'gaussian', 50) | Score: 66.47\n", "Tested params: ('minmax', 5, 4, 'rectangular', 'gaussian', 100) | Score: 72.90\n", "Tested params: ('minmax', 5, 4, 'rectangular', 'gaussian', 200) | Score: 68.85\n", "Tested params: ('minmax', 5, 4, 'rectangular', 'bubble', 50) | Score: 71.66\n", "Tested params: ('minmax', 5, 4, 'rectangular', 'bubble', 100) | Score: 43.72\n", "Tested params: ('minmax', 5, 4, 'rectangular', 'bubble', 200) | Score: 61.40\n", "Tested params: ('minmax', 5, 4, 'hexagonal', 'gaussian', 50) | Score: 16.62\n", "Tested params: ('minmax', 5, 4, 'hexagonal', 'gaussian', 100) | Score: 34.27\n", "Tested params: ('minmax', 5, 4, 'hexagonal', 'gaussian', 200) | Score: 21.11\n", "Tested params: ('minmax', 5, 4, 'hexagonal', 'bubble', 50) | Score: -9.99\n", "Tested params: ('minmax', 5, 4, 'hexagonal', 'bubble', 100) | Score: -2.36\n", "Tested params: ('minmax', 5, 4, 'hexagonal', 'bubble', 200) | Score: -1.21\n", "Tested params: ('minmax', 5, 6, 'rectangular', 'gaussian', 50) | Score: 60.75\n", "Tested params: ('minmax', 5, 6, 'rectangular', 'gaussian', 100) | Score: 78.77\n", "Tested params: ('minmax', 5, 6, 'rectangular', 'gaussian', 200) | Score: 58.31\n", "Tested params: ('minmax', 5, 6, 'rectangular', 'bubble', 50) | Score: 28.64\n", "Tested params: ('minmax', 5, 6, 'rectangular', 'bubble', 100) | Score: 17.27\n", "Tested params: ('minmax', 5, 6, 'rectangular', 'bubble', 200) | Score: 19.81\n", "Tested params: ('minmax', 5, 6, 'hexagonal', 'gaussian', 50) | Score: 25.46\n", "Tested params: ('minmax', 5, 6, 'hexagonal', 'gaussian', 100) | Score: 35.53\n", "Tested params: ('minmax', 5, 6, 'hexagonal', 'gaussian', 200) | Score: 13.92\n", "Tested params: ('minmax', 5, 6, 'hexagonal', 'bubble', 50) | Score: 12.11\n", "Tested params: ('minmax', 5, 6, 'hexagonal', 'bubble', 100) | Score: 5.22\n", "Tested params: ('minmax', 5, 6, 'hexagonal', 'bubble', 200) | Score: 10.57\n", "Tested params: ('minmax', 7, 2, 'rectangular', 'gaussian', 50) | Score: 72.60\n", "Tested params: ('minmax', 7, 2, 'rectangular', 'gaussian', 100) | Score: 72.86\n", "Tested params: ('minmax', 7, 2, 'rectangular', 'gaussian', 200) | Score: 51.63\n", "Tested params: ('minmax', 7, 2, 'rectangular', 'bubble', 50) | Score: 16.66\n", "Tested params: ('minmax', 7, 2, 'rectangular', 'bubble', 100) | Score: 10.83\n", "Tested params: ('minmax', 7, 2, 'rectangular', 'bubble', 200) | Score: 15.46\n", "Tested params: ('minmax', 7, 2, 'hexagonal', 'gaussian', 50) | Score: 12.05\n", "Tested params: ('minmax', 7, 2, 'hexagonal', 'gaussian', 100) | Score: 43.06\n", "Tested params: ('minmax', 7, 2, 'hexagonal', 'gaussian', 200) | Score: 17.09\n", "Tested params: ('minmax', 7, 2, 'hexagonal', 'bubble', 50) | Score: 16.52\n", "Tested params: ('minmax', 7, 2, 'hexagonal', 'bubble', 100) | Score: 10.83\n", "Tested params: ('minmax', 7, 2, 'hexagonal', 'bubble', 200) | Score: 15.46\n", "Tested params: ('minmax', 7, 4, 'rectangular', 'gaussian', 50) | Score: 65.05\n", "Tested params: ('minmax', 7, 4, 'rectangular', 'gaussian', 100) | Score: 70.15\n", "Tested params: ('minmax', 7, 4, 'rectangular', 'gaussian', 200) | Score: 80.84\n", "Tested params: ('minmax', 7, 4, 'rectangular', 'bubble', 50) | Score: 20.95\n", "Tested params: ('minmax', 7, 4, 'rectangular', 'bubble', 100) | Score: 67.74\n", "Tested params: ('minmax', 7, 4, 'rectangular', 'bubble', 200) | Score: 12.57\n", "Tested params: ('minmax', 7, 4, 'hexagonal', 'gaussian', 50) | Score: 10.16\n", "Tested params: ('minmax', 7, 4, 'hexagonal', 'gaussian', 100) | Score: 5.11\n", "Tested params: ('minmax', 7, 4, 'hexagonal', 'gaussian', 200) | Score: 48.04\n", "Tested params: ('minmax', 7, 4, 'hexagonal', 'bubble', 50) | Score: 2.60\n", "Tested params: ('minmax', 7, 4, 'hexagonal', 'bubble', 100) | Score: 30.91\n", "Tested params: ('minmax', 7, 4, 'hexagonal', 'bubble', 200) | Score: -8.15\n", "Tested params: ('minmax', 7, 6, 'rectangular', 'gaussian', 50) | Score: 70.10\n", "Tested params: ('minmax', 7, 6, 'rectangular', 'gaussian', 100) | Score: 46.53\n", "Tested params: ('minmax', 7, 6, 'rectangular', 'gaussian', 200) | Score: 76.17\n", "Tested params: ('minmax', 7, 6, 'rectangular', 'bubble', 50) | Score: 41.95\n", "Tested params: ('minmax', 7, 6, 'rectangular', 'bubble', 100) | Score: 8.98\n", "Tested params: ('minmax', 7, 6, 'rectangular', 'bubble', 200) | Score: 9.69\n", "Tested params: ('minmax', 7, 6, 'hexagonal', 'gaussian', 50) | Score: 16.42\n", "Tested params: ('minmax', 7, 6, 'hexagonal', 'gaussian', 100) | Score: 19.06\n", "Tested params: ('minmax', 7, 6, 'hexagonal', 'gaussian', 200) | Score: 22.83\n", "Tested params: ('minmax', 7, 6, 'hexagonal', 'bubble', 50) | Score: -4.27\n", "Tested params: ('minmax', 7, 6, 'hexagonal', 'bubble', 100) | Score: -5.44\n", "Tested params: ('minmax', 7, 6, 'hexagonal', 'bubble', 200) | Score: 0.58\n" ] } ], "source": [ "for params in itertools.product(\n", " hyperparameter_grid[\"scale_method\"],\n", " hyperparameter_grid[\"x_dim\"],\n", " hyperparameter_grid[\"y_dim\"],\n", " hyperparameter_grid[\"topology\"],\n", " hyperparameter_grid[\"neighborhood_fnc\"],\n", " hyperparameter_grid[\"epochs\"],\n", "):\n", " scale_method, x_dim, y_dim, topology, neighborhood_fnc, epochs = params\n", "\n", " # Train SOM with the current hyperparameter combination\n", " som = SOM(\n", " train_dat=train_dat,\n", " other_dat=other_dat,\n", " scale_method=scale_method,\n", " x_dim=x_dim,\n", " y_dim=y_dim,\n", " topology=topology,\n", " neighborhood_fnc=neighborhood_fnc,\n", " epochs=epochs,\n", " )\n", " som.train_map()\n", "\n", " # Calculate evaluation metrics\n", " pve = som.calculate_percent_variance_explained()\n", " topographic_error = som.calculate_topographic_error()\n", "\n", " # Combine metrics into a single score (higher PVE and lower error are better)\n", " score = pve - topographic_error * 100\n", "\n", " print(f\"Tested params: {params} | Score: {score:.2f}\")\n", "\n", " # Update the best parameters if the current score is better\n", " if score > best_score:\n", " best_score = score\n", " best_params = params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 5. Show the best parameters" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Best Hyperparameters:\n", "Scale Method: minmax\n", "x_dim: 7\n", "y_dim: 4\n", "Topology: rectangular\n", "Neighborhood Function: gaussian\n", "Epochs: 200\n", "Best Score: 80.84\n" ] } ], "source": [ "# Output the best hyperparameters\n", "print(\"\\nBest Hyperparameters:\")\n", "print(f\"Scale Method: {best_params[0]}\")\n", "print(f\"x_dim: {best_params[1]}\")\n", "print(f\"y_dim: {best_params[2]}\")\n", "print(f\"Topology: {best_params[3]}\")\n", "print(f\"Neighborhood Function: {best_params[4]}\")\n", "print(f\"Epochs: {best_params[5]}\")\n", "print(f\"Best Score: {best_score:.2f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 6. Train SOM" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Train and visualize the best SOM\n", "best_som = SOM(\n", " train_dat=train_dat,\n", " other_dat=other_dat,\n", " scale_method=best_params[0],\n", " x_dim=best_params[1],\n", " y_dim=best_params[2],\n", " topology=best_params[3],\n", " neighborhood_fnc=best_params[4],\n", " epochs=best_params[5],\n", ")\n", "best_som.train_map()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 7. Get the Fit Metrics" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Final SOM Performance:\n", "Percent variance explained = 94.28111754749402%\n", "Topographic error = 0.13445378151260504\n" ] } ], "source": [ "# Get fit metrics\n", "pve = best_som.calculate_percent_variance_explained()\n", "topographic_error = best_som.calculate_topographic_error()\n", "\n", "print(\"\\nFinal SOM Performance:\")\n", "print(f\"Percent variance explained = {pve}%\")\n", "print(f\"Topographic error = {topographic_error}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 8. Plot component planes and categorical data (see output directory for figures)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Plot component planes\n", "best_som.plot_component_planes(output_dir=\"output/titanic\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "chen5150", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.13" } }, "nbformat": 4, "nbformat_minor": 4 }