Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
No
Source
source
TensorFlow version
tf 2.16.1
Custom code
Yes
OS platform and distribution
Linux Ubuntu 22.04
Mobile device
No response
Python version
3.11
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
12.1
GPU model and memory
No response
Current behavior?
I have identified a significant numerical inconsistency in the XLA compiler during the gradient computation of a complex multi-branch model. When jit_compile=True is enabled, the calculated gradients deviate drastically from the Eager execution baseline (Max Absolute Difference > 0.4).
file.zip
Standalone code to reproduce the issue
import tensorflow as tf
import numpy as np
import json
import os
# --- 1. 轻量级 JSON 解析器 (替代私有 wdag_keras) ---
class StandaloneWDAGModel(tf.keras.Model):
def __init__(self, json_path):
super().__init__()
with open(json_path, 'r') as f:
self.config = json.load(f)
self.node_layers = {}
for node in self.config['nodes']:
if node['type'] == 'Input': continue
# 动态映射算子
if node['type'] == 'Conv1D':
self.node_layers[node['id']] = tf.keras.layers.Conv1D(**node['args'])
elif node['type'] == 'Dense':
self.node_layers[node['id']] = tf.keras.layers.Dense(**node['args'])
elif node['type'] == 'Concatenate':
self.node_layers[node['id']] = tf.keras.layers.Concatenate(**node['args'])
elif node['type'] == 'Flatten':
self.node_layers[node['id']] = tf.keras.layers.Flatten(**node['args'])
def call(self, inputs, training=True):
node_outputs = {self.config['inputs'][0]: inputs}
# 按照拓扑顺序(假设 JSON 节点已排序)执行计算
for node in self.config['nodes']:
nid = node['id']
if nid in node_outputs: continue
# 获取当前节点的所有输入边
src_ids = [e['src'] for e in self.config['edges'] if e['tgt'] == nid]
if not src_ids: continue
in_tensors = [node_outputs[sid] for sid in src_ids]
layer = self.node_layers.get(nid)
if layer:
if len(in_tensors) > 1:
node_outputs[nid] = layer(in_tensors, training=training)
else:
node_outputs[nid] = layer(in_tensors[0], training=training)
# 返回主输出
return node_outputs[self.config['outputs'][0]]
# --- 2. 核心复现逻辑 ---
def run_reproduction():
JSON_FILE = 'BUG_epoch_1.json'
DATA_FILE = 'temp_tf_meta_inputs_epoch.npz'
WEIGHT_FILE = 'temp_tf_meta_epoch.weights.h5'
# 初始化模型并 Build
model = StandaloneWDAGModel(JSON_FILE)
data = np.load(DATA_FILE)
x1, x2 = tf.cast(data['x1'], tf.float32), tf.cast(data['x2'], tf.float32)
# 预运行以构建权重
_ = model(x1)
model.load_weights(WEIGHT_FILE)
# 定义梯度获取函数
def get_grads(jit):
@tf.function(jit_compile=jit)
def _compute(i1, i2):
with tf.GradientTape() as tape:
loss = tf.reduce_mean(model(i1)) + tf.reduce_mean(model(i2))
return tape.gradient(loss, model.trainable_variables)
return _compute(x1, x2)
print(" Running Eager Baseline...")
grads_eager = get_grads(jit=False)
print(" Running XLA Optimized...")
grads_xla = get_grads(jit=True)
# --- 3. 结果核验 ---
print("\n" + "="*30 + "\nNumerical Consistency Report\n" + "="*30)
for i, (g_e, g_x) in enumerate(zip(grads_eager, grads_xla)):
if g_e is None or g_x is None: continue
diff = np.abs(g_e.numpy() - g_x.numpy())
max_diff = np.max(diff)
if max_diff > 1e-3:
print(f"Layer {i:2d} | Mismatch Detected! Max Abs Diff: {max_diff:.6e}")
else:
print(f"Layer {i:2d} | OK")
if __name__ == "__main__":
run_reproduction()
Relevant log output
==============================
Numerical Consistency Report
==============================
Layer 0 | Mismatch Detected! Max Abs Diff: 3.796768e-01
Layer 1 | OK
Layer 2 | Mismatch Detected! Max Abs Diff: 4.209318e-01
Layer 3 | OK
Layer 4 | Mismatch Detected! Max Abs Diff: 2.811756e-01
Layer 5 | OK
Layer 8 | OK
Layer 9 | OK
Layer 10 | OK
Layer 11 | OK
Layer 16 | Mismatch Detected! Max Abs Diff: 2.794993e-02
Layer 17 | OK
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
No
Source
source
TensorFlow version
tf 2.16.1
Custom code
Yes
OS platform and distribution
Linux Ubuntu 22.04
Mobile device
No response
Python version
3.11
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
12.1
GPU model and memory
No response
Current behavior?
I have identified a significant numerical inconsistency in the XLA compiler during the gradient computation of a complex multi-branch model. When jit_compile=True is enabled, the calculated gradients deviate drastically from the Eager execution baseline (Max Absolute Difference > 0.4).
file.zip
Standalone code to reproduce the issue
Relevant log output