Skip to content

Commit

Permalink
Use L2SqrtExpanded in pairwise_distances
Browse files Browse the repository at this point in the history
Previously we were using `L2SqrtUnexpanded`, which is more costly.
  • Loading branch information
jcrist committed Feb 7, 2025
1 parent db16f23 commit 0529289
Showing 1 changed file with 3 additions and 17 deletions.
20 changes: 3 additions & 17 deletions python/cuml/cuml/metrics/pairwise_distances.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,9 +66,9 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics":
PAIRWISE_DISTANCE_METRICS = {
"cityblock": DistanceType.L1,
"cosine": DistanceType.CosineExpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"euclidean": DistanceType.L2SqrtExpanded,
"l1": DistanceType.L1,
"l2": DistanceType.L2SqrtUnexpanded,
"l2": DistanceType.L2SqrtExpanded,
"manhattan": DistanceType.L1,
"sqeuclidean": DistanceType.L2Expanded,
"canberra": DistanceType.Canberra,
Expand Down Expand Up @@ -102,20 +102,6 @@ PAIRWISE_DISTANCE_SPARSE_METRICS = {


def _determine_metric(metric_str, is_sparse_=False):
# Available options in scikit-learn and their pairs. See
# sklearn.metrics.pairwise.PAIRWISE_DISTANCE_FUNCTIONS:
# 'cityblock': L1
# 'cosine': CosineExpanded
# 'euclidean': L2SqrtUnexpanded
# 'haversine': N/A
# 'l2': L2SqrtUnexpanded
# 'l1': L1
# 'manhattan': L1
# 'nan_euclidean': N/A
# 'sqeuclidean': L2Unexpanded
# Note: many are duplicates following this:
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/pairwise.py#L1321

if metric_str == 'haversine':
raise ValueError(" The metric: '{}', is not supported at this time."
.format(metric_str))
Expand Down

0 comments on commit 0529289

Please sign in to comment.