Skip to content

Major gradient mismatch (max diff ~0.42) on multi-branch WDAG topologies with jit_compile=True #116258

@Beanlycool

Description

@Beanlycool

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

source

TensorFlow version

tf 2.16.1

Custom code

Yes

OS platform and distribution

Linux Ubuntu 22.04

Mobile device

No response

Python version

3.11

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

12.1

GPU model and memory

No response

Current behavior?

I have identified a significant numerical inconsistency in the XLA compiler during the gradient computation of a complex multi-branch model. When jit_compile=True is enabled, the calculated gradients deviate drastically from the Eager execution baseline (Max Absolute Difference > 0.4).

file.zip

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np
import json
import os

# --- 1. 轻量级 JSON 解析器 (替代私有 wdag_keras) ---
class StandaloneWDAGModel(tf.keras.Model):
    def __init__(self, json_path):
        super().__init__()
        with open(json_path, 'r') as f:
            self.config = json.load(f)
        
        self.node_layers = {}
        for node in self.config['nodes']:
            if node['type'] == 'Input': continue
            
            # 动态映射算子
            if node['type'] == 'Conv1D':
                self.node_layers[node['id']] = tf.keras.layers.Conv1D(**node['args'])
            elif node['type'] == 'Dense':
                self.node_layers[node['id']] = tf.keras.layers.Dense(**node['args'])
            elif node['type'] == 'Concatenate':
                self.node_layers[node['id']] = tf.keras.layers.Concatenate(**node['args'])
            elif node['type'] == 'Flatten':
                self.node_layers[node['id']] = tf.keras.layers.Flatten(**node['args'])

    def call(self, inputs, training=True):
        node_outputs = {self.config['inputs'][0]: inputs}
        
        # 按照拓扑顺序(假设 JSON 节点已排序)执行计算
        for node in self.config['nodes']:
            nid = node['id']
            if nid in node_outputs: continue
            
            # 获取当前节点的所有输入边
            src_ids = [e['src'] for e in self.config['edges'] if e['tgt'] == nid]
            if not src_ids: continue
            
            in_tensors = [node_outputs[sid] for sid in src_ids]
            
            layer = self.node_layers.get(nid)
            if layer:
                if len(in_tensors) > 1:
                    node_outputs[nid] = layer(in_tensors, training=training)
                else:
                    node_outputs[nid] = layer(in_tensors[0], training=training)
        
        # 返回主输出
        return node_outputs[self.config['outputs'][0]]

# --- 2. 核心复现逻辑 ---
def run_reproduction():
    JSON_FILE = 'BUG_epoch_1.json'
    DATA_FILE = 'temp_tf_meta_inputs_epoch.npz'
    WEIGHT_FILE = 'temp_tf_meta_epoch.weights.h5'

    # 初始化模型并 Build
    model = StandaloneWDAGModel(JSON_FILE)
    data = np.load(DATA_FILE)
    x1, x2 = tf.cast(data['x1'], tf.float32), tf.cast(data['x2'], tf.float32)
    
    # 预运行以构建权重
    _ = model(x1) 
    model.load_weights(WEIGHT_FILE)

    # 定义梯度获取函数
    def get_grads(jit):
        @tf.function(jit_compile=jit)
        def _compute(i1, i2):
            with tf.GradientTape() as tape:
                loss = tf.reduce_mean(model(i1)) + tf.reduce_mean(model(i2))
            return tape.gradient(loss, model.trainable_variables)
        return _compute(x1, x2)

    print(" Running Eager Baseline...")
    grads_eager = get_grads(jit=False)
    print(" Running XLA Optimized...")
    grads_xla = get_grads(jit=True)

    # --- 3. 结果核验 ---
    print("\n" + "="*30 + "\nNumerical Consistency Report\n" + "="*30)
    for i, (g_e, g_x) in enumerate(zip(grads_eager, grads_xla)):
        if g_e is None or g_x is None: continue
        
        diff = np.abs(g_e.numpy() - g_x.numpy())
        max_diff = np.max(diff)
        if max_diff > 1e-3:
            print(f"Layer {i:2d} | Mismatch Detected! Max Abs Diff: {max_diff:.6e}")
        else:
            print(f"Layer {i:2d} | OK")

if __name__ == "__main__":
    run_reproduction()

Relevant log output

==============================
Numerical Consistency Report
==============================
Layer  0 | Mismatch Detected! Max Abs Diff: 3.796768e-01
Layer  1 | OK
Layer  2 | Mismatch Detected! Max Abs Diff: 4.209318e-01
Layer  3 | OK
Layer  4 | Mismatch Detected! Max Abs Diff: 2.811756e-01
Layer  5 | OK
Layer  8 | OK
Layer  9 | OK
Layer 10 | OK
Layer 11 | OK
Layer 16 | Mismatch Detected! Max Abs Diff: 2.794993e-02
Layer 17 | OK

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions