| import abc |
| import base64 |
| import collections |
| import pickle |
| import warnings |
| from enum import Enum |
| from typing import TYPE_CHECKING, Any, Dict, Union |
|
|
| from ray.air.util.data_batch_conversion import BatchFormat |
| from ray.util.annotations import DeveloperAPI, PublicAPI |
|
|
| if TYPE_CHECKING: |
| import numpy as np |
| import pandas as pd |
|
|
| from ray.air.data_batch_type import DataBatchType |
| from ray.data import Dataset |
|
|
|
|
| @PublicAPI(stability="beta") |
| class PreprocessorNotFittedException(RuntimeError): |
| """Error raised when the preprocessor needs to be fitted first.""" |
|
|
| pass |
|
|
|
|
| @PublicAPI(stability="beta") |
| class Preprocessor(abc.ABC): |
| """Implements an ML preprocessing operation. |
| |
| Preprocessors are stateful objects that can be fitted against a Dataset and used |
| to transform both local data batches and distributed data. For example, a |
| Normalization preprocessor may calculate the mean and stdev of a field during |
| fitting, and uses these attributes to implement its normalization transform. |
| |
| Preprocessors can also be stateless and transform data without needed to be fitted. |
| For example, a preprocessor may simply remove a column, which does not require |
| any state to be fitted. |
| |
| If you are implementing your own Preprocessor sub-class, you should override the |
| following: |
| |
| * ``_fit`` if your preprocessor is stateful. Otherwise, set |
| ``_is_fittable=False``. |
| * ``_transform_pandas`` and/or ``_transform_numpy`` for best performance, |
| implement both. Otherwise, the data will be converted to the match the |
| implemented method. |
| """ |
|
|
| class FitStatus(str, Enum): |
| """The fit status of preprocessor.""" |
|
|
| NOT_FITTABLE = "NOT_FITTABLE" |
| NOT_FITTED = "NOT_FITTED" |
| |
| |
| |
| |
| |
| PARTIALLY_FITTED = "PARTIALLY_FITTED" |
| FITTED = "FITTED" |
|
|
| |
| _is_fittable = True |
|
|
| def _check_has_fitted_state(self): |
| """Checks if the Preprocessor has fitted state. |
| |
| This is also used as an indiciation if the Preprocessor has been fit, following |
| convention from Ray versions prior to 2.6. |
| This allows preprocessors that have been fit in older versions of Ray to be |
| used to transform data in newer versions. |
| """ |
|
|
| fitted_vars = [v for v in vars(self) if v.endswith("_")] |
| return bool(fitted_vars) |
|
|
| def fit_status(self) -> "Preprocessor.FitStatus": |
| if not self._is_fittable: |
| return Preprocessor.FitStatus.NOT_FITTABLE |
| elif ( |
| hasattr(self, "_fitted") and self._fitted |
| ) or self._check_has_fitted_state(): |
| return Preprocessor.FitStatus.FITTED |
| else: |
| return Preprocessor.FitStatus.NOT_FITTED |
|
|
| def fit(self, ds: "Dataset") -> "Preprocessor": |
| """Fit this Preprocessor to the Dataset. |
| |
| Fitted state attributes will be directly set in the Preprocessor. |
| |
| Calling it more than once will overwrite all previously fitted state: |
| ``preprocessor.fit(A).fit(B)`` is equivalent to ``preprocessor.fit(B)``. |
| |
| Args: |
| ds: Input dataset. |
| |
| Returns: |
| Preprocessor: The fitted Preprocessor with state attributes. |
| """ |
| fit_status = self.fit_status() |
| if fit_status == Preprocessor.FitStatus.NOT_FITTABLE: |
| |
| return self |
|
|
| if fit_status in ( |
| Preprocessor.FitStatus.FITTED, |
| Preprocessor.FitStatus.PARTIALLY_FITTED, |
| ): |
| warnings.warn( |
| "`fit` has already been called on the preprocessor (or at least one " |
| "contained preprocessors if this is a chain). " |
| "All previously fitted state will be overwritten!" |
| ) |
|
|
| fitted_ds = self._fit(ds) |
| self._fitted = True |
| return fitted_ds |
|
|
| def fit_transform(self, ds: "Dataset") -> "Dataset": |
| """Fit this Preprocessor to the Dataset and then transform the Dataset. |
| |
| Calling it more than once will overwrite all previously fitted state: |
| ``preprocessor.fit_transform(A).fit_transform(B)`` |
| is equivalent to ``preprocessor.fit_transform(B)``. |
| |
| Args: |
| ds: Input Dataset. |
| |
| Returns: |
| ray.data.Dataset: The transformed Dataset. |
| """ |
| self.fit(ds) |
| return self.transform(ds) |
|
|
| def transform(self, ds: "Dataset") -> "Dataset": |
| """Transform the given dataset. |
| |
| Args: |
| ds: Input Dataset. |
| |
| Returns: |
| ray.data.Dataset: The transformed Dataset. |
| |
| Raises: |
| PreprocessorNotFittedException: if ``fit`` is not called yet. |
| """ |
| fit_status = self.fit_status() |
| if fit_status in ( |
| Preprocessor.FitStatus.PARTIALLY_FITTED, |
| Preprocessor.FitStatus.NOT_FITTED, |
| ): |
| raise PreprocessorNotFittedException( |
| "`fit` must be called before `transform`, " |
| "or simply use fit_transform() to run both steps" |
| ) |
| transformed_ds = self._transform(ds) |
| return transformed_ds |
|
|
| def transform_batch(self, data: "DataBatchType") -> "DataBatchType": |
| """Transform a single batch of data. |
| |
| The data will be converted to the format supported by the Preprocessor, |
| based on which ``_transform_*`` methods are implemented. |
| |
| Args: |
| data: Input data batch. |
| |
| Returns: |
| DataBatchType: |
| The transformed data batch. This may differ |
| from the input type depending on which ``_transform_*`` methods |
| are implemented. |
| """ |
| fit_status = self.fit_status() |
| if fit_status in ( |
| Preprocessor.FitStatus.PARTIALLY_FITTED, |
| Preprocessor.FitStatus.NOT_FITTED, |
| ): |
| raise PreprocessorNotFittedException( |
| "`fit` must be called before `transform_batch`." |
| ) |
| return self._transform_batch(data) |
|
|
| @DeveloperAPI |
| def _fit(self, ds: "Dataset") -> "Preprocessor": |
| """Sub-classes should override this instead of fit().""" |
| raise NotImplementedError() |
|
|
| def _determine_transform_to_use(self) -> BatchFormat: |
| """Determine which batch format to use based on Preprocessor implementation. |
| |
| * If only `_transform_pandas` is implemented, then use ``pandas`` batch format. |
| * If only `_transform_numpy` is implemented, then use ``numpy`` batch format. |
| * If both are implemented, then use the Preprocessor defined preferred batch |
| format. |
| """ |
|
|
| has_transform_pandas = ( |
| self.__class__._transform_pandas != Preprocessor._transform_pandas |
| ) |
| has_transform_numpy = ( |
| self.__class__._transform_numpy != Preprocessor._transform_numpy |
| ) |
|
|
| if has_transform_numpy and has_transform_pandas: |
| return self.preferred_batch_format() |
| elif has_transform_numpy: |
| return BatchFormat.NUMPY |
| elif has_transform_pandas: |
| return BatchFormat.PANDAS |
| else: |
| raise NotImplementedError( |
| "None of `_transform_numpy` or `_transform_pandas` are implemented. " |
| "At least one of these transform functions must be implemented " |
| "for Preprocessor transforms." |
| ) |
|
|
| def _transform(self, ds: "Dataset") -> "Dataset": |
| |
| |
| transform_type = self._determine_transform_to_use() |
|
|
| |
| |
| kwargs = self._get_transform_config() |
| if transform_type == BatchFormat.PANDAS: |
| return ds.map_batches( |
| self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs |
| ) |
| elif transform_type == BatchFormat.NUMPY: |
| return ds.map_batches( |
| self._transform_numpy, batch_format=BatchFormat.NUMPY, **kwargs |
| ) |
| else: |
| raise ValueError( |
| "Invalid transform type returned from _determine_transform_to_use; " |
| f'"pandas" and "numpy" allowed, but got: {transform_type}' |
| ) |
|
|
| def _get_transform_config(self) -> Dict[str, Any]: |
| """Returns kwargs to be passed to :meth:`ray.data.Dataset.map_batches`. |
| |
| This can be implemented by subclassing preprocessors. |
| """ |
| return {} |
|
|
| def _transform_batch(self, data: "DataBatchType") -> "DataBatchType": |
| |
| import numpy as np |
| import pandas as pd |
|
|
| from ray.air.util.data_batch_conversion import ( |
| _convert_batch_type_to_numpy, |
| _convert_batch_type_to_pandas, |
| ) |
|
|
| try: |
| import pyarrow |
| except ImportError: |
| pyarrow = None |
|
|
| if not isinstance( |
| data, (pd.DataFrame, pyarrow.Table, collections.abc.Mapping, np.ndarray) |
| ): |
| raise ValueError( |
| "`transform_batch` is currently only implemented for Pandas " |
| "DataFrames, pyarrow Tables, NumPy ndarray and dictionary of " |
| f"ndarray. Got {type(data)}." |
| ) |
|
|
| transform_type = self._determine_transform_to_use() |
|
|
| if transform_type == BatchFormat.PANDAS: |
| return self._transform_pandas(_convert_batch_type_to_pandas(data)) |
| elif transform_type == BatchFormat.NUMPY: |
| return self._transform_numpy(_convert_batch_type_to_numpy(data)) |
|
|
| @DeveloperAPI |
| def _transform_pandas(self, df: "pd.DataFrame") -> "pd.DataFrame": |
| """Run the transformation on a data batch in a Pandas DataFrame format.""" |
| raise NotImplementedError() |
|
|
| @DeveloperAPI |
| def _transform_numpy( |
| self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]] |
| ) -> Union["np.ndarray", Dict[str, "np.ndarray"]]: |
| """Run the transformation on a data batch in a NumPy ndarray format.""" |
| raise NotImplementedError() |
|
|
| @classmethod |
| @DeveloperAPI |
| def preferred_batch_format(cls) -> BatchFormat: |
| """Batch format hint for upstream producers to try yielding best block format. |
| |
| The preferred batch format to use if both `_transform_pandas` and |
| `_transform_numpy` are implemented. Defaults to Pandas. |
| |
| Can be overriden by Preprocessor classes depending on which transform |
| path is the most optimal. |
| """ |
| return BatchFormat.PANDAS |
|
|
| @DeveloperAPI |
| def serialize(self) -> str: |
| """Return this preprocessor serialized as a string. |
| Note: this is not a stable serialization format as it uses `pickle`. |
| """ |
| |
| |
| return base64.b64encode(pickle.dumps(self)).decode("ascii") |
|
|
| @staticmethod |
| @DeveloperAPI |
| def deserialize(serialized: str) -> "Preprocessor": |
| """Load the original preprocessor serialized via `self.serialize()`.""" |
| return pickle.loads(base64.b64decode(serialized)) |
|
|