Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 17 additions & 72 deletions sumpy/fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,54 +235,6 @@ def done(self):
# }}}


# {{{ translation classes data

class SumpyTranslationClassesData:
"""A class for building and storing additional, optional data for
precomputation of translation classes passed to the expansion wrangler."""

def __init__(self, queue, trav, is_translation_per_level=True):
# FIXME: Queues should not be part of data.
self.queue = queue
self.trav = trav
self.tree = trav.tree
self.is_translation_per_level = is_translation_per_level

@property
@memoize_method
def translation_classes_builder(self):
from boxtree.translation_classes import TranslationClassesBuilder
return TranslationClassesBuilder(self.queue.context)

@memoize_method
def build_translation_classes_lists(self):
return self.translation_classes_builder(self.queue, self.trav, self.tree,
is_translation_per_level=self.is_translation_per_level)[0]

@memoize_method
def m2l_translation_classes_lists(self):
return (self
.build_translation_classes_lists()
.from_sep_siblings_translation_classes)

@memoize_method
def m2l_translation_vectors(self):
return (self
.build_translation_classes_lists()
.from_sep_siblings_translation_class_to_distance_vector)

def m2l_translation_classes_level_starts(self):
return (self
.build_translation_classes_lists()
.from_sep_siblings_translation_classes_level_starts)


class SumpyTranslationClassesDataNotSuppliedWarning(UserWarning):
pass

# }}}


# {{{ expansion wrangler

class SumpyExpansionWrangler(ExpansionWranglerInterface):
Expand Down Expand Up @@ -315,7 +267,8 @@ def __init__(self, tree_indep, traversal, dtype, fmm_level_to_order,
kernel_extra_kwargs=None,
self_extra_kwargs=None,
translation_classes_data=None,
preprocessed_mpole_dtype=None):
preprocessed_mpole_dtype=None,
*, _disable_translation_classes=False):
super().__init__(tree_indep, traversal)
self.issued_timing_data_warning = False

Expand Down Expand Up @@ -353,27 +306,18 @@ def __init__(self, tree_indep, traversal, dtype, fmm_level_to_order,
self.extra_kwargs = source_extra_kwargs.copy()
self.extra_kwargs.update(self.kernel_extra_kwargs)

if base_kernel.is_translation_invariant:
if translation_classes_data is None:
from warnings import warn
if self.tree_indep.use_fft_for_m2l:
raise NotImplementedError(
"FFT based List 2 (multipole-to-local) translations "
"without translation_classes_data argument is not "
"implemented. Supply a translation_classes_data argument "
"to the wrangler for optimized List 2.")
else:
warn(
"List 2 (multipole-to-local) translations will be "
"unoptimized. Supply a translation_classes_data argument "
"to the wrangler for optimized List 2.",
SumpyTranslationClassesDataNotSuppliedWarning,
stacklevel=2)
self.supports_translation_classes = False
else:
self.supports_translation_classes = True
else:
if _disable_translation_classes or not base_kernel.is_translation_invariant:
self.supports_translation_classes = False
else:
if translation_classes_data is None:
with cl.CommandQueue(self.tree_indep.cl_context) as queue:
from boxtree.translation_classes import TranslationClassesBuilder
translation_classes_builder = TranslationClassesBuilder(
queue.context)
translation_classes_data, _ = translation_classes_builder(
queue, traversal, self.tree,
is_translation_per_level=True)
self.supports_translation_classes = True

self.translation_classes_data = translation_classes_data
self.use_fft_for_m2l = self.tree_indep.use_fft_for_m2l
Expand Down Expand Up @@ -407,7 +351,7 @@ def local_expansions_level_starts(self):
def m2l_translation_class_level_start_box_nrs(self):
with cl.CommandQueue(self.tree_indep.cl_context) as queue:
data = self.translation_classes_data
return data.m2l_translation_classes_level_starts().get(queue)
return data.from_sep_siblings_translation_classes_level_starts.get(queue)

@memoize_method
def m2l_translation_classes_dependent_data_level_starts(self):
Expand Down Expand Up @@ -726,8 +670,9 @@ def multipole_to_local_precompute(self):
if ntranslation_classes == 0:
continue

data = self.translation_classes_data
m2l_translation_vectors = (
self.translation_classes_data.m2l_translation_vectors())
data.from_sep_siblings_translation_class_to_distance_vector)

