Skip to content

Commit

Permalink
FIX Make cross-validation figures less inexact
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturoAmorQ committed Feb 23, 2024
1 parent 8124c5b commit 900c6f7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
Binary file modified figures/cross_validation_train_test_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/nested_cross_validation_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 13 additions & 6 deletions figures/plot_parameter_tuning_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,12 @@ def plot_cv_indices(cv, X, y, ax, lw=50):
Patch(color=cmap_cv(0.5)),
Patch(color=cmap_cv(0.02)),
],
["Testing samples", "Training samples", "Validation samples"],
loc=(1.02, 0.7),
[
"Testing samples\n(reserved for\nfinal evaluation)",
"Training samples",
"Validation samples",
],
loc=(1.02, 0.5),
)
return ax

Expand All @@ -82,7 +86,6 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):

# Generate the training/testing visualizations for each CV split
for ii, (train_outer, test_outer) in enumerate(splits_outer):

splits_inner = list(cv_inner.split(train_outer))
n_splits_inner = len(splits_inner)

Expand Down Expand Up @@ -116,7 +119,7 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):
)
yticklabels = list(range(n_splits_outer))
ax.set(
yticks=n_splits_inner*np.arange(n_splits_outer) + 0.5,
yticks=n_splits_inner * np.arange(n_splits_outer) + 0.5,
yticklabels=yticklabels,
xlabel="Sample index",
ylabel="CV outer iteration",
Expand All @@ -129,8 +132,12 @@ def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50):
Patch(color=cmap_cv(0.5)),
Patch(color=cmap_cv(0.02)),
],
["Testing samples", "Training samples", "Validation samples"],
loc=(1.06, .93),
[
"Testing samples\n(reserved for\nouter evaluation)",
"Training samples",
"Validation samples",
],
loc=(1.06, 0.85),
)
return ax

Expand Down

0 comments on commit 900c6f7

Please sign in to comment.