The dtype promotion check in _fix_promotion does not correctly identify scalar inputs, and unconditionally accesses .dtype.
This breaks binary operators with float scalar inputs.
The can be fixed by accessing dtype via getattr with a None default or validating that the input is not a scalar.
Happy to provide a PR.
Minimal repo, in version 1.4, via:
import torch
import numpy
import array_api_compat as aac
aac.__version__ ()
t = torch.arange(10)
n = numpy.arange(10)
numpy.add(n, 1.0)
torch.add(t, 1.0)
aac.get_namespace(n).add(n, 1.0)
aac.get_namespace(t).add(t, 1.0)
Raises:
9 torch.add(t, 1.0)
11 aac.get_namespace(n).add(n, 1.0)
---> 12 aac.get_namespace(t).add(t, 1.0)
File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:91, in _two_arg.<locals>._f(x1, x2, **kwargs)
89 @wraps(f)
90 def _f(x1, x2, /, **kwargs):
---> 91 x1, x2 = _fix_promotion(x1, x2)
92 return f(x1, x2, **kwargs)
File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:104, in _fix_promotion(x1, x2, only_scalar)
103 def _fix_promotion(x1, x2, only_scalar=True):
--> 104 if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
105 return x1, x2
106 # If an argument is 0-D pytorch downcasts the other argument
AttributeError: 'float' object has no attribute 'dtype'
Would expect equivalent behavior to torch.add.
See:
https://gist.github.com/asford/ee688d59f0747a6507b9670a83fa7c47
The dtype promotion check in
_fix_promotiondoes not correctly identify scalar inputs, and unconditionally accesses.dtype.This breaks binary operators with
floatscalar inputs.The can be fixed by accessing dtype via
getattrwith aNonedefault or validating that the input is not a scalar.Happy to provide a PR.
Minimal repo, in version 1.4, via:
Raises:
Would expect equivalent behavior to
torch.add.See:
https://gist.github.com/asford/ee688d59f0747a6507b9670a83fa7c47