File size: 5,904 Bytes
e0fc633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import operator
import warnings

import dask
from dask import core
from dask.core import istask
from dask.dataframe.core import _concat
from dask.dataframe.optimize import optimize
from dask.dataframe.shuffle import shuffle_group
from dask.highlevelgraph import HighLevelGraph

from .scheduler import MultipleReturnFunc, multiple_return_get

try:
    from dask.dataframe.shuffle import SimpleShuffleLayer
except ImportError:
    # SimpleShuffleLayer doesn't exist in this version of Dask.
    SimpleShuffleLayer = None

if SimpleShuffleLayer is not None:

    class MultipleReturnSimpleShuffleLayer(SimpleShuffleLayer):
        @classmethod
        def clone(cls, layer: SimpleShuffleLayer):
            # TODO(Clark): Probably don't need this since SimpleShuffleLayer
            # implements __copy__() and the shallow clone should be enough?
            return cls(
                name=layer.name,
                column=layer.column,
                npartitions=layer.npartitions,
                npartitions_input=layer.npartitions_input,
                ignore_index=layer.ignore_index,
                name_input=layer.name_input,
                meta_input=layer.meta_input,
                parts_out=layer.parts_out,
                annotations=layer.annotations,
            )

        def __repr__(self):
            return (
                f"MultipleReturnSimpleShuffleLayer<name='{self.name}', "
                f"npartitions={self.npartitions}>"
            )

        def __reduce__(self):
            attrs = [
                "name",
                "column",
                "npartitions",
                "npartitions_input",
                "ignore_index",
                "name_input",
                "meta_input",
                "parts_out",
                "annotations",
            ]
            return (
                MultipleReturnSimpleShuffleLayer,
                tuple(getattr(self, attr) for attr in attrs),
            )

        def _cull(self, parts_out):
            return MultipleReturnSimpleShuffleLayer(
                self.name,
                self.column,
                self.npartitions,
                self.npartitions_input,
                self.ignore_index,
                self.name_input,
                self.meta_input,
                parts_out=parts_out,
            )

        def _construct_graph(self):
            """Construct graph for a simple shuffle operation."""

            shuffle_group_name = "group-" + self.name
            shuffle_split_name = "split-" + self.name

            dsk = {}
            n_parts_out = len(self.parts_out)
            for part_out in self.parts_out:
                # TODO(Clark): Find better pattern than in-scheduler concat.
                _concat_list = [
                    (shuffle_split_name, part_out, part_in)
                    for part_in in range(self.npartitions_input)
                ]
                dsk[(self.name, part_out)] = (_concat, _concat_list, self.ignore_index)
                for _, _part_out, _part_in in _concat_list:
                    dsk[(shuffle_split_name, _part_out, _part_in)] = (
                        multiple_return_get,
                        (shuffle_group_name, _part_in),
                        _part_out,
                    )
                    if (shuffle_group_name, _part_in) not in dsk:
                        dsk[(shuffle_group_name, _part_in)] = (
                            MultipleReturnFunc(
                                shuffle_group,
                                n_parts_out,
                            ),
                            (self.name_input, _part_in),
                            self.column,
                            0,
                            self.npartitions,
                            self.npartitions,
                            self.ignore_index,
                            self.npartitions,
                        )

            return dsk

    def rewrite_simple_shuffle_layer(dsk, keys):
        if not isinstance(dsk, HighLevelGraph):
            dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
        else:
            dsk = dsk.copy()

        layers = dsk.layers.copy()
        for key, layer in layers.items():
            if type(layer) is SimpleShuffleLayer:
                dsk.layers[key] = MultipleReturnSimpleShuffleLayer.clone(layer)
        return dsk

    def dataframe_optimize(dsk, keys, **kwargs):
        if not isinstance(keys, (list, set)):
            keys = [keys]
        keys = list(core.flatten(keys))

        if not isinstance(dsk, HighLevelGraph):
            dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())

        dsk = rewrite_simple_shuffle_layer(dsk, keys=keys)
        return optimize(dsk, keys, **kwargs)

else:

    def dataframe_optimize(dsk, keys, **kwargs):
        warnings.warn(
            "Custom dataframe shuffle optimization only works on "
            "dask>=2020.12.0, you are on version "
            f"{dask.__version__}, please upgrade Dask."
            "Falling back to default dataframe optimizer."
        )
        return optimize(dsk, keys, **kwargs)


# Stale approaches below.


def fuse_splits_into_multiple_return(dsk, keys):
    if not isinstance(dsk, HighLevelGraph):
        dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
    else:
        dsk = dsk.copy()
    dependencies = dsk.dependencies.copy()
    for k, v in dsk.items():
        if istask(v) and v[0] == shuffle_group:
            task_deps = dependencies[k]
            # Only rewrite shuffle group split if all downstream dependencies
            # are splits.
            if all(
                istask(dsk[dep]) and dsk[dep][0] == operator.getitem
                for dep in task_deps
            ):
                for dep in task_deps:
                    # Rewrite split
                    pass