diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a31f3801b..dfe3906e48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,8 +4,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + ## [0.11.2] - 11/xx/2021 +### Added +- Extending `dpctl.device_context` with nested contexts (#678) + ## Fixed - Fixed issue #649 about incorrect behavior of `.T` method on sliced arrays (#653) diff --git a/dpctl/__init__.py b/dpctl/__init__.py index 80ce781f3b..efd93f3630 100644 --- a/dpctl/__init__.py +++ b/dpctl/__init__.py @@ -61,6 +61,7 @@ get_current_queue, get_num_activated_queues, is_in_device_context, + nested_context_factories, set_global_queue, ) @@ -111,6 +112,7 @@ "get_current_queue", "get_num_activated_queues", "is_in_device_context", + "nested_context_factories", "set_global_queue", ] __all__ += [ diff --git a/dpctl/_sycl_queue_manager.pyx b/dpctl/_sycl_queue_manager.pyx index 53d5058d04..c814b7286a 100644 --- a/dpctl/_sycl_queue_manager.pyx +++ b/dpctl/_sycl_queue_manager.pyx @@ -19,7 +19,7 @@ # cython: linetrace=True import logging -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from .enum_types import backend_type, device_type @@ -210,6 +210,22 @@ cpdef get_current_backend(): return _mgr.get_current_backend() +nested_context_factories = [] + + +def _get_nested_contexts(ctxt): + _help_numba_dppy() + return (factory(ctxt) for factory in nested_context_factories) + + +def _help_numba_dppy(): + """Import numba-dppy for registering nested contexts""" + try: + import numba_dppy + except Exception: + pass + + @contextmanager def device_context(arg): """ @@ -222,6 +238,9 @@ def device_context(arg): the context manager's scope. The yielded queue is removed as the currently usable queue on exiting the context manager. + You can register context factory in the list of factories. + This context manager uses context factories to create and activate nested contexts. + Args: queue_str (str) : A string corresponding to the DPC++ filter selector. @@ -243,11 +262,26 @@ def device_context(arg): with dpctl.device_context("level0:gpu:0"): pass + The following example registers nested context factory: + + .. code-block:: python + + import dctl + + def factory(sycl_queue): + ... + return context + + dpctl.nested_context_factories.append(factory) + """ ctxt = None try: ctxt = _mgr._set_as_current_queue(arg) - yield ctxt + with ExitStack() as stack: + for nested_context in _get_nested_contexts(ctxt): + stack.enter_context(nested_context) + yield ctxt finally: # Code to release resource if ctxt: diff --git a/dpctl/tests/test_sycl_queue_manager.py b/dpctl/tests/test_sycl_queue_manager.py index 5ff33e09b3..ae7c75cbbd 100644 --- a/dpctl/tests/test_sycl_queue_manager.py +++ b/dpctl/tests/test_sycl_queue_manager.py @@ -17,6 +17,8 @@ """Defines unit test cases for the SyclQueueManager class. """ +import contextlib + import pytest import dpctl @@ -156,3 +158,73 @@ def test_get_current_backend(): dpctl.set_global_queue("gpu") elif has_cpu(): dpctl.set_global_queue("cpu") + + +def test_nested_context_factory_is_empty_list(): + assert isinstance(dpctl.nested_context_factories, list) + assert not dpctl.nested_context_factories + + +@contextlib.contextmanager +def _register_nested_context_factory(factory): + dpctl.nested_context_factories.append(factory) + try: + yield + finally: + dpctl.nested_context_factories.remove(factory) + + +def test_register_nested_context_factory_context(): + def factory(): + pass + + with _register_nested_context_factory(factory): + assert factory in dpctl.nested_context_factories + + assert isinstance(dpctl.nested_context_factories, list) + assert not dpctl.nested_context_factories + + +@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available") +def test_device_context_activates_nested_context(): + in_context = False + factory_called = False + + @contextlib.contextmanager + def context(): + nonlocal in_context + old, in_context = in_context, True + yield + in_context = old + + def factory(_): + nonlocal factory_called + factory_called = True + return context() + + with _register_nested_context_factory(factory): + assert not factory_called + assert not in_context + + with dpctl.device_context("opencl:cpu:0"): + assert factory_called + assert in_context + + assert not in_context + + +@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available") +@pytest.mark.parametrize( + "factory, exception, match", + [ + (True, TypeError, "object is not callable"), + (lambda x: None, AttributeError, "no attribute '__exit__'"), + ], +) +def test_nested_context_factory_exception_if_wrong_factory( + factory, exception, match +): + with pytest.raises(exception, match=match): + with _register_nested_context_factory(factory): + with dpctl.device_context("opencl:cpu:0"): + pass