Skip to content

Migrate Data.transpose from LAMA to Dask #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 14 additions & 29 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12184,6 +12184,7 @@ def tolist(self):
"""
return self.array.tolist()

@daskified(1)
@_deprecated_kwarg_check("i")
@_inplace_enabled(default=False)
def transpose(self, axes=None, inplace=False, i=False):
Expand Down Expand Up @@ -12224,44 +12225,28 @@ def transpose(self, axes=None, inplace=False, i=False):
"""
d = _inplace_enabled_define_and_cleanup(self)

ndim = d._ndim

# Parse the axes. By default, reverse the order of the axes.
ndim = d.ndim
if axes is None:
if ndim <= 1:
return d

iaxes = tuple(range(ndim - 1, -1, -1))
else:
iaxes = d._parse_axes(axes) # , 'transpose')

# Return unchanged if axes are in the same order as the data
if iaxes == tuple(range(ndim)):
if inplace:
d = None
return d

if len(iaxes) != ndim:
raise ValueError(
"Can't tranpose: Axes don't match array: {}".format(iaxes)
)
# --- End: if
iaxes = d._parse_axes(axes)

# Permute the axes.
# Note: _axes attribute is still important/utilised post-Daskification
# because e.g. axes labelled as cyclic by the _cyclic attribute use it
# to determine their position (see #discussion_r694096462 on PR #247).
data_axes = d._axes
d._axes = [data_axes[i] for i in iaxes]

# Permute the shape
shape = d._shape
d._shape = tuple([shape[i] for i in iaxes])

# Permute the locations map
for partition in d.partitions.matrix.flat:
location = partition.location
shape = partition.shape

partition.location = [location[i] for i in iaxes]
partition.shape = [shape[i] for i in iaxes]
dx = d._get_dask()
try:
dx = da.transpose(dx, axes=axes)
except ValueError:
raise ValueError(
f"Can't transpose: Axes don't match array: {axes}"
)
d._set_dask(dx, reset_mask_hardness=False)

return d

Expand Down
29 changes: 12 additions & 17 deletions cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,29 +1582,24 @@ def test_Data_swapaxes(self):
self.assertEqual(b.shape, e.shape, message)
self.assertTrue((b == e.array).all(), message)

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attribute 'chunk_sizes'")
def test_Data_transpose(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
return

a = np.arange(10 * 15 * 19).reshape(10, 1, 15, 19)

for chunksize in self.chunk_sizes:
with cf.chunksize(chunksize):
d = cf.Data(a.copy())

for indices in (range(a.ndim), range(-a.ndim, 0)):
for axes in itertools.permutations(indices):
a = np.transpose(a, axes)
d.transpose(axes, inplace=True)
message = (
"cf.Data.transpose({}) failed: "
"d.shape={}, a.shape={}".format(
axes, d.shape, a.shape
)
)
self.assertEqual(d.shape, a.shape, message)
self.assertTrue((d.array == a).all(), message)
d = cf.Data(a.copy())

for indices in (range(a.ndim), range(-a.ndim, 0)):
for axes in itertools.permutations(indices):
a = np.transpose(a, axes)
d.transpose(axes, inplace=True)
message = (
"cf.Data.transpose({}) failed: "
"d.shape={}, a.shape={}".format(axes, d.shape, a.shape)
)
self.assertEqual(d.shape, a.shape, message)
self.assertTrue((d.array == a).all(), message)

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'")
def test_Data_unique(self):
Expand Down