evt, _ = precompute_kernel(
queue,
Expand Down Expand Up @@ -767,7 +712,7 @@ def _add_m2l_precompute_kwargs(self, kwargs_for_m2l,
kwargs_for_m2l["translation_classes_level_start"] = \
translation_classes_level_start
kwargs_for_m2l["m2l_translation_classes_lists"] = \
self.translation_classes_data.m2l_translation_classes_lists()
self.translation_classes_data.from_sep_siblings_translation_classes

def multipole_to_local(self,
level_start_target_box_nrs,
Expand Down
41 changes: 12 additions & 29 deletions test/test_fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@
LinearPDEConformingVolumeTaylorLocalExpansion)
from sumpy.fmm import (
SumpyTreeIndependentDataForWrangler,
SumpyExpansionWrangler,
SumpyTranslationClassesData,
SumpyTranslationClassesDataNotSuppliedWarning)
SumpyExpansionWrangler)

import pytest
import warnings

import logging
logger = logging.getLogger(__name__)
Expand All @@ -57,7 +54,7 @@
faulthandler.enable()


@pytest.mark.parametrize("optimized_m2l, use_fft",
@pytest.mark.parametrize("use_translation_classes, use_fft",
[(False, False), (True, False), (True, True)])
@pytest.mark.parametrize(
("knl", "local_expn_class", "mpole_expn_class",
Expand All @@ -84,7 +81,7 @@
False),
])
def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
order_varies_with_level, optimized_m2l, use_fft):
order_varies_with_level, use_translation_classes, use_fft):
logging.basicConfig(level=logging.INFO)

if local_expn_class == VolumeTaylorLocalExpansion and use_fft:
Expand Down Expand Up @@ -188,11 +185,6 @@ def test_sumpy_fmm(ctx_factory, knl, local_expn_class, mpole_expn_class,
for order in order_values:
target_kernels = [knl]

if optimized_m2l:
translation_classes_data = SumpyTranslationClassesData(queue, trav)
else:
translation_classes_data = None

tree_indep = SumpyTreeIndependentDataForWrangler(
ctx,
partial(mpole_expn_class, knl),
Expand All @@ -206,14 +198,10 @@ def fmm_level_to_order(kernel, kernel_args, tree, lev):
def fmm_level_to_order(kernel, kernel_args, tree, lev):
return order

with warnings.catch_warnings():
if not optimized_m2l:
warnings.simplefilter("ignore",
SumpyTranslationClassesDataNotSuppliedWarning)
wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
fmm_level_to_order=fmm_level_to_order,
kernel_extra_kwargs=extra_kwargs,
translation_classes_data=translation_classes_data)
wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
fmm_level_to_order=fmm_level_to_order,
kernel_extra_kwargs=extra_kwargs,
_disable_translation_classes=not use_translation_classes)

from boxtree.fmm import drive_fmm

Expand Down Expand Up @@ -315,8 +303,7 @@ def test_unified_single_and_double(ctx_factory):
strength_usage=strength_usage)
wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order,
source_extra_kwargs=source_extra_kwargs,
translation_classes_data=SumpyTranslationClassesData(queue, trav))
source_extra_kwargs=source_extra_kwargs)

from boxtree.fmm import drive_fmm

Expand Down Expand Up @@ -376,8 +363,7 @@ def test_sumpy_fmm_timing_data_collection(ctx_factory):
target_kernels)

wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order,
translation_classes_data=SumpyTranslationClassesData(queue, trav))
fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order)
from boxtree.fmm import drive_fmm

timing_data = {}
Expand Down Expand Up @@ -435,8 +421,7 @@ def test_sumpy_fmm_exclude_self(ctx_factory):

wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order,
self_extra_kwargs=self_extra_kwargs,
translation_classes_data=SumpyTranslationClassesData(queue, trav))
self_extra_kwargs=self_extra_kwargs)

from boxtree.fmm import drive_fmm

Expand Down Expand Up @@ -510,8 +495,7 @@ def test_sumpy_axis_source_derivative(ctx_factory):

wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order,
self_extra_kwargs=self_extra_kwargs,
translation_classes_data=SumpyTranslationClassesData(queue, trav))
self_extra_kwargs=self_extra_kwargs)

from boxtree.fmm import drive_fmm

Expand Down Expand Up @@ -580,8 +564,7 @@ def test_sumpy_target_point_multiplier(ctx_factory, deriv_axes):

wrangler = SumpyExpansionWrangler(tree_indep, trav, dtype,
fmm_level_to_order=lambda kernel, kernel_args, tree, lev: order,
self_extra_kwargs=self_extra_kwargs,
translation_classes_data=SumpyTranslationClassesData(queue, trav))
self_extra_kwargs=self_extra_kwargs)

from boxtree.fmm import drive_fmm

Expand Down