Skip to content

Commit 1d99f90

Browse files
committed
clarify behavior when dtype=None in sum, prod and trace
1 parent e72b3ca commit 1d99f90

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

src/array_api_stubs/_draft/linalg.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,10 +718,11 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr
718718
data type of the returned array. If ``None``,
719719
720720
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
721-
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
722-
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
723-
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
724-
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
721+
- if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``,
722+
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
723+
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
724+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
725+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
725726
726727
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.
727728

src/array_api_stubs/_draft/statistical_functions.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,11 @@ def prod(
143143
data type of the returned array. If ``None``,
144144
145145
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
146-
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
147-
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
148-
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
149-
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
146+
- if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``,
147+
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
148+
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
149+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
150+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
150151
151152
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the product. Default: ``None``.
152153
@@ -240,10 +241,11 @@ def sum(
240241
data type of the returned array. If ``None``,
241242
242243
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
243-
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
244-
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
245-
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
246-
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
244+
- if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``,
245+
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
246+
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
247+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
248+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
247249
248250
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.
249251

0 commit comments

Comments
 (0)