-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathApp.java
More file actions
63 lines (55 loc) · 2.62 KB
/
App.java
File metadata and controls
63 lines (55 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
package tftest;
import java.util.Collections;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.GPUOptions;
import org.tensorflow.proto.framework.GraphOptions;
import org.tensorflow.proto.framework.OptimizerOptions;
import org.tensorflow.proto.framework.OptimizerOptions.GlobalJitLevel;
import org.tensorflow.proto.framework.OptimizerOptions.Level;
import org.tensorflow.types.TFloat32;
public class App {
private final static String modelPath = "/tmp/vit_b32_fe/";
public static void main(String[] args) {
System.out.println("Testing query without ConfigProto");
try (SavedModelBundle savedModel = SavedModelBundle.loader(modelPath).withTags(new String[]{"serve"}).load()) {
doInference(savedModel, "Java: Model without ConfigProto");
}
System.out.println("Testing query with ConfigProto");
ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance())
.setLogDevicePlacement(false)
.setGraphOptions(GraphOptions.newBuilder()
.setOptimizerOptions(
OptimizerOptions.newBuilder()
.setCpuGlobalJit(false)
.setGlobalJitLevel(GlobalJitLevel.OFF)
.setDoCommonSubexpressionElimination(false)
.setDoConstantFolding(false)
.setDoFunctionInlining(false)
.setOptLevel(Level.L0)
.build()
)
).setGpuOptions(
GPUOptions.newBuilder()
.setForceGpuCompatible(false)
.setAllowGrowth(true)
.setPerProcessGpuMemoryFraction(0.5)
.build()
).build();
try (SavedModelBundle savedModel = SavedModelBundle.loader(modelPath).withConfigProto(config).withTags(new String[]{"serve"}).load()) {
doInference(savedModel, "Java: Model with ConfigProto/disabled JIT");
}
}
public static void doInference(SavedModelBundle savedModel, String msg) {
long start = System.currentTimeMillis();
try (TFloat32 xTensor = TFloat32.tensorOf(NdArrays.ofFloats(Shape.of(1,244,244,3)));
TFloat32 zTensor = (TFloat32) savedModel
.call(Collections.singletonMap("inputs", xTensor))
.get("output_0")) {
long end = System.currentTimeMillis();
System.out.println(msg + ", warm up took "+((end-start)/1000f)+" seconds");
}
}
}