Skip to content

Commit

Permalink
introduce tags
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Dec 6, 2024
1 parent a20864a commit 0976082
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 25 deletions.
4 changes: 2 additions & 2 deletions sklearnex/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from onedal.utils._array_api import _convert_to_dpnp

from ._config import get_config
from .utils import get_tags


def _get_backend(obj, queue, method_name, *data):
Expand Down Expand Up @@ -75,8 +76,7 @@ def dispatch(obj, method_name, branches, *args, **kwargs):
if (
sklearn_check_version("1.4")
and get_config()["array_api_dispatch"]
and "array_api_support" in obj._get_tags()
and obj._get_tags()["array_api_support"]
and get_tags(obj)["array_api_support"]
and not has_usm_data
):
# USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is
Expand Down
16 changes: 0 additions & 16 deletions sklearnex/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,3 @@ def get_hyperparameters(self, op):
return cls

return decorator


# This abstract class is meant to generate a clickable doc link for classses
# in sklearnex that are not part of base scikit-learn. It should be inherited
# before inheriting from a scikit-learn estimator, otherwise will get overriden
# by the estimator's original.
class IntelEstimator(ABC):
@property
def _doc_link_module(self) -> str:
return "sklearnex"

@property
def _doc_link_template(self) -> str:
module_path, _ = self.__class__.__module__.rsplit(".", 1)
class_name = self.__class__.__name__
return f"https://intel.github.io/scikit-learn-intelex/latest/non-scikit-algorithms.html#{module_path}.{class_name}"
57 changes: 57 additions & 0 deletions sklearnex/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# ===============================================================================
# Copyright contributors to the oneDAL project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================

from abc import ABC

from daal4py.sklearn._utils import sklearn_check_version


class IntelEstimator(ABC):

if sklearn_check_version("1.6"):
# Starting in sklearn 1.6, _more_tags is deprecated. An IntelEstimator
# is defined to handle the various versioning issues with the tags and
# with the ongoing rollout of sklearn's array_api support. This will make
# maintenance easier, and centralize tag changes to a single location.

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.onedal_array_api = False
return tags

elif sklearn_check_version("1.3"):

def _more_tags(self):
return {"onedal_array_api": False}

else:
# array_api_support tag was added in sklearn 1.3 via scikit-learn/scikit-learn#26372
def _more_tags(self):
return {"array_api_support": False, "onedal_array_api": False}

if sklearn_check_version("1.4"):

def _get_doc_link(self) -> str:
# This method is meant to generate a clickable doc link for classses
# in sklearnex that are not part of base scikit-learn. It should be
# inherited before inheriting from a scikit-learn estimator, otherwise
# will get overriden by the estimator's original.
url = super()._get_doc_link()
if not url:
module_path, _ = self.__class__.__module__.rsplit(".", 1)
class_name = self.__class__.__name__
url = f"https://intel.github.io/scikit-learn-intelex/latest/non-scikit-algorithms.html#{module_path}.{class_name}"
return url
4 changes: 2 additions & 2 deletions sklearnex/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def with_sklearnex():

# TODO:
# check if required.
#@pytest.fixture
#def with_array_api():
# @pytest.fixture
# def with_array_api():
# if sklearn_check_version("1.2"):
# with config_context(array_api_dispatch=True):
# yield
Expand Down
3 changes: 2 additions & 1 deletion sklearnex/neighbors/knn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from onedal.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier

from .._device_offload import dispatch, wrap_output_data
from ..utils import get_tags
from .common import KNeighborsDispatchingBase

if sklearn_check_version("1.6"):
Expand Down Expand Up @@ -184,7 +185,7 @@ def _onedal_fit(self, X, y, queue=None):
}

try:
requires_y = self._get_tags()["requires_y"]
requires_y = self.get_tags()["requires_y"]
except KeyError:
requires_y = False

Expand Down
3 changes: 2 additions & 1 deletion sklearnex/neighbors/knn_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from onedal.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor

from .._device_offload import dispatch, wrap_output_data
from ..utils import get_tags
from .common import KNeighborsDispatchingBase

if sklearn_check_version("1.6"):
Expand Down Expand Up @@ -166,7 +167,7 @@ def _onedal_fit(self, X, y, queue=None):
}

try:
requires_y = self._get_tags()["requires_y"]
requires_y = get_tags(self)["requires_y"]
except KeyError:
requires_y = False

Expand Down
3 changes: 2 additions & 1 deletion sklearnex/neighbors/knn_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from onedal.neighbors import NearestNeighbors as onedal_NearestNeighbors

from .._device_offload import dispatch, wrap_output_data
from ..utils import get_tags
from .common import KNeighborsDispatchingBase

if sklearn_check_version("1.6"):
Expand Down Expand Up @@ -140,7 +141,7 @@ def _onedal_fit(self, X, y=None, queue=None):
}

try:
requires_y = self._get_tags()["requires_y"]
requires_y = get_tags(self)["requires_y"]
except KeyError:
requires_y = False

Expand Down
3 changes: 2 additions & 1 deletion sklearnex/svm/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from .._config import config_context, get_config
from .._utils import PatchingConditionsChain
from ..utils import get_tags

if sklearn_check_version("1.6"):
from sklearn.utils.validation import validate_data
Expand Down Expand Up @@ -156,7 +157,7 @@ def _onedal_fit_checks(self, X, y, sample_weight=None):
)

if y is None:
if self._get_tags()["requires_y"]:
if get_tags(self)["requires_y"]:
raise ValueError(
f"This {self.__class__.__name__} estimator "
f"requires y to be passed, but the target y is None."
Expand Down
19 changes: 18 additions & 1 deletion sklearnex/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,21 @@

from .validation import _assert_all_finite

__all__ = ["_assert_all_finite"]
from daal4py.sklearn._utils import sklearn_check_version

# Not an ideal solution, but this converts the outputs of newer sklearnex tags
# into dicts to match how tags had been used. Someone more clever than me will
# have to find a way of converting older tags into newer ones instead (with
# minimal impact on performance).

if sklearn_check_version("1.6"):
from sklearn.utils import get_tags as _sklearn_get_tags

get_tags = lambda estimator: _sklearn_get_tags(estimator).__dict__

else:
from sklearn.base import BaseEstimator

get_tags = BaseEstimator._get_tags

__all__ = ["_assert_all_finite", "get_tags"]

0 comments on commit 0976082

Please sign in to comment.