Skip to content

PR #41205: [XLA:GPU] Migrate CuDnnCmd into CuDnnThunk#116349

Draft
copybara-service[bot] wants to merge 1 commit intomasterfrom
exported_pr_902579190
Draft

PR #41205: [XLA:GPU] Migrate CuDnnCmd into CuDnnThunk#116349
copybara-service[bot] wants to merge 1 commit intomasterfrom
exported_pr_902579190

Conversation

@copybara-service
Copy link
Copy Markdown

PR #41205: [XLA:GPU] Migrate CuDnnCmd into CuDnnThunk

Imported from GitHub PR openxla/xla#41205

Follows the same pattern as the MemzeroCmd/Memset32Cmd/CublasLtMatmulThunk migrations: eliminate the standalone CuDnnCmd wrapper class by making CuDnnThunk directly implement Command (via TracedCommand) and provide its own Record() method.

Before

CuDnnThunk (a plain Thunk subclass) was converted by the command buffer emitter into a heap-allocated CuDnnCmd via Convert(), splitting the same logical op across two classes with duplicated operand/graph state.

After

CuDnnThunk inherits from TracedCommand (using the protected constructor that preserves Kind::kCuDnn + ThunkInfo), implements Record() directly, and is emplaced into the command sequence via static_cast -- matching GemmThunk, CublasLtMatmulThunk, and MemzeroThunk. CuDnnCmd is deleted entirely.

Record() prefers explicit command-buffer construction when the graph reports SupportsExplicitCommandBufferConstruction()==true (CreateDnnGraphCommand / UpdateDnnGraphCommand) and falls back to RecordTracedCommand otherwise.

Initialize() is made safe for pre-populated graphs (e.g. when tests bypass the fingerprint deserialization path): the short-circuit lives inside absl::call_once to keep the read synchronized with concurrent Initialize() calls from other streams/devices.

Tests

  • cuda_command_buffer_thunk_test.cc: existing CommandBufferThunkTest.CuDnnCmd rewired to the new owning-SequentialThunk shape.
  • cudnn_thunk_test.cc: new CuDnnThunkCmdBufTest fixture + 4 TEST_F cases covering Record() in {Create, Update} x {explicit, implicit} modes. The implicit path is driven by a test-only FakeDnnGraph that forces SupportsExplicitCommandBufferConstruction()->false and whose Execute() emits a verifiable Memset32 on the trace stream. The Update tests allocate a fresh output buffer to exercise the non-cache-hit path in both UpdateDnnGraphCommand and RecordTracedCommand.

Copybara import of the project:

--
f856b7d88c29d4986b872f74d54b7a56d3ab959f by Shawn Wang shawnw@nvidia.com:

[XLA:GPU] Migrate CuDnnCmd into CuDnnThunk

Follows the same pattern as the MemzeroCmd/Memset32Cmd/CublasLtMatmulThunk
migrations: eliminate the standalone CuDnnCmd wrapper class by making
CuDnnThunk directly implement Command (via TracedCommand) and provide its
own Record() method.

Before

CuDnnThunk (a plain Thunk subclass) was converted by the command buffer
emitter into a heap-allocated CuDnnCmd via Convert(), splitting
the same logical op across two classes with duplicated operand/graph state.

After

CuDnnThunk inherits from TracedCommand (using the protected constructor that
preserves Kind::kCuDnn + ThunkInfo), implements Record() directly, and is
emplaced into the command sequence via static_cast -- matching GemmThunk,
CublasLtMatmulThunk, and MemzeroThunk. CuDnnCmd is deleted entirely.

Record() prefers explicit command-buffer construction when the graph
reports SupportsExplicitCommandBufferConstruction()==true
(CreateDnnGraphCommand / UpdateDnnGraphCommand) and falls back to
RecordTracedCommand otherwise.

Initialize() is made safe for pre-populated graphs (e.g. when tests bypass
the fingerprint deserialization path): the short-circuit lives inside
absl::call_once to keep the read synchronized with concurrent Initialize()
calls from other streams/devices.

Tests

  • cuda_command_buffer_thunk_test.cc: existing CommandBufferThunkTest.CuDnnCmd
    rewired to the new owning-SequentialThunk shape.
  • cudnn_thunk_test.cc: new CuDnnThunkCmdBufTest fixture + 4 TEST_F cases
    covering Record() in {Create, Update} x {explicit, implicit} modes. The
    implicit path is driven by a test-only FakeDnnGraph that forces
    SupportsExplicitCommandBufferConstruction()->false and whose Execute()
    emits a verifiable Memset32 on the trace stream. The Update tests allocate
    a fresh output buffer to exercise the non-cache-hit path in both
    UpdateDnnGraphCommand and RecordTracedCommand.

BUILD

  • :command_buffer_cmd drops //xla/stream_executor:dnn (no longer used).
  • :cudnn_thunk picks up :command, :traced_command, :command_buffer, :stream,
    //xla:util, absl/log, absl/log:check; drops :errors (no TF_* macros left).
  • :cudnn_thunk_test promoted from xla_cc_test to
    xla_test(backends=["gpu"], tags=["cuda-only"]) with the deps needed for
    the new direct-Record tests (cudnn_plugin, cudnn_frontend, command_state,
    collective_params, etc.).
  • BUILD integrity audited 1:1 against #include graphs for all modified
    targets.

