10
10
from pandas .api .types import (
11
11
is_list_like ,
12
12
is_string_dtype ,
13
+ is_categorical_dtype ,
13
14
)
14
15
from pandas .core .dtypes .concat import concat_compat
15
16
@@ -1176,6 +1177,7 @@ def _final_frame_longer(
1176
1177
df = {** index , ** outcome , ** values }
1177
1178
1178
1179
df = pd .DataFrame (df , copy = False , index = df_index )
1180
+ df_index = None
1179
1181
1180
1182
if sort_by_appearance :
1181
1183
df = _sort_by_appearance_for_melt (df = df , len_index = len_index )
@@ -1198,6 +1200,9 @@ def pivot_wider(
1198
1200
flatten_levels : Optional [bool ] = True ,
1199
1201
names_sep : str = "_" ,
1200
1202
names_glue : str = None ,
1203
+ reset_index : bool = True ,
1204
+ names_expand : bool = False ,
1205
+ index_expand : bool = False ,
1201
1206
) -> pd .DataFrame :
1202
1207
"""
1203
1208
Reshapes data from *long* to *wide* form.
@@ -1222,6 +1227,7 @@ def pivot_wider(
1222
1227
at the start of each label in the columns.
1223
1228
1224
1229
1230
+
1225
1231
Example:
1226
1232
1227
1233
>>> import pandas as pd
@@ -1292,9 +1298,16 @@ def pivot_wider(
1292
1298
and uses python's `str.format_map` under the hood.
1293
1299
Simply create the string template,
1294
1300
using the column labels in `names_from`,
1295
- and special `_value` as a placeholder
1296
- if there are multiple `values_from`.
1301
+ and special `_value` as a placeholder for `values_from`.
1297
1302
Applicable only if `flatten_levels` is `True`.
1303
+ :param reset_index: Determines whether to restore `index`
1304
+ as a column/columns. Applicable only if `index` is provided,
1305
+ and `flatten_levels` is `True`. Default is `True`.
1306
+ :param names_expand: Expand columns to show all the categories.
1307
+ Applies only if `names_from` is a categorical column.
1308
+ Default is `False`.
1309
+ :param index_expand: Expand the index to show all the categories.
1310
+ Applies only if `index` is a categorical column. Default is `False`.
1298
1311
:returns: A pandas DataFrame that has been unpivoted from long to wide
1299
1312
form.
1300
1313
"""
@@ -1309,6 +1322,9 @@ def pivot_wider(
1309
1322
flatten_levels ,
1310
1323
names_sep ,
1311
1324
names_glue ,
1325
+ reset_index ,
1326
+ names_expand ,
1327
+ index_expand ,
1312
1328
)
1313
1329
1314
1330
@@ -1320,6 +1336,9 @@ def _computations_pivot_wider(
1320
1336
flatten_levels : Optional [bool ] = True ,
1321
1337
names_sep : str = "_" ,
1322
1338
names_glue : str = None ,
1339
+ reset_index : bool = True ,
1340
+ names_expand : bool = False ,
1341
+ index_expand : bool = False ,
1323
1342
) -> pd .DataFrame :
1324
1343
"""
1325
1344
This is the main workhorse of the `pivot_wider` function.
@@ -1339,6 +1358,9 @@ def _computations_pivot_wider(
1339
1358
flatten_levels ,
1340
1359
names_sep ,
1341
1360
names_glue ,
1361
+ reset_index ,
1362
+ names_expand ,
1363
+ index_expand ,
1342
1364
) = _data_checks_pivot_wider (
1343
1365
df ,
1344
1366
index ,
@@ -1347,30 +1369,54 @@ def _computations_pivot_wider(
1347
1369
flatten_levels ,
1348
1370
names_sep ,
1349
1371
names_glue ,
1372
+ reset_index ,
1373
+ names_expand ,
1374
+ index_expand ,
1350
1375
)
1351
- if flatten_levels :
1352
- # check dtype of `names_from` is string
1353
- names_from_all_strings = (
1354
- df .filter (names_from ).agg (is_string_dtype ).all ().item ()
1355
- )
1356
-
1357
- # check dtype of columns
1358
- column_dtype = is_string_dtype (df .columns )
1359
1376
1360
1377
df = df .pivot ( # noqa: PD010
1361
1378
index = index , columns = names_from , values = values_from
1362
1379
)
1363
1380
1364
- # an empty df is likely because
1365
- # there is no `values_from`
1381
+ indexer = df .index
1382
+ if index_expand and index :
1383
+ any_categoricals = (indexer .get_level_values (name ) for name in index )
1384
+ any_categoricals = any (map (is_categorical_dtype , any_categoricals ))
1385
+ if any_categoricals :
1386
+ indexer = _expand (indexer , retain_categories = True )
1387
+ df = df .reindex (index = indexer )
1388
+
1389
+ indexer = df .columns
1390
+ if names_expand :
1391
+ any_categoricals = (
1392
+ indexer .get_level_values (name ) for name in names_from
1393
+ )
1394
+ any_categoricals = any (map (is_categorical_dtype , any_categoricals ))
1395
+ if any_categoricals :
1396
+ retain_categories = True
1397
+ if flatten_levels & (
1398
+ (names_glue is not None )
1399
+ | isinstance (indexer , pd .MultiIndex )
1400
+ | ((index is not None ) & reset_index )
1401
+ ):
1402
+ retain_categories = False
1403
+ indexer = _expand (indexer , retain_categories = retain_categories )
1404
+ df = df .reindex (columns = indexer )
1405
+
1406
+ indexer = None
1366
1407
if any ((df .empty , not flatten_levels )):
1367
1408
return df
1368
1409
1369
1410
if isinstance (df .columns , pd .MultiIndex ):
1370
- if (not names_from_all_strings ) or (not column_dtype ):
1371
- new_columns = [tuple (map (str , entry )) for entry in df ]
1372
- else :
1373
- new_columns = [entry for entry in df ]
1411
+ new_columns = df .columns
1412
+ all_strings = (
1413
+ new_columns .get_level_values (num )
1414
+ for num in range (new_columns .nlevels )
1415
+ )
1416
+ all_strings = all (map (is_string_dtype , all_strings ))
1417
+ if not all_strings :
1418
+ new_columns = (tuple (map (str , entry )) for entry in new_columns )
1419
+
1374
1420
if names_glue is not None :
1375
1421
if ("_value" in names_from ) and (None in df .columns .names ):
1376
1422
warnings .warn (
@@ -1403,24 +1449,18 @@ def _computations_pivot_wider(
1403
1449
1404
1450
df .columns = new_columns
1405
1451
else :
1406
- if (not names_from_all_strings ) or (not column_dtype ):
1407
- df .columns = df .columns .astype (str )
1408
1452
if names_glue is not None :
1409
1453
try :
1410
1454
df .columns = [
1411
1455
names_glue .format_map ({names_from [0 ]: entry })
1412
- for entry in df
1456
+ for entry in df . columns
1413
1457
]
1414
1458
except KeyError as error :
1415
1459
raise KeyError (
1416
1460
f"{ error } is not a column label in names_from."
1417
1461
) from error
1418
1462
1419
- # if columns are of category type
1420
- # this returns columns to object dtype
1421
- # also, resetting index with category columns is not possible
1422
- df .columns = [* df .columns ]
1423
- if index :
1463
+ if index and reset_index :
1424
1464
df = df .reset_index ()
1425
1465
1426
1466
if df .columns .names :
@@ -1437,6 +1477,9 @@ def _data_checks_pivot_wider(
1437
1477
flatten_levels ,
1438
1478
names_sep ,
1439
1479
names_glue ,
1480
+ reset_index ,
1481
+ names_expand ,
1482
+ index_expand ,
1440
1483
):
1441
1484
1442
1485
"""
@@ -1464,9 +1507,12 @@ def _data_checks_pivot_wider(
1464
1507
if values_from is not None :
1465
1508
if is_list_like (values_from ):
1466
1509
values_from = [* values_from ]
1467
- values_from = _select_column_names (values_from , df )
1468
- if len (values_from ) == 1 :
1469
- values_from = values_from [0 ]
1510
+ out = _select_column_names (values_from , df )
1511
+ # hack to align with pd.pivot
1512
+ if values_from == out [0 ]:
1513
+ values_from = out [0 ]
1514
+ else :
1515
+ values_from = out
1470
1516
1471
1517
check ("flatten_levels" , flatten_levels , [bool ])
1472
1518
@@ -1476,6 +1522,10 @@ def _data_checks_pivot_wider(
1476
1522
if names_glue is not None :
1477
1523
check ("names_glue" , names_glue , [str ])
1478
1524
1525
+ check ("reset_index" , reset_index , [bool ])
1526
+ check ("names_expand" , names_expand , [bool ])
1527
+ check ("index_expand" , index_expand , [bool ])
1528
+
1479
1529
return (
1480
1530
df ,
1481
1531
index ,
@@ -1484,4 +1534,51 @@ def _data_checks_pivot_wider(
1484
1534
flatten_levels ,
1485
1535
names_sep ,
1486
1536
names_glue ,
1537
+ reset_index ,
1538
+ names_expand ,
1539
+ index_expand ,
1487
1540
)
1541
+
1542
+
1543
+ def _expand (indexer , retain_categories ):
1544
+ """
1545
+ Expand Index to all categories.
1546
+ Applies to categorical index, and used
1547
+ in _computations_pivot_wider for scenarios where
1548
+ names_expand and/or index_expand is True.
1549
+ Categories are preserved where possible.
1550
+ If `retain_categories` is False, a fastpath is taken
1551
+ to generate all possible combinations.
1552
+
1553
+ Returns an Index.
1554
+ """
1555
+ if indexer .nlevels > 1 :
1556
+ names = indexer .names
1557
+ if not retain_categories :
1558
+ indexer = pd .MultiIndex .from_product (indexer .levels , names = names )
1559
+ else :
1560
+ indexer = [
1561
+ indexer .get_level_values (n ) for n in range (indexer .nlevels )
1562
+ ]
1563
+ indexer = [
1564
+ pd .Categorical (
1565
+ values = arr .categories ,
1566
+ categories = arr .categories ,
1567
+ ordered = arr .ordered ,
1568
+ )
1569
+ if is_categorical_dtype (arr )
1570
+ else arr .unique ()
1571
+ for arr in indexer
1572
+ ]
1573
+ indexer = pd .MultiIndex .from_product (indexer , names = names )
1574
+
1575
+ else :
1576
+ if not retain_categories :
1577
+ indexer = indexer .categories
1578
+ else :
1579
+ indexer = pd .Categorical (
1580
+ values = indexer .categories ,
1581
+ categories = indexer .categories ,
1582
+ ordered = indexer .ordered ,
1583
+ )
1584
+ return indexer
0 commit comments