Skip to content

Commit 233c8fb

Browse files
committed
api: add mul interp mode
1 parent 995fc55 commit 233c8fb

File tree

10 files changed

+83
-17
lines changed

10 files changed

+83
-17
lines changed

devito/core/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _normalize_kwargs(cls, **kwargs):
7676

7777
# Code generation options for derivatives
7878
o['expand'] = oo.pop('expand', cls.EXPAND)
79+
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
7980
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
8081
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
8182
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

devito/core/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _normalize_kwargs(cls, **kwargs):
8989

9090
# Code generation options for derivatives
9191
o['expand'] = oo.pop('expand', cls.EXPAND)
92+
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
9293
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
9394
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
9495
o['deriv-unroll'] = oo.pop('deriv-unroll', False)

devito/core/operator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ class BasicOperator(Operator):
123123
finite-difference derivatives.
124124
"""
125125

126+
MUL_FIRST = False
127+
"""
128+
When evaluating expressions location, prioritize multiplication
129+
operations.
130+
"""
131+
126132
DERIV_COLLECT = True
127133
"""
128134
Factorize finite-difference derivatives exploiting the linearity of the FD

devito/finite_differences/derivative.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def T(self):
474474

475475
return self._rebuild(transpose=adjoint)
476476

477-
def _eval_at(self, func):
477+
def _eval_at(self, func, **kwargs):
478478
"""
479479
Evaluates the derivative at the location of `func`. It is necessary for staggered
480480
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
@@ -521,7 +521,7 @@ def _eval_at(self, func):
521521
return self._rebuild(self.expr, **rkw)
522522
args = [self.expr.func(*v) for v in mapper.values()]
523523
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
524-
args = [self._rebuild(a)._eval_at(func) for a in args]
524+
args = [self._rebuild(a)._eval_at(func, **kwargs) for a in args]
525525
return self.expr.func(*args)
526526
elif self.expr.is_Mul:
527527
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear
@@ -594,7 +594,7 @@ def _eval_fd(self, expr, **kwargs):
594594
res = generic_derivative(expr, self.dims[0], self.fd_order[0],
595595
self.deriv_order[0], weights=self.weights,
596596
side=self.side, matvec=self.transpose,
597-
x0=self.x0, expand=expand)
597+
x0=x0_deriv, expand=expand)
598598

599599
# Step 4: Apply substitutions
600600
for e in self._ppsubs:

devito/finite_differences/differentiable.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,12 @@ def coefficients(self):
155155
key = lambda x: coeff_priority.get(x, -1)
156156
return sorted(coefficients, key=key, reverse=True)[0]
157157

158-
def _eval_at(self, func):
158+
def _eval_at(self, func, **kwargs):
159159
if not func.is_Staggered:
160160
# Cartesian grid, do no waste time
161161
return self
162-
return self.func(*[getattr(a, '_eval_at', lambda x: a)(func) for a in self.args])
162+
return self.func(*[getattr(a, '_eval_at', lambda x: a)(func, **kwargs)
163+
for a in self.args])
163164

164165
def _subs(self, old, new, **hints):
165166
if old == self:
@@ -454,7 +455,11 @@ def highest_priority(DiffOp):
454455
# set of dimensions is used when multiple ones with the same
455456
# priority appear
456457
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
457-
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
458+
args = DiffOp._args_diff
459+
if not args:
460+
return DiffOp
461+
else:
462+
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
458463

459464

460465
class DifferentiableOp(Differentiable):
@@ -520,7 +525,7 @@ class DifferentiableFunction(DifferentiableOp):
520525
def __new__(cls, *args, **kwargs):
521526
return cls.__sympy_class__.__new__(cls, *args, **kwargs)
522527

523-
def _eval_at(self, func):
528+
def _eval_at(self, func, **kwargs):
524529
return self
525530

526531

@@ -629,6 +634,56 @@ def _gather_for_diff(self):
629634

630635
return self.func(*new_args, evaluate=False)
631636

637+
def _eval_at(self, func, mul_first=False, **kwargs):
638+
# Dont evaluate mul first
639+
if not mul_first:
640+
return super()._eval_at(func, mul_first=mul_first)
641+
642+
# Not a basic a*b*c... expression, just defer to superclass
643+
if any(isinstance(f, DifferentiableOp) for f in self.args):
644+
return super()._eval_at(func, mul_first=mul_first)
645+
646+
# Split Derivative and Differentiable args
647+
derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative))
648+
649+
if derivs:
650+
derivs = Differentiable._eval_at(self.func(*derivs), func,
651+
mul_first=mul_first)
652+
else:
653+
derivs = 1
654+
655+
if not other:
656+
return derivs
657+
elif len(other) > 1:
658+
expr = self.func(*other)._gather_for_diff
659+
else:
660+
expr = other[0]
661+
662+
# Non differentiable expr (e.g., number)
663+
if not isinstance(expr, Differentiable):
664+
return self.func(derivs, expr)
665+
666+
# Build mapper for dimensions that need to be interpolated
667+
mapper = {}
668+
for d in self.dimensions:
669+
try:
670+
if self.indices_ref[d] is not func.indices_ref[d]:
671+
mapper[d] = func.indices_ref[d]
672+
except KeyError:
673+
pass
674+
675+
# Nothing to interpolate
676+
if not mapper:
677+
return super()._eval_at(func, mul_first=mul_first)
678+
679+
# Interpolate expr at the required indices
680+
interp = expr.diff(*mapper.keys(), deriv_order=[0 for _ in mapper],
681+
fd_order=[self.interp_order for _ in mapper],
682+
x0=mapper)
683+
684+
# Return the full expression with Derivatives
685+
return self.func(derivs, interp)
686+
632687

633688
class Pow(DifferentiableOp, sympy.Pow):
634689
_fd_priority = 0
@@ -971,7 +1026,7 @@ def _subs(self, old, new, **hints):
9711026

9721027
class DiffDerivative(IndexDerivative, DifferentiableOp):
9731028

974-
def _eval_at(self, func):
1029+
def _eval_at(self, func, **kwargs):
9751030
# Like EvalDerivative, a DiffDerivative must have already been evaluated
9761031
# at a valid x0 and should not be re-evaluated at a different location
9771032
return self
@@ -1022,7 +1077,7 @@ def _new_rawargs(self, *args, **kwargs):
10221077
kwargs.pop('is_commutative', None)
10231078
return self.func(*args, **kwargs)
10241079

1025-
def _eval_at(self, func):
1080+
def _eval_at(self, func, **kwargs):
10261081
# An EvalDerivative must have already been evaluated at a valid x0
10271082
# and should not be re-evaluated at a different location
10281083
return self

devito/operator/operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def _lower_exprs(cls, expressions, **kwargs):
338338
* Shift indices for domain alignment.
339339
"""
340340
expand = kwargs['options'].get('expand', True)
341+
mul_first = kwargs['options'].get('eval-mul-first', False)
341342

