File size: 5,030 Bytes
98ca408 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import logging
from dataclasses import dataclass
from typing import Generic, Iterable, List, Optional, TypeVar
import ray
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor
from ray.util.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
WriteReturnType = TypeVar("WriteReturnType")
"""Generic type for the return value of `Datasink.write`."""
@dataclass
@DeveloperAPI
class WriteResult(Generic[WriteReturnType]):
"""Aggregated result of the Datasink write operations."""
# Total number of written rows.
num_rows: int
# Total size in bytes of written data.
size_bytes: int
# All returned values of `Datasink.write`.
write_returns: List[WriteReturnType]
@DeveloperAPI
class Datasink(Generic[WriteReturnType]):
"""Interface for defining write-related logic.
If you want to write data to something that isn't built-in, subclass this class
and call :meth:`~ray.data.Dataset.write_datasink`.
"""
def on_write_start(self) -> None:
"""Callback for when a write job starts.
Use this method to perform setup for write tasks. For example, creating a
staging bucket in S3.
"""
pass
def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> WriteReturnType:
"""Write blocks. This is used by a single write task.
Args:
blocks: Generator of data blocks.
ctx: ``TaskContext`` for the write task.
Returns:
Result of this write task. When the entire write operator finishes,
All returned values will be passed as `WriteResult.write_returns`
to `Datasink.on_write_complete`.
"""
raise NotImplementedError
def on_write_complete(self, write_result: WriteResult[WriteReturnType]):
"""Callback for when a write job completes.
This can be used to "commit" a write output. This method must
succeed prior to ``write_datasink()`` returning to the user. If this
method fails, then ``on_write_failed()`` is called.
Args:
write_result: Aggregated result of the
the Write operator, containing write results and stats.
"""
pass
def on_write_failed(self, error: Exception) -> None:
"""Callback for when a write job fails.
This is called on a best-effort basis on write failures.
Args:
error: The first error encountered.
"""
pass
def get_name(self) -> str:
"""Return a human-readable name for this datasink.
This is used as the names of the write tasks.
"""
name = type(self).__name__
datasink_suffix = "Datasink"
if name.startswith("_"):
name = name[1:]
if name.endswith(datasink_suffix):
name = name[: -len(datasink_suffix)]
return name
@property
def supports_distributed_writes(self) -> bool:
"""If ``False``, only launch write tasks on the driver's node."""
return True
@property
def min_rows_per_write(self) -> Optional[int]:
"""The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call.
If ``None``, Ray Data passes a system-chosen number of rows.
"""
return None
@DeveloperAPI
class DummyOutputDatasink(Datasink[None]):
"""An example implementation of a writable datasource for testing.
Examples:
>>> import ray
>>> from ray.data.datasource import DummyOutputDatasink
>>> output = DummyOutputDatasink()
>>> ray.data.range(10).write_datasink(output)
>>> assert output.num_ok == 1
"""
def __init__(self):
ctx = ray.data.DataContext.get_current()
# Setup a dummy actor to send the data. In a real datasource, write
# tasks would send data to an external system instead of a Ray actor.
@ray.remote(scheduling_strategy=ctx.scheduling_strategy)
class DataSink:
def __init__(self):
self.rows_written = 0
self.enabled = True
def write(self, block: Block) -> None:
block = BlockAccessor.for_block(block)
self.rows_written += block.num_rows()
def get_rows_written(self):
return self.rows_written
self.data_sink = DataSink.remote()
self.num_ok = 0
self.num_failed = 0
self.enabled = True
def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
tasks = []
if not self.enabled:
raise ValueError("disabled")
for b in blocks:
tasks.append(self.data_sink.write.remote(b))
ray.get(tasks)
def on_write_complete(self, write_result: WriteResult[None]):
self.num_ok += 1
def on_write_failed(self, error: Exception) -> None:
self.num_failed += 1
|