forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_library.py
More file actions
140 lines (116 loc) · 4.93 KB
/
load_library.py
File metadata and controls
140 lines (116 loc) · 4.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Function for loading TensorFlow plugins."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import imp
import sys
import threading
from six.moves.builtins import bytes # pylint: disable=redefined-builtin
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow as py_tf
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
# Thread safe dict to memoize the library filename to module mapping
_OP_LIBRARY_MAP = {}
_OP_LIBRARY_MAP_LOCK = threading.Lock()
def load_op_library(library_filename):
"""Loads a TensorFlow plugin, containing custom ops and kernels.
Pass "library_filename" to a platform-specific mechanism for dynamically
loading a library. The rules for determining the exact location of the
library are platform-specific and are not documented here. When the
library is loaded, ops and kernels registered in the library via the
REGISTER_* macros are made available in the TensorFlow process. Note
that ops with the same name as an existing op are rejected and not
registered with the process.
Args:
library_filename: Path to the plugin.
Relative or absolute filesystem path to a dynamic library file.
Returns:
A python module containing the Python wrappers for Ops defined in
the plugin.
Raises:
RuntimeError: when unable to load the library or get the python wrappers.
"""
status = py_tf.TF_NewStatus()
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
try:
error_code = py_tf.TF_GetCode(status)
if error_code != 0:
error_msg = compat.as_text(py_tf.TF_Message(status))
with _OP_LIBRARY_MAP_LOCK:
if (error_code == error_codes_pb2.ALREADY_EXISTS and
'has already been loaded' in error_msg and
library_filename in _OP_LIBRARY_MAP):
return _OP_LIBRARY_MAP[library_filename]
# pylint: disable=protected-access
raise errors._make_specific_exception(None, None, error_msg, error_code)
# pylint: enable=protected-access
finally:
py_tf.TF_DeleteStatus(status)
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
op_list.ParseFromString(compat.as_bytes(op_list_str))
wrappers = py_tf.GetPythonWrappers(op_list_str)
# Get a unique name for the module.
module_name = hashlib.md5(wrappers).hexdigest()
module = imp.new_module(module_name)
# pylint: disable=exec-used
exec(wrappers, module.__dict__)
# Stash away the library handle for making calls into the dynamic library.
module.LIB_HANDLE = lib_handle
# OpDefs of the list of ops defined in the library.
module.OP_LIST = op_list
sys.modules[module_name] = module
# Memoize the filename to module mapping.
with _OP_LIBRARY_MAP_LOCK:
_OP_LIBRARY_MAP[library_filename] = module
return module
_FILE_SYSTEM_LIBRARY_MAP = {}
_FILE_SYSTEM_LIBRARY_MAP_LOCK = threading.Lock()
def load_file_system_library(library_filename):
"""Loads a TensorFlow plugin, containing file system implementation.
Pass `library_filename` to a platform-specific mechanism for dynamically
loading a library. The rules for determining the exact location of the
library are platform-specific and are not documented here.
Args:
library_filename: Path to the plugin.
Relative or absolute filesystem path to a dynamic library file.
Returns:
None.
Raises:
RuntimeError: when unable to load the library.
"""
status = py_tf.TF_NewStatus()
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
try:
error_code = py_tf.TF_GetCode(status)
if error_code != 0:
error_msg = compat.as_text(py_tf.TF_Message(status))
with _FILE_SYSTEM_LIBRARY_MAP_LOCK:
if (error_code == error_codes_pb2.ALREADY_EXISTS and
'has already been loaded' in error_msg and
library_filename in _FILE_SYSTEM_LIBRARY_MAP):
return
# pylint: disable=protected-access
raise errors._make_specific_exception(None, None, error_msg, error_code)
# pylint: enable=protected-access
finally:
py_tf.TF_DeleteStatus(status)
with _FILE_SYSTEM_LIBRARY_MAP_LOCK:
_FILE_SYSTEM_LIBRARY_MAP[library_filename] = lib_handle