forked from tensorforce/tensorforce
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutil.py
More file actions
executable file
·216 lines (168 loc) · 7.16 KB
/
util.py
File metadata and controls
executable file
·216 lines (168 loc) · 7.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Copyright 2017 reinforce.io. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import importlib
import logging
import numpy as np
import tensorflow as tf
from tensorflow.core.util.event_pb2 import SessionLog
from tensorforce import TensorForceError
epsilon = 1e-6
log_levels = dict(
info=logging.INFO,
debug=logging.DEBUG,
critical=logging.CRITICAL,
warning=logging.WARNING,
fatal=logging.FATAL
)
def prod(xs):
"""Computes the product along the elements in an iterable. Returns 1 for empty iterable.
Args:
xs: Iterable containing numbers.
Returns: Product along iterable.
"""
p = 1
for x in xs:
p *= x
return p
def rank(x):
return x.get_shape().ndims
def shape(x, unknown=-1):
return tuple(unknown if dims is None else dims for dims in x.get_shape().as_list())
def np_dtype(dtype):
"""Translates dtype specifications in configurations to numpy data types.
Args:
dtype: String describing a numerical type (e.g. 'float') or numerical type primitive.
Returns: Numpy data type
"""
if dtype == 'float' or dtype == float or dtype == np.float32 or dtype == tf.float32:
return np.float32
elif dtype == 'int' or dtype == int or dtype == np.int32 or dtype == tf.int32:
return np.int32
elif dtype == 'bool' or dtype == bool or dtype == np.bool_ or dtype == tf.bool:
return np.bool_
else:
raise TensorForceError("Error: Type conversion from type {} not supported.".format(str(dtype)))
def tf_dtype(dtype):
"""Translates dtype specifications in configurations to tensorflow data types.
Args:
dtype: String describing a numerical type (e.g. 'float'), numpy data type,
or numerical type primitive.
Returns: TensorFlow data type
"""
if dtype == 'float' or dtype == float or dtype == np.float32 or dtype == tf.float32:
return tf.float32
elif dtype == 'int' or dtype == int or dtype == np.int32 or dtype == tf.int32:
return tf.int32
elif dtype == 'bool' or dtype == bool or dtype == np.bool_ or dtype == tf.bool:
return tf.bool
else:
raise TensorForceError("Error: Type conversion from type {} not supported.".format(str(dtype)))
def map_tensors(fn, tensors):
if tensors is None:
return None
elif isinstance(tensors, tuple):
return tuple(map_tensors(fn=fn, tensors=tensor) for tensor in tensors)
elif isinstance(tensors, list):
return [map_tensors(fn=fn, tensors=tensor) for tensor in tensors]
elif isinstance(tensors, dict):
return {key: map_tensors(fn=fn, tensors=tensor) for key, tensor in tensors.items()}
elif isinstance(tensors, set):
return {map_tensors(fn=fn, tensors=tensor) for tensor in tensors}
else:
return fn(tensors)
def get_object(obj, predefined_objects=None, default_object=None, kwargs=None):
"""
Utility method to map some kind of object specification to its content,
e.g. optimizer or baseline specifications to the respective classes.
Args:
obj: A specification dict (value for key 'type' optionally specifies
the object, options as follows), a module path (e.g.,
my_module.MyClass), a key in predefined_objects, or a callable
(e.g., the class type object).
predefined_objects: Dict containing predefined set of objects,
accessible via their key
default_object: Default object is no other is specified
kwargs: Arguments for object creation
Returns: The retrieved object
"""
args = ()
kwargs = dict() if kwargs is None else kwargs
if isinstance(obj, dict):
kwargs.update(obj)
obj = kwargs.pop('type', None)
if predefined_objects is not None and obj in predefined_objects:
obj = predefined_objects[obj]
elif isinstance(obj, str):
if obj.find('.') != -1:
module_name, function_name = obj.rsplit('.', 1)
module = importlib.import_module(module_name)
obj = getattr(module, function_name)
else:
raise TensorForceError("Error: object {} not found in predefined objects: {}".format(
obj,
list(predefined_objects or ())
))
elif callable(obj):
pass
elif default_object is not None:
args = (obj,)
obj = default_object
else:
# assumes the object is already instantiated
return obj
return obj(*args, **kwargs)
def prepare_kwargs(raw, string_parameter='name'):
"""
Utility method to convert raw string/diction input into a dictionary to pass
into a function. Always returns a dictionary.
Args:
raw: string or dictionary, string is assumed to be the name of the activation
activation function. Dictionary will be passed through unchanged.
Returns: kwargs dictionary for **kwargs
"""
kwargs = dict()
if isinstance(raw, dict):
kwargs.update(raw)
elif isinstance(raw, str):
kwargs[string_parameter] = raw
return kwargs
class UpdateSummarySaverHook(tf.train.SummarySaverHook):
def __init__(self, model, *args, **kwargs):
super(UpdateSummarySaverHook, self).__init__(*args, **kwargs)
self.model = model
def before_run(self, run_context):
self._request_summary = run_context.original_args[1] is not None and \
self.model.is_observe and \
(self._next_step is None or self._timer.should_trigger_for_step(self._next_step))
# run_context.original_args[1].get(self.is_optimizing, False) and \
requests = {'global_step': self._global_step_tensor}
if self._request_summary:
if self._get_summary_op() is not None:
requests['summary'] = self._get_summary_op()
return tf.train.SessionRunArgs(requests)
def after_run(self, run_context, run_values):
if not self._summary_writer:
return
stale_global_step = run_values.results['global_step']
global_step = stale_global_step + 1
if self._next_step is None or self._request_summary:
global_step = run_context.session.run(self._global_step_tensor)
if self._next_step is None:
self._summary_writer.add_session_log(SessionLog(status=SessionLog.START), global_step)
if 'summary' in run_values.results:
self._timer.update_last_triggered_step(global_step)
for summary in run_values.results['summary']:
self._summary_writer.add_summary(summary, global_step)
self._next_step = global_step + 1