Hyperparameter Tuning Example for Self-Organizing Maps (SOM)¶
This script demonstrates how to use the SOM class to perform hyperparameter tuning. By defining a grid of hyperparameters and systematically testing all combinations, the script identifies the optimal configuration of hyperparameters for a given dataset.
Features of This Script:¶
Grid Search Implementation:
Hyperparameters such as
scale_method,x_dim,y_dim,topology,neighborhood_fnc, andepochsare systematically tested over a predefined range of values.The metrics Percent Variance Explained (PVE) and Topographic Error are combined into a scoring function to evaluate the SOM’s performance for each combination.
Ease of Use:
The script leverages Python’s
itertools.productfor a clean and systematic exploration of hyperparameter combinations.Metrics are calculated using the SOM class’ built-in methods, making the evaluation process seamless.
Visualization of Results:
Once the best hyperparameters are identified, the SOM is retrained, and component planes and categorical data distributions are visualized and saved.
Considerations:¶
Time Complexity:
Depending on the size of the dataset and the number of hyperparameter combinations, this process may take significant time. Adjust the ranges of the hyperparameters to balance between thoroughness and computational efficiency.
Extensibility:
The scoring function can be adjusted based on specific requirements. In this example, the score is computed as PVE minus a scaled Topographic Error.
Output:¶
Best hyperparameters and their resulting score
Final SOM trained with the best parameters.
Saved visualizations of component planes and categorical data distributions.
Usage:¶
Modify the dataset path and hyperparameter grid as needed.
1. Imports¶
[1]:
# Standard imports
import itertools
import os
import sys
# Third party imports
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# Local imports
notebook_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(notebook_dir, '..', '..'))
sys.path.append(parent_dir)
from SOM.utils.som_utils import SOM
2. Load Data¶
[2]:
# Load data
train_dat_path = os.path.join(parent_dir, 'SOM', 'data', 'titanic_training_data.csv')
other_dat_path = os.path.join(parent_dir, 'SOM', 'data', 'titanic_categorical_data.csv')
train_dat = pd.read_csv(train_dat_path)
other_dat = pd.read_csv(other_dat_path)
3. Define Hyperparameters¶
[3]:
# Define hyperparameter grid
hyperparameter_grid = {
"scale_method": ["zscore", "minmax"],
"x_dim": [3, 5, 7],
"y_dim": [2, 4, 6],
"topology": ["rectangular", "hexagonal"],
"neighborhood_fnc": ["gaussian", "bubble"],
"epochs": [50, 100, 200],
}
# Initialize variables to store the best parameters and score
best_params = None
best_score = -float("inf")
4. Grid Search for Optimal Parameters¶
[4]:
for params in itertools.product(
hyperparameter_grid["scale_method"],
hyperparameter_grid["x_dim"],
hyperparameter_grid["y_dim"],
hyperparameter_grid["topology"],
hyperparameter_grid["neighborhood_fnc"],
hyperparameter_grid["epochs"],
):
scale_method, x_dim, y_dim, topology, neighborhood_fnc, epochs = params
# Train SOM with the current hyperparameter combination
som = SOM(
train_dat=train_dat,
other_dat=other_dat,
scale_method=scale_method,
x_dim=x_dim,
y_dim=y_dim,
topology=topology,
neighborhood_fnc=neighborhood_fnc,
epochs=epochs,
)
som.train_map()
# Calculate evaluation metrics
pve = som.calculate_percent_variance_explained()
topographic_error = som.calculate_topographic_error()
# Combine metrics into a single score (higher PVE and lower error are better)
score = pve - topographic_error * 100
print(f"Tested params: {params} | Score: {score:.2f}")
# Update the best parameters if the current score is better
if score > best_score:
best_score = score
best_params = params
Tested params: ('zscore', 3, 2, 'rectangular', 'gaussian', 50) | Score: 3.07
Tested params: ('zscore', 3, 2, 'rectangular', 'gaussian', 100) | Score: 22.30
Tested params: ('zscore', 3, 2, 'rectangular', 'gaussian', 200) | Score: 3.37
Tested params: ('zscore', 3, 2, 'rectangular', 'bubble', 50) | Score: 55.71
Tested params: ('zscore', 3, 2, 'rectangular', 'bubble', 100) | Score: -29.33
Tested params: ('zscore', 3, 2, 'rectangular', 'bubble', 200) | Score: -19.81
Tested params: ('zscore', 3, 2, 'hexagonal', 'gaussian', 50) | Score: -25.46
Tested params: ('zscore', 3, 2, 'hexagonal', 'gaussian', 100) | Score: 0.15
Tested params: ('zscore', 3, 2, 'hexagonal', 'gaussian', 200) | Score: -11.31
Tested params: ('zscore', 3, 2, 'hexagonal', 'bubble', 50) | Score: -2.13
Tested params: ('zscore', 3, 2, 'hexagonal', 'bubble', 100) | Score: -36.89
Tested params: ('zscore', 3, 2, 'hexagonal', 'bubble', 200) | Score: -24.99
Tested params: ('zscore', 3, 4, 'rectangular', 'gaussian', 50) | Score: 27.55
Tested params: ('zscore', 3, 4, 'rectangular', 'gaussian', 100) | Score: 54.05
Tested params: ('zscore', 3, 4, 'rectangular', 'gaussian', 200) | Score: 56.14
Tested params: ('zscore', 3, 4, 'rectangular', 'bubble', 50) | Score: 21.94
Tested params: ('zscore', 3, 4, 'rectangular', 'bubble', 100) | Score: 6.22
Tested params: ('zscore', 3, 4, 'rectangular', 'bubble', 200) | Score: 10.33
Tested params: ('zscore', 3, 4, 'hexagonal', 'gaussian', 50) | Score: 17.04
Tested params: ('zscore', 3, 4, 'hexagonal', 'gaussian', 100) | Score: 1.76
Tested params: ('zscore', 3, 4, 'hexagonal', 'gaussian', 200) | Score: 24.68
Tested params: ('zscore', 3, 4, 'hexagonal', 'bubble', 50) | Score: -12.24
Tested params: ('zscore', 3, 4, 'hexagonal', 'bubble', 100) | Score: -21.09
Tested params: ('zscore', 3, 4, 'hexagonal', 'bubble', 200) | Score: -9.83
Tested params: ('zscore', 3, 6, 'rectangular', 'gaussian', 50) | Score: 47.69
Tested params: ('zscore', 3, 6, 'rectangular', 'gaussian', 100) | Score: 64.72
Tested params: ('zscore', 3, 6, 'rectangular', 'gaussian', 200) | Score: 41.66
Tested params: ('zscore', 3, 6, 'rectangular', 'bubble', 50) | Score: 10.19
Tested params: ('zscore', 3, 6, 'rectangular', 'bubble', 100) | Score: 2.10
Tested params: ('zscore', 3, 6, 'rectangular', 'bubble', 200) | Score: 15.35
Tested params: ('zscore', 3, 6, 'hexagonal', 'gaussian', 50) | Score: 0.50
Tested params: ('zscore', 3, 6, 'hexagonal', 'gaussian', 100) | Score: -3.38
Tested params: ('zscore', 3, 6, 'hexagonal', 'gaussian', 200) | Score: 15.23
Tested params: ('zscore', 3, 6, 'hexagonal', 'bubble', 50) | Score: -2.27
Tested params: ('zscore', 3, 6, 'hexagonal', 'bubble', 100) | Score: -9.24
Tested params: ('zscore', 3, 6, 'hexagonal', 'bubble', 200) | Score: -9.86
Tested params: ('zscore', 5, 2, 'rectangular', 'gaussian', 50) | Score: 32.92
Tested params: ('zscore', 5, 2, 'rectangular', 'gaussian', 100) | Score: 37.05
Tested params: ('zscore', 5, 2, 'rectangular', 'gaussian', 200) | Score: 20.39
Tested params: ('zscore', 5, 2, 'rectangular', 'bubble', 50) | Score: 14.60
Tested params: ('zscore', 5, 2, 'rectangular', 'bubble', 100) | Score: -3.79
Tested params: ('zscore', 5, 2, 'rectangular', 'bubble', 200) | Score: -9.82
Tested params: ('zscore', 5, 2, 'hexagonal', 'gaussian', 50) | Score: -12.21
Tested params: ('zscore', 5, 2, 'hexagonal', 'gaussian', 100) | Score: 8.12
Tested params: ('zscore', 5, 2, 'hexagonal', 'gaussian', 200) | Score: -17.13
Tested params: ('zscore', 5, 2, 'hexagonal', 'bubble', 50) | Score: -0.94
Tested params: ('zscore', 5, 2, 'hexagonal', 'bubble', 100) | Score: -8.55
Tested params: ('zscore', 5, 2, 'hexagonal', 'bubble', 200) | Score: -9.82
Tested params: ('zscore', 5, 4, 'rectangular', 'gaussian', 50) | Score: 46.99
Tested params: ('zscore', 5, 4, 'rectangular', 'gaussian', 100) | Score: 60.28
Tested params: ('zscore', 5, 4, 'rectangular', 'gaussian', 200) | Score: 60.30
Tested params: ('zscore', 5, 4, 'rectangular', 'bubble', 50) | Score: -0.35
Tested params: ('zscore', 5, 4, 'rectangular', 'bubble', 100) | Score: -6.41
Tested params: ('zscore', 5, 4, 'rectangular', 'bubble', 200) | Score: -4.08
Tested params: ('zscore', 5, 4, 'hexagonal', 'gaussian', 50) | Score: 23.32
Tested params: ('zscore', 5, 4, 'hexagonal', 'gaussian', 100) | Score: 29.52
Tested params: ('zscore', 5, 4, 'hexagonal', 'gaussian', 200) | Score: 11.27
Tested params: ('zscore', 5, 4, 'hexagonal', 'bubble', 50) | Score: -5.53
Tested params: ('zscore', 5, 4, 'hexagonal', 'bubble', 100) | Score: -12.43
Tested params: ('zscore', 5, 4, 'hexagonal', 'bubble', 200) | Score: -8.00
Tested params: ('zscore', 5, 6, 'rectangular', 'gaussian', 50) | Score: 66.65
Tested params: ('zscore', 5, 6, 'rectangular', 'gaussian', 100) | Score: 72.04
Tested params: ('zscore', 5, 6, 'rectangular', 'gaussian', 200) | Score: 61.56
Tested params: ('zscore', 5, 6, 'rectangular', 'bubble', 50) | Score: 17.12
Tested params: ('zscore', 5, 6, 'rectangular', 'bubble', 100) | Score: 4.29
Tested params: ('zscore', 5, 6, 'rectangular', 'bubble', 200) | Score: 11.11
Tested params: ('zscore', 5, 6, 'hexagonal', 'gaussian', 50) | Score: 46.63
Tested params: ('zscore', 5, 6, 'hexagonal', 'gaussian', 100) | Score: 17.17
Tested params: ('zscore', 5, 6, 'hexagonal', 'gaussian', 200) | Score: 24.15
Tested params: ('zscore', 5, 6, 'hexagonal', 'bubble', 50) | Score: -5.84
Tested params: ('zscore', 5, 6, 'hexagonal', 'bubble', 100) | Score: -1.17
Tested params: ('zscore', 5, 6, 'hexagonal', 'bubble', 200) | Score: 0.19
Tested params: ('zscore', 7, 2, 'rectangular', 'gaussian', 50) | Score: 56.91
Tested params: ('zscore', 7, 2, 'rectangular', 'gaussian', 100) | Score: 46.95
Tested params: ('zscore', 7, 2, 'rectangular', 'gaussian', 200) | Score: 49.48
Tested params: ('zscore', 7, 2, 'rectangular', 'bubble', 50) | Score: 18.53
Tested params: ('zscore', 7, 2, 'rectangular', 'bubble', 100) | Score: 1.30
Tested params: ('zscore', 7, 2, 'rectangular', 'bubble', 200) | Score: 6.63
Tested params: ('zscore', 7, 2, 'hexagonal', 'gaussian', 50) | Score: 16.15
Tested params: ('zscore', 7, 2, 'hexagonal', 'gaussian', 100) | Score: 23.12
Tested params: ('zscore', 7, 2, 'hexagonal', 'gaussian', 200) | Score: 25.97
Tested params: ('zscore', 7, 2, 'hexagonal', 'bubble', 50) | Score: 16.15
Tested params: ('zscore', 7, 2, 'hexagonal', 'bubble', 100) | Score: -10.75
Tested params: ('zscore', 7, 2, 'hexagonal', 'bubble', 200) | Score: -2.47
Tested params: ('zscore', 7, 4, 'rectangular', 'gaussian', 50) | Score: 60.97
Tested params: ('zscore', 7, 4, 'rectangular', 'gaussian', 100) | Score: 70.44
Tested params: ('zscore', 7, 4, 'rectangular', 'gaussian', 200) | Score: 38.54
Tested params: ('zscore', 7, 4, 'rectangular', 'bubble', 50) | Score: -7.36
Tested params: ('zscore', 7, 4, 'rectangular', 'bubble', 100) | Score: -4.40
Tested params: ('zscore', 7, 4, 'rectangular', 'bubble', 200) | Score: 33.14
Tested params: ('zscore', 7, 4, 'hexagonal', 'gaussian', 50) | Score: 35.18
Tested params: ('zscore', 7, 4, 'hexagonal', 'gaussian', 100) | Score: 15.62
Tested params: ('zscore', 7, 4, 'hexagonal', 'gaussian', 200) | Score: 38.17
Tested params: ('zscore', 7, 4, 'hexagonal', 'bubble', 50) | Score: -8.62
Tested params: ('zscore', 7, 4, 'hexagonal', 'bubble', 100) | Score: -8.88
Tested params: ('zscore', 7, 4, 'hexagonal', 'bubble', 200) | Score: 27.96
Tested params: ('zscore', 7, 6, 'rectangular', 'gaussian', 50) | Score: 64.22
Tested params: ('zscore', 7, 6, 'rectangular', 'gaussian', 100) | Score: 73.20
Tested params: ('zscore', 7, 6, 'rectangular', 'gaussian', 200) | Score: 72.54
Tested params: ('zscore', 7, 6, 'rectangular', 'bubble', 50) | Score: 16.17
Tested params: ('zscore', 7, 6, 'rectangular', 'bubble', 100) | Score: -0.89
Tested params: ('zscore', 7, 6, 'rectangular', 'bubble', 200) | Score: 0.65
Tested params: ('zscore', 7, 6, 'hexagonal', 'gaussian', 50) | Score: 31.35
Tested params: ('zscore', 7, 6, 'hexagonal', 'gaussian', 100) | Score: 7.94
Tested params: ('zscore', 7, 6, 'hexagonal', 'gaussian', 200) | Score: 17.92
Tested params: ('zscore', 7, 6, 'hexagonal', 'bubble', 50) | Score: 8.89
Tested params: ('zscore', 7, 6, 'hexagonal', 'bubble', 100) | Score: -4.53
Tested params: ('zscore', 7, 6, 'hexagonal', 'bubble', 200) | Score: -0.75
Tested params: ('minmax', 3, 2, 'rectangular', 'gaussian', 50) | Score: 54.92
Tested params: ('minmax', 3, 2, 'rectangular', 'gaussian', 100) | Score: 73.38
Tested params: ('minmax', 3, 2, 'rectangular', 'gaussian', 200) | Score: 77.83
Tested params: ('minmax', 3, 2, 'rectangular', 'bubble', 50) | Score: 43.36
Tested params: ('minmax', 3, 2, 'rectangular', 'bubble', 100) | Score: 69.40
Tested params: ('minmax', 3, 2, 'rectangular', 'bubble', 200) | Score: -16.27
Tested params: ('minmax', 3, 2, 'hexagonal', 'gaussian', 50) | Score: 37.99
Tested params: ('minmax', 3, 2, 'hexagonal', 'gaussian', 100) | Score: 3.66
Tested params: ('minmax', 3, 2, 'hexagonal', 'gaussian', 200) | Score: -14.72
Tested params: ('minmax', 3, 2, 'hexagonal', 'bubble', 50) | Score: 35.37
Tested params: ('minmax', 3, 2, 'hexagonal', 'bubble', 100) | Score: 32.29
Tested params: ('minmax', 3, 2, 'hexagonal', 'bubble', 200) | Score: -16.55
Tested params: ('minmax', 3, 4, 'rectangular', 'gaussian', 50) | Score: 71.42
Tested params: ('minmax', 3, 4, 'rectangular', 'gaussian', 100) | Score: 54.54
Tested params: ('minmax', 3, 4, 'rectangular', 'gaussian', 200) | Score: 67.86
Tested params: ('minmax', 3, 4, 'rectangular', 'bubble', 50) | Score: -4.85
Tested params: ('minmax', 3, 4, 'rectangular', 'bubble', 100) | Score: 40.56
Tested params: ('minmax', 3, 4, 'rectangular', 'bubble', 200) | Score: -5.09
Tested params: ('minmax', 3, 4, 'hexagonal', 'gaussian', 50) | Score: 1.15
Tested params: ('minmax', 3, 4, 'hexagonal', 'gaussian', 100) | Score: 56.80
Tested params: ('minmax', 3, 4, 'hexagonal', 'gaussian', 200) | Score: 27.91
Tested params: ('minmax', 3, 4, 'hexagonal', 'bubble', 50) | Score: -14.38
Tested params: ('minmax', 3, 4, 'hexagonal', 'bubble', 100) | Score: 10.45
Tested params: ('minmax', 3, 4, 'hexagonal', 'bubble', 200) | Score: -9.01
Tested params: ('minmax', 3, 6, 'rectangular', 'gaussian', 50) | Score: 72.05
Tested params: ('minmax', 3, 6, 'rectangular', 'gaussian', 100) | Score: 75.45
Tested params: ('minmax', 3, 6, 'rectangular', 'gaussian', 200) | Score: 67.80
Tested params: ('minmax', 3, 6, 'rectangular', 'bubble', 50) | Score: 29.96
Tested params: ('minmax', 3, 6, 'rectangular', 'bubble', 100) | Score: 10.18
Tested params: ('minmax', 3, 6, 'rectangular', 'bubble', 200) | Score: -0.06
Tested params: ('minmax', 3, 6, 'hexagonal', 'gaussian', 50) | Score: 10.09
Tested params: ('minmax', 3, 6, 'hexagonal', 'gaussian', 100) | Score: 30.30
Tested params: ('minmax', 3, 6, 'hexagonal', 'gaussian', 200) | Score: 17.54
Tested params: ('minmax', 3, 6, 'hexagonal', 'bubble', 50) | Score: -6.74
Tested params: ('minmax', 3, 6, 'hexagonal', 'bubble', 100) | Score: -8.72
Tested params: ('minmax', 3, 6, 'hexagonal', 'bubble', 200) | Score: -12.10
Tested params: ('minmax', 5, 2, 'rectangular', 'gaussian', 50) | Score: 79.85
Tested params: ('minmax', 5, 2, 'rectangular', 'gaussian', 100) | Score: 62.12
Tested params: ('minmax', 5, 2, 'rectangular', 'gaussian', 200) | Score: 65.07
Tested params: ('minmax', 5, 2, 'rectangular', 'bubble', 50) | Score: 41.49
Tested params: ('minmax', 5, 2, 'rectangular', 'bubble', 100) | Score: 46.84
Tested params: ('minmax', 5, 2, 'rectangular', 'bubble', 200) | Score: -17.90
Tested params: ('minmax', 5, 2, 'hexagonal', 'gaussian', 50) | Score: -0.18
Tested params: ('minmax', 5, 2, 'hexagonal', 'gaussian', 100) | Score: 16.58
Tested params: ('minmax', 5, 2, 'hexagonal', 'gaussian', 200) | Score: -4.60
Tested params: ('minmax', 5, 2, 'hexagonal', 'bubble', 50) | Score: 41.49
Tested params: ('minmax', 5, 2, 'hexagonal', 'bubble', 100) | Score: 44.18
Tested params: ('minmax', 5, 2, 'hexagonal', 'bubble', 200) | Score: -18.46
Tested params: ('minmax', 5, 4, 'rectangular', 'gaussian', 50) | Score: 66.47
Tested params: ('minmax', 5, 4, 'rectangular', 'gaussian', 100) | Score: 72.90
Tested params: ('minmax', 5, 4, 'rectangular', 'gaussian', 200) | Score: 68.85
Tested params: ('minmax', 5, 4, 'rectangular', 'bubble', 50) | Score: 71.66
Tested params: ('minmax', 5, 4, 'rectangular', 'bubble', 100) | Score: 43.72
Tested params: ('minmax', 5, 4, 'rectangular', 'bubble', 200) | Score: 61.40
Tested params: ('minmax', 5, 4, 'hexagonal', 'gaussian', 50) | Score: 16.62
Tested params: ('minmax', 5, 4, 'hexagonal', 'gaussian', 100) | Score: 34.27
Tested params: ('minmax', 5, 4, 'hexagonal', 'gaussian', 200) | Score: 21.11
Tested params: ('minmax', 5, 4, 'hexagonal', 'bubble', 50) | Score: -9.99
Tested params: ('minmax', 5, 4, 'hexagonal', 'bubble', 100) | Score: -2.36
Tested params: ('minmax', 5, 4, 'hexagonal', 'bubble', 200) | Score: -1.21
Tested params: ('minmax', 5, 6, 'rectangular', 'gaussian', 50) | Score: 60.75
Tested params: ('minmax', 5, 6, 'rectangular', 'gaussian', 100) | Score: 78.77
Tested params: ('minmax', 5, 6, 'rectangular', 'gaussian', 200) | Score: 58.31
Tested params: ('minmax', 5, 6, 'rectangular', 'bubble', 50) | Score: 28.64
Tested params: ('minmax', 5, 6, 'rectangular', 'bubble', 100) | Score: 17.27
Tested params: ('minmax', 5, 6, 'rectangular', 'bubble', 200) | Score: 19.81
Tested params: ('minmax', 5, 6, 'hexagonal', 'gaussian', 50) | Score: 25.46
Tested params: ('minmax', 5, 6, 'hexagonal', 'gaussian', 100) | Score: 35.53
Tested params: ('minmax', 5, 6, 'hexagonal', 'gaussian', 200) | Score: 13.92
Tested params: ('minmax', 5, 6, 'hexagonal', 'bubble', 50) | Score: 12.11
Tested params: ('minmax', 5, 6, 'hexagonal', 'bubble', 100) | Score: 5.22
Tested params: ('minmax', 5, 6, 'hexagonal', 'bubble', 200) | Score: 10.57
Tested params: ('minmax', 7, 2, 'rectangular', 'gaussian', 50) | Score: 72.60
Tested params: ('minmax', 7, 2, 'rectangular', 'gaussian', 100) | Score: 72.86
Tested params: ('minmax', 7, 2, 'rectangular', 'gaussian', 200) | Score: 51.63
Tested params: ('minmax', 7, 2, 'rectangular', 'bubble', 50) | Score: 16.66
Tested params: ('minmax', 7, 2, 'rectangular', 'bubble', 100) | Score: 10.83
Tested params: ('minmax', 7, 2, 'rectangular', 'bubble', 200) | Score: 15.46
Tested params: ('minmax', 7, 2, 'hexagonal', 'gaussian', 50) | Score: 12.05
Tested params: ('minmax', 7, 2, 'hexagonal', 'gaussian', 100) | Score: 43.06
Tested params: ('minmax', 7, 2, 'hexagonal', 'gaussian', 200) | Score: 17.09
Tested params: ('minmax', 7, 2, 'hexagonal', 'bubble', 50) | Score: 16.52
Tested params: ('minmax', 7, 2, 'hexagonal', 'bubble', 100) | Score: 10.83
Tested params: ('minmax', 7, 2, 'hexagonal', 'bubble', 200) | Score: 15.46
Tested params: ('minmax', 7, 4, 'rectangular', 'gaussian', 50) | Score: 65.05
Tested params: ('minmax', 7, 4, 'rectangular', 'gaussian', 100) | Score: 70.15
Tested params: ('minmax', 7, 4, 'rectangular', 'gaussian', 200) | Score: 80.84
Tested params: ('minmax', 7, 4, 'rectangular', 'bubble', 50) | Score: 20.95
Tested params: ('minmax', 7, 4, 'rectangular', 'bubble', 100) | Score: 67.74
Tested params: ('minmax', 7, 4, 'rectangular', 'bubble', 200) | Score: 12.57
Tested params: ('minmax', 7, 4, 'hexagonal', 'gaussian', 50) | Score: 10.16
Tested params: ('minmax', 7, 4, 'hexagonal', 'gaussian', 100) | Score: 5.11
Tested params: ('minmax', 7, 4, 'hexagonal', 'gaussian', 200) | Score: 48.04
Tested params: ('minmax', 7, 4, 'hexagonal', 'bubble', 50) | Score: 2.60
Tested params: ('minmax', 7, 4, 'hexagonal', 'bubble', 100) | Score: 30.91
Tested params: ('minmax', 7, 4, 'hexagonal', 'bubble', 200) | Score: -8.15
Tested params: ('minmax', 7, 6, 'rectangular', 'gaussian', 50) | Score: 70.10
Tested params: ('minmax', 7, 6, 'rectangular', 'gaussian', 100) | Score: 46.53
Tested params: ('minmax', 7, 6, 'rectangular', 'gaussian', 200) | Score: 76.17
Tested params: ('minmax', 7, 6, 'rectangular', 'bubble', 50) | Score: 41.95
Tested params: ('minmax', 7, 6, 'rectangular', 'bubble', 100) | Score: 8.98
Tested params: ('minmax', 7, 6, 'rectangular', 'bubble', 200) | Score: 9.69
Tested params: ('minmax', 7, 6, 'hexagonal', 'gaussian', 50) | Score: 16.42
Tested params: ('minmax', 7, 6, 'hexagonal', 'gaussian', 100) | Score: 19.06
Tested params: ('minmax', 7, 6, 'hexagonal', 'gaussian', 200) | Score: 22.83
Tested params: ('minmax', 7, 6, 'hexagonal', 'bubble', 50) | Score: -4.27
Tested params: ('minmax', 7, 6, 'hexagonal', 'bubble', 100) | Score: -5.44
Tested params: ('minmax', 7, 6, 'hexagonal', 'bubble', 200) | Score: 0.58
5. Show the best parameters¶
[5]:
# Output the best hyperparameters
print("\nBest Hyperparameters:")
print(f"Scale Method: {best_params[0]}")
print(f"x_dim: {best_params[1]}")
print(f"y_dim: {best_params[2]}")
print(f"Topology: {best_params[3]}")
print(f"Neighborhood Function: {best_params[4]}")
print(f"Epochs: {best_params[5]}")
print(f"Best Score: {best_score:.2f}")
Best Hyperparameters:
Scale Method: minmax
x_dim: 7
y_dim: 4
Topology: rectangular
Neighborhood Function: gaussian
Epochs: 200
Best Score: 80.84
6. Train SOM¶
[6]:
# Train and visualize the best SOM
best_som = SOM(
train_dat=train_dat,
other_dat=other_dat,
scale_method=best_params[0],
x_dim=best_params[1],
y_dim=best_params[2],
topology=best_params[3],
neighborhood_fnc=best_params[4],
epochs=best_params[5],
)
best_som.train_map()
7. Get the Fit Metrics¶
[7]:
# Get fit metrics
pve = best_som.calculate_percent_variance_explained()
topographic_error = best_som.calculate_topographic_error()
print("\nFinal SOM Performance:")
print(f"Percent variance explained = {pve}%")
print(f"Topographic error = {topographic_error}")
Final SOM Performance:
Percent variance explained = 94.28111754749402%
Topographic error = 0.13445378151260504
8. Plot component planes and categorical data (see output directory for figures)¶
[9]:
# Plot component planes
best_som.plot_component_planes(output_dir="output/titanic")
[ ]: