Source code for dask.array._shuffle

from __future__ import annotations

import copy
import math
from functools import reduce
from itertools import count, product
from operator import mul
from typing import Literal

import numpy as np

from dask import config
from dask._task_spec import DataNode, List, Task, TaskRef
from dask.array.chunk import getitem
from dask.array.core import Array, unknown_chunk_message
from dask.array.dispatch import concatenate_lookup, take_lookup
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph


[docs] def shuffle(x, indexer: list[list[int]], axis: int, chunks: Literal["auto"] = "auto"): """ Reorders one dimensions of a Dask Array based on an indexer. The indexer defines a list of positional groups that will end up in the same chunk together. A single group is in at most one chunk on this dimension, but a chunk might contain multiple groups to avoid fragmentation of the array. The algorithm tries to balance the chunksizes as much as possible to ideally keep the number of chunks consistent or at least manageable. Parameters ---------- x: dask array Array to be shuffled. indexer: list[list[int]] The indexer that determines which elements along the dimension will end up in the same chunk. Multiple groups can be in the same chunk to avoid fragmentation, but each group will end up in exactly one chunk. axis: int The axis to shuffle along. chunks: "auto" Hint on how to rechunk if single groups are becoming too large. The default is to split chunks along the other dimensions evenly to keep the chunksize consistent. The rechunking is done in a way that ensures that non all-to-all network communication is necessary, chunks are only split and not combined with other chunks. Examples -------- >>> import dask.array as da >>> import numpy as np >>> arr = np.array([[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]]) >>> x = da.from_array(arr, chunks=(2, 4)) Separate the elements in different groups. >>> y = x.shuffle([[6, 5, 2], [4, 1], [3, 0, 7]], axis=1) The shuffle algorihthm will combine the first 2 groups into a single chunk to keep the number of chunks small. The tolerance of increasing the chunk size is controlled by the configuration "array.chunk-size-tolerance". The default value is 1.25. >>> y.chunks ((2,), (5, 3)) The array was reordered along axis 1 according to the positional indexer that was given. >>> y.compute() array([[ 7, 6, 3, 5, 2, 4, 1, 8], [15, 14, 11, 13, 10, 12, 9, 16]]) """ if np.isnan(x.shape).any(): raise ValueError( f"Shuffling only allowed with known chunk sizes. {unknown_chunk_message}" ) assert isinstance(axis, int), "axis must be an integer" _validate_indexer(x.chunks, indexer, axis) x = _rechunk_other_dimensions(x, max(map(len, indexer)), axis, chunks) token = tokenize(x, indexer, axis) out_name = f"shuffle-{token}" chunks, layer = _shuffle(x.chunks, indexer, axis, x.name, out_name, token) if len(layer) == 0: return Array(x.dask, x.name, x.chunks, meta=x) graph = HighLevelGraph.from_collections(out_name, layer, dependencies=[x]) return Array(graph, out_name, chunks, meta=x)
def _calculate_new_chunksizes( input_chunks, new_chunks, changeable_dimensions: set, maximum_chunk: int ): chunksize_tolerance = config.get("array.chunk-size-tolerance") maximum_chunk = max(maximum_chunk, 1) # iterate until we distributed the increase in chunksize accross all dimensions # or every non-shuffle dimension is all 1 while changeable_dimensions: n_changeable_dimensions = len(changeable_dimensions) chunksize_inc_factor = reduce(mul, map(max, new_chunks)) / maximum_chunk if chunksize_inc_factor <= 1: break for i in list(changeable_dimensions): new_chunksizes = [] # calculate what the max chunk size in this dimension is and split every # chunk that is larger than that. We split the increase factor evenly # between all dimensions that are not shuffled. up_chunksize_limit_for_dim = max(new_chunks[i]) / ( chunksize_inc_factor ** (1 / n_changeable_dimensions) ) for c in input_chunks[i]: if c > chunksize_tolerance * up_chunksize_limit_for_dim: factor = math.ceil(c / up_chunksize_limit_for_dim) # Ensure that we end up at least with chunksize 1 factor = min(factor, c) chunksize, remainder = divmod(c, factor) nc = [chunksize] * factor for ii in range(remainder): # Add remainder parts to the first few chunks nc[ii] += 1 new_chunksizes.extend(nc) else: new_chunksizes.append(c) if tuple(new_chunksizes) == new_chunks[i] or max(new_chunksizes) == 1: changeable_dimensions.remove(i) new_chunks[i] = tuple(new_chunksizes) return new_chunks def _rechunk_other_dimensions( x: Array, longest_group: int, axis: int, chunks: Literal["auto"] ) -> Array: assert chunks == "auto", "Only auto is supported for now" chunksize_tolerance = config.get("array.chunk-size-tolerance") if longest_group <= max(x.chunks[axis]) * chunksize_tolerance: # We are staying below our threshold, so don't rechunk return x changeable_dimensions = set(range(len(x.chunks))) - {axis} new_chunks = list(x.chunks) new_chunks[axis] = (longest_group,) # How large is the largest chunk in the input maximum_chunk = reduce(mul, map(max, x.chunks)) new_chunks = _calculate_new_chunksizes( x.chunks, new_chunks, changeable_dimensions, maximum_chunk ) new_chunks[axis] = x.chunks[axis] return x.rechunk(tuple(new_chunks)) def _validate_indexer(chunks, indexer, axis): if not isinstance(indexer, list) or not all(isinstance(i, list) for i in indexer): raise ValueError("indexer must be a list of lists of positional indices") if not axis <= len(chunks): raise ValueError( f"Axis {axis} is out of bounds for array with {len(chunks)} axes" ) if max(map(max, indexer)) >= sum(chunks[axis]): raise IndexError( f"Indexer contains out of bounds index. Dimension only has {sum(chunks[axis])} elements." ) def _shuffle(chunks, indexer, axis, in_name, out_name, token): _validate_indexer(chunks, indexer, axis) if len(indexer) == len(chunks[axis]): # check if the array is already shuffled the way we want ctr = 0 for idx, c in zip(indexer, chunks[axis]): if idx != list(range(ctr, ctr + c)): break ctr += c else: return chunks, {} indexer = copy.deepcopy(indexer) chunksize_tolerance = config.get("array.chunk-size-tolerance") chunk_size_limit = int(sum(chunks[axis]) / len(chunks[axis]) * chunksize_tolerance) # Figure out how many groups we can put into one chunk current_chunk, new_chunks = [], [] for idx in indexer: if len(current_chunk) + len(idx) > chunk_size_limit and len(current_chunk) > 0: new_chunks.append(current_chunk) current_chunk = idx.copy() else: current_chunk.extend(idx) if len(current_chunk) > chunk_size_limit / chunksize_tolerance: new_chunks.append(current_chunk) current_chunk = [] if len(current_chunk) > 0: new_chunks.append(current_chunk) chunk_boundaries = np.cumsum(chunks[axis]) # Get existing chunk tuple locations chunk_tuples = list( product(*(range(len(c)) for i, c in enumerate(chunks) if i != axis)) ) intermediates = dict() merges = dict() dtype = np.min_scalar_type(max(chunks[axis])) split_name = f"shuffle-split-{token}" slices = [slice(None)] * len(chunks) split_name_suffixes = count() sorter_name = "shuffle-sorter-" taker_name = "shuffle-taker-" old_blocks = { old_index: (in_name,) + old_index for old_index in np.ndindex(tuple([len(c) for c in chunks])) } for new_chunk_idx, new_chunk_taker in enumerate(new_chunks): new_chunk_taker = np.array(new_chunk_taker) sorter = np.argsort(new_chunk_taker).astype(dtype) sorter_key = sorter_name + tokenize(sorter) # low level fusion can't deal with arrays on first position merges[sorter_key] = DataNode(sorter_key, (1, sorter)) sorted_array = new_chunk_taker[sorter] source_chunk_nr, taker_boundary = np.unique( np.searchsorted(chunk_boundaries, sorted_array, side="right"), return_index=True, ) taker_boundary = taker_boundary.tolist() taker_boundary.append(len(new_chunk_taker)) taker_cache = {} for chunk_tuple in chunk_tuples: merge_keys = [] for c, b_start, b_end in zip( source_chunk_nr, taker_boundary[:-1], taker_boundary[1:] ): # insert our axis chunk id into the chunk_tuple chunk_key = convert_key(chunk_tuple, c, axis) name = (split_name, next(split_name_suffixes)) this_slice = slices.copy() # Cache the takers to allow de-duplication when serializing # Ugly! if c in taker_cache: taker_key = taker_cache[c] else: this_slice[axis] = ( sorted_array[b_start:b_end] - (chunk_boundaries[c - 1] if c > 0 else 0) ).astype(dtype) if len(source_chunk_nr) == 1: this_slice[axis] = this_slice[axis][np.argsort(sorter)] taker_key = taker_name + tokenize(this_slice) # low level fusion can't deal with arrays on first position intermediates[taker_key] = DataNode( taker_key, (1, tuple(this_slice)) ) taker_cache[c] = taker_key intermediates[name] = Task( name, _getitem, TaskRef(old_blocks[chunk_key]), TaskRef(taker_key) ) merge_keys.append(name) merge_suffix = convert_key(chunk_tuple, new_chunk_idx, axis) out_name_merge = (out_name,) + merge_suffix if len(merge_keys) > 1: merges[out_name_merge] = Task( out_name_merge, concatenate_arrays, List(*(TaskRef(m) for m in merge_keys)), TaskRef(sorter_key), axis, ) elif len(merge_keys) == 1: t = intermediates.pop(merge_keys[0]) t.key = out_name_merge merges[out_name_merge] = t else: raise NotImplementedError output_chunks = [] for i, c in enumerate(chunks): if i == axis: output_chunks.append(tuple(map(len, new_chunks))) else: output_chunks.append(c) layer = {**merges, **intermediates} return tuple(output_chunks), layer def _getitem(obj, index): return getitem(obj, index[1]) def concatenate_arrays(arrs, sorter, axis): return take_lookup( concatenate_lookup.dispatch(type(arrs[0]))(arrs, axis=axis), np.argsort(sorter[1]), axis=axis, ) def convert_key(key, chunk, axis): key = list(key) key.insert(axis, chunk) return tuple(key)