@@ -155,11 +155,12 @@ def coefficients(self):
155155 key = lambda x : coeff_priority .get (x , - 1 )
156156 return sorted (coefficients , key = key , reverse = True )[0 ]
157157
158- def _eval_at (self , func ):
158+ def _eval_at (self , func , ** kwargs ):
159159 if not func .is_Staggered :
160160 # Cartesian grid, do no waste time
161161 return self
162- return self .func (* [getattr (a , '_eval_at' , lambda x : a )(func ) for a in self .args ])
162+ return self .func (* [getattr (a , '_eval_at' , lambda x : a )(func , ** kwargs )
163+ for a in self .args ])
163164
164165 def _subs (self , old , new , ** hints ):
165166 if old == self :
@@ -454,7 +455,11 @@ def highest_priority(DiffOp):
454455 # set of dimensions is used when multiple ones with the same
455456 # priority appear
456457 prio = lambda x : (getattr (x , '_fd_priority' , 0 ), len (x .dimensions ))
457- return sorted (DiffOp ._args_diff , key = prio , reverse = True )[0 ]
458+ args = DiffOp ._args_diff
459+ if not args :
460+ return DiffOp
461+ else :
462+ return sorted (DiffOp ._args_diff , key = prio , reverse = True )[0 ]
458463
459464
460465class DifferentiableOp (Differentiable ):
@@ -520,7 +525,7 @@ class DifferentiableFunction(DifferentiableOp):
520525 def __new__ (cls , * args , ** kwargs ):
521526 return cls .__sympy_class__ .__new__ (cls , * args , ** kwargs )
522527
523- def _eval_at (self , func ):
528+ def _eval_at (self , func , ** kwargs ):
524529 return self
525530
526531
@@ -629,6 +634,56 @@ def _gather_for_diff(self):
629634
630635 return self .func (* new_args , evaluate = False )
631636
637+ def _eval_at (self , func , mul_first = False , ** kwargs ):
638+ # Dont evaluate mul first
639+ if not mul_first :
640+ return super ()._eval_at (func , mul_first = mul_first )
641+
642+ # Not a basic a*b*c... expression, just defer to superclass
643+ if any (isinstance (f , DifferentiableOp ) for f in self .args ):
644+ return super ()._eval_at (func , mul_first = mul_first )
645+
646+ # Split Derivative and Differentiable args
647+ derivs , other = split (self .args , lambda e : isinstance (e , sympy .Derivative ))
648+
649+ if derivs :
650+ derivs = Differentiable ._eval_at (self .func (* derivs ), func ,
651+ mul_first = mul_first )
652+ else :
653+ derivs = 1
654+
655+ if not other :
656+ return derivs
657+ elif len (other ) > 1 :
658+ expr = self .func (* other )._gather_for_diff
659+ else :
660+ expr = other [0 ]
661+
662+ # Non differentiable expr (e.g., number)
663+ if not isinstance (expr , Differentiable ):
664+ return self .func (derivs , expr )
665+
666+ # Build mapper for dimensions that need to be interpolated
667+ mapper = {}
668+ for d in self .dimensions :
669+ try :
670+ if self .indices_ref [d ] is not func .indices_ref [d ]:
671+ mapper [d ] = func .indices_ref [d ]
672+ except KeyError :
673+ pass
674+
675+ # Nothing to interpolate
676+ if not mapper :
677+ return super ()._eval_at (func , mul_first = mul_first )
678+
679+ # Interpolate expr at the required indices
680+ interp = expr .diff (* mapper .keys (), deriv_order = [0 for _ in mapper ],
681+ fd_order = [self .interp_order for _ in mapper ],
682+ x0 = mapper )
683+
684+ # Return the full expression with Derivatives
685+ return self .func (derivs , interp )
686+
632687
633688class Pow (DifferentiableOp , sympy .Pow ):
634689 _fd_priority = 0
@@ -971,7 +1026,7 @@ def _subs(self, old, new, **hints):
9711026
9721027class DiffDerivative (IndexDerivative , DifferentiableOp ):
9731028
974- def _eval_at (self , func ):
1029+ def _eval_at (self , func , ** kwargs ):
9751030 # Like EvalDerivative, a DiffDerivative must have already been evaluated
9761031 # at a valid x0 and should not be re-evaluated at a different location
9771032 return self
@@ -1022,7 +1077,7 @@ def _new_rawargs(self, *args, **kwargs):
10221077 kwargs .pop ('is_commutative' , None )
10231078 return self .func (* args , ** kwargs )
10241079
1025- def _eval_at (self , func ):
1080+ def _eval_at (self , func , ** kwargs ):
10261081 # An EvalDerivative must have already been evaluated at a valid x0
10271082 # and should not be re-evaluated at a different location
10281083 return self
0 commit comments