342343
# Specialization is performed on unevaluated expressions
343344
expressions = cls._specialize_dsl(expressions, **kwargs)
@@ -348,7 +349,8 @@ def _lower_exprs(cls, expressions, **kwargs):
348349
# ModuloDimensions
349350
if not expand:
350351
expand = lambda d: d.is_Stepping
351-
expressions = flatten([i._evaluate(expand=expand) for i in expressions])
352+
expressions = flatten([i._evaluate(expand=expand, mul_first=mul_first)
353+
for i in expressions])
352354

353355
# Scalarize the tensor equations, if any
354356
expressions = [j for i in expressions for j in i._flatten]

devito/types/dense.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,9 +1070,9 @@ def __fd_setup__(self):
10701070

10711071
@cached_property
10721072
def _fd_priority(self):
1073-
return 1 if self.staggered.on_node else 2
1073+
return 1.2 if self.staggered.on_node else 1.1
10741074

1075-
def _eval_at(self, func):
1075+
def _eval_at(self, func, **kwargs):
10761076
if self.staggered == func.staggered:
10771077
return self
10781078

@@ -1491,7 +1491,7 @@ def __shape_setup__(cls, **kwargs):
14911491

14921492
@cached_property
14931493
def _fd_priority(self):
1494-
return 2.1 if self.staggered.on_node else 2.2
1494+
return 2.1 if self.staggered.on_node else 2
14951495

14961496
@property
14971497
def time_order(self):

devito/types/equation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _evaluate(self, **kwargs):
110110
"""
111111
try:
112112
lhs = self.lhs._evaluate(**kwargs)
113-
rhs = self.rhs._eval_at(self.lhs)._evaluate(**kwargs)
113+
rhs = self.rhs._eval_at(self.lhs, **kwargs)._evaluate(**kwargs)
114114
except AttributeError:
115115
lhs, rhs = self._evaluate_args(**kwargs)
116116
eq = self.func(lhs, rhs, subdomain=self.subdomain,

devito/types/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def _dist_scatter(self, alias=None, data=None):
662662
mapper.update(self._dist_subfunc_scatter(sf))
663663
return mapper
664664

665-
def _eval_at(self, func):
665+
def _eval_at(self, func, **kwargs):
666666
return self
667667

668668
def _halo_exchange(self):

devito/types/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,13 @@ def __getattr__(self, name):
155155
except:
156156
raise AttributeError("%r object has no attribute %r" % (self.__class__, name))
157157

158-
def _eval_at(self, func):
158+
def _eval_at(self, func, **kwargs):
159159
"""
160160
Evaluate tensor at func location
161161
"""
162162
def entries(i, j, func):
163-
return getattr(self[i, j], '_eval_at', lambda x: self[i, j])(func[i, j])
163+
return getattr(self[i, j], '_eval_at',
164+
lambda x: self[i, j])(func[i, j], **kwargs)
164165
entry = lambda i, j: entries(i, j, func)
165166
return self._new(self.rows, self.cols, entry)
166167

0 commit comments

Comments
 (0)