Merging this change closes #41205

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#41205 from shawnwang18:shawnw/migrate_cudnn_cmd_to_thunk f856b7d88c29d4986b872f74d54b7a56d3ab959f

Imported from GitHub PR openxla/xla#41205

Follows the same pattern as the MemzeroCmd/Memset32Cmd/CublasLtMatmulThunk migrations: eliminate the standalone CuDnnCmd wrapper class by making CuDnnThunk directly implement Command (via TracedCommand) and provide its own Record() method.

Before
------
CuDnnThunk (a plain Thunk subclass) was converted by the command buffer emitter into a heap-allocated CuDnnCmd via Convert<CuDnnThunk>(), splitting the same logical op across two classes with duplicated operand/graph state.

After
-----
CuDnnThunk inherits from TracedCommand (using the protected constructor that preserves Kind::kCuDnn + ThunkInfo), implements Record() directly, and is emplaced into the command sequence via static_cast -- matching GemmThunk, CublasLtMatmulThunk, and MemzeroThunk. CuDnnCmd is deleted entirely.

Record() prefers explicit command-buffer construction when the graph reports SupportsExplicitCommandBufferConstruction()==true (CreateDnnGraphCommand / UpdateDnnGraphCommand) and falls back to RecordTracedCommand otherwise.

Initialize() is made safe for pre-populated graphs (e.g. when tests bypass the fingerprint deserialization path): the short-circuit lives inside absl::call_once to keep the read synchronized with concurrent Initialize() calls from other streams/devices.

Tests
-----
- cuda_command_buffer_thunk_test.cc: existing CommandBufferThunkTest.CuDnnCmd rewired to the new owning-SequentialThunk shape.
- cudnn_thunk_test.cc: new CuDnnThunkCmdBufTest fixture + 4 TEST_F cases covering Record() in {Create, Update} x {explicit, implicit} modes. The implicit path is driven by a test-only FakeDnnGraph that forces SupportsExplicitCommandBufferConstruction()->false and whose Execute() emits a verifiable Memset32 on the trace stream. The Update tests allocate a fresh output buffer to exercise the non-cache-hit path in both UpdateDnnGraphCommand and RecordTracedCommand.

Copybara import of the project:

--
f856b7d88c29d4986b872f74d54b7a56d3ab959f by Shawn Wang <shawnw@nvidia.com>:

[XLA:GPU] Migrate CuDnnCmd into CuDnnThunk

Follows the same pattern as the MemzeroCmd/Memset32Cmd/CublasLtMatmulThunk
migrations: eliminate the standalone CuDnnCmd wrapper class by making
CuDnnThunk directly implement Command (via TracedCommand) and provide its
own Record() method.

Before
------
CuDnnThunk (a plain Thunk subclass) was converted by the command buffer
emitter into a heap-allocated CuDnnCmd via Convert<CuDnnThunk>(), splitting
the same logical op across two classes with duplicated operand/graph state.

After
-----
CuDnnThunk inherits from TracedCommand (using the protected constructor that
preserves Kind::kCuDnn + ThunkInfo), implements Record() directly, and is
emplaced into the command sequence via static_cast -- matching GemmThunk,
CublasLtMatmulThunk, and MemzeroThunk. CuDnnCmd is deleted entirely.

Record() prefers explicit command-buffer construction when the graph
reports SupportsExplicitCommandBufferConstruction()==true
(CreateDnnGraphCommand / UpdateDnnGraphCommand) and falls back to
RecordTracedCommand otherwise.

Initialize() is made safe for pre-populated graphs (e.g. when tests bypass
the fingerprint deserialization path): the short-circuit lives inside
absl::call_once to keep the read synchronized with concurrent Initialize()
calls from other streams/devices.

Tests
-----
- cuda_command_buffer_thunk_test.cc: existing CommandBufferThunkTest.CuDnnCmd
  rewired to the new owning-SequentialThunk shape.
- cudnn_thunk_test.cc: new CuDnnThunkCmdBufTest fixture + 4 TEST_F cases
  covering Record() in {Create, Update} x {explicit, implicit} modes. The
  implicit path is driven by a test-only FakeDnnGraph that forces
  SupportsExplicitCommandBufferConstruction()->false and whose Execute()
  emits a verifiable Memset32 on the trace stream. The Update tests allocate
  a fresh output buffer to exercise the non-cache-hit path in both
  UpdateDnnGraphCommand and RecordTracedCommand.

BUILD
-----
- :command_buffer_cmd drops //xla/stream_executor:dnn (no longer used).
- :cudnn_thunk picks up :command, :traced_command, :command_buffer, :stream,
  //xla:util, absl/log, absl/log:check; drops :errors (no TF_* macros left).
- :cudnn_thunk_test promoted from xla_cc_test to
  xla_test(backends=["gpu"], tags=["cuda-only"]) with the deps needed for
  the new direct-Record tests (cudnn_plugin, cudnn_frontend, command_state,
  collective_params, etc.).
- BUILD integrity audited 1:1 against #include graphs for all modified
  targets.

Merging this change closes #41205

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#41205 from shawnwang18:shawnw/migrate_cudnn_cmd_to_thunk f856b7d88c29d4986b872f74d54b7a56d3ab959f
PiperOrigin-RevId: 902579190
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ImportError: DLL load failed: A dynamic link library (DLL) initialization routine failed.

1 participant