Skip to content

Commit 3d7c36c

Browse files
committed
完成py3的兼容
1 parent f9ef38f commit 3d7c36c

7 files changed

Lines changed: 64 additions & 84 deletions

File tree

demo/mplot_demo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# -*- coding: utf-8 -*-
22
import matplotlib
3-
import matplotlib.pyplot as plt
43
import pandas as pd
5-
matplotlib.use('TkAgg')
64

75
from quantdigger.widgets.mplotwidgets import widgets
86
from quantdigger.widgets.mplotwidgets.mplots import Candles
97
from quantdigger.technicals.common import MA, Volume
108

9+
matplotlib.use('TkAgg')
10+
import matplotlib.pyplot as plt
11+
1112
price_data = pd.read_csv('data/IF000.csv', index_col=0, parse_dates=True)
1213
fig = plt.figure()
1314

quantdigger/event/event.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
import json
44
import six
55

6+
import sys
7+
if sys.version_info >= (3,):
8+
py = 3
9+
else:
10+
py = 2
11+
612

713
# @TODO REMOVE EventsPool
814
class EventsPool(object):
@@ -60,6 +66,8 @@ def args(self):
6066

6167
@classmethod
6268
def message_to_event(self, message):
69+
if py == 3:
70+
message = message.decode('utf8')
6371
route, args = message.split('&')
6472
route = route[1:]
6573
return Event(route=route, args=json.loads(args))

quantdigger/event/rpc.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
11
# encoding: UTF-8
2-
##
3-
# @file rpc.py
4-
# @brief
5-
# @author wondereamer
6-
# @version 0.5
7-
# @date 2016-05-17
82

93
import six
104
import time
@@ -20,7 +14,7 @@ def __init__(self, name, event_engine, service, event_client=None, event_server=
2014
self.EVENT_FROM_CLIENT = event_client if event_client else "EVENT_FROM_%s_CLIENT" % service.upper()
2115
self.EVENT_FROM_SERVER = event_server if event_server else "EVENT_FROM_%s_SERVER" % service.upper()
2216
self.rid = 0
23-
self._handlers = { }
17+
self._handlers = {}
2418
self._name = name
2519
self._handlers_lock = Lock()
2620
self._event_engine = event_engine
@@ -31,19 +25,19 @@ def __init__(self, name, event_engine, service, event_client=None, event_server=
3125
self._timer_sleep = 1
3226
self._sync_call_time_lock = Lock()
3327
self._sync_call_time = datetime.now()
34-
timer = Thread(target = self._run_timer)
28+
timer = Thread(target=self._run_timer)
3529
timer.daemon = True
3630
timer.start()
3731

3832
def _run_timer(self):
39-
## @TODO 用python自带的Event替代。
33+
# @TODO 用python自带的Event替代。
4034
while True:
4135
if not self._timeout == 0:
4236
with self._sync_call_time_lock:
4337
mtime = self._sync_call_time
44-
delta = (datetime.now()-mtime).seconds
38+
delta = (datetime.now() - mtime).seconds
4539
if delta >= self._timeout:
46-
#six.print_("timeout", self._timeout, delta)
40+
# print("timeout", self._timeout, delta)
4741
# 不可重入,保证self.rid就是超时的那个
4842
with self._handlers_lock:
4943
del self._handlers[self.rid]
@@ -79,15 +73,15 @@ def _process_apiback(self, event):
7973
def call(self, apiname, args, handler):
8074
""" 给定参数args,异步调用RPCServer的apiname服务,
8175
返回结果做为回调函数handler的参数。
82-
76+
8377
Args:
8478
apiname (str): 服务API名称。
8579
args (dict): 给服务API的参数。
8680
handler (function): 回调函数。
8781
"""
8882
if not isinstance(args, dict):
8983
raise InvalidRPCClientArguments(argtype=type(args))
90-
assert(not handler == None)
84+
assert(handler is not None)
9185
log.debug('RPCClient [%s] sync_call: %s' % (self._name, apiname))
9286
self.rid += 1
9387
args['apiname'] = apiname
@@ -96,10 +90,10 @@ def call(self, apiname, args, handler):
9690
with self._handlers_lock:
9791
self._handlers[self.rid] = handler
9892

99-
def sync_call(self, apiname, args={ }, timeout=5):
93+
def sync_call(self, apiname, args={}, timeout=5):
10094
""" 给定参数args,同步调用RPCServer的apiname服务,
10195
返回该服务的处理结果。如果超时,返回None。
102-
96+
10397
Args:
10498
apiname (str): 服务API名称。
10599
args (dict): 给服务API的参数。
@@ -117,11 +111,10 @@ def sync_call(self, apiname, args={ }, timeout=5):
117111
self._sync_call_time = datetime.now()
118112
self._timeout = timeout
119113
with self._handlers_lock:
120-
self._handlers[self.rid] = None #
114+
self._handlers[self.rid] = None
121115
self._event_engine.emit(Event(self.EVENT_FROM_CLIENT, args))
122116
self._waiting_server_data()
123117
ret = self._sync_ret
124-
#self._sync_ret = None
125118
return ret
126119

127120
def _waiting_server_data(self):
@@ -136,7 +129,7 @@ def _notify_server_data(self):
136129
class EventRPCServer(object):
137130
def __init__(self, event_engine, service, event_client=None, event_server=None):
138131
super(EventRPCServer, self).__init__()
139-
self._routes = { }
132+
self._routes = {}
140133
self._routes_lock = Lock()
141134
# server监听的client事件
142135
self.EVENT_FROM_CLIENT = event_client if event_client else "EVENT_FROM_%s_CLIENT" % service.upper()
@@ -149,16 +142,16 @@ def __init__(self, event_engine, service, event_client=None, event_server=None):
149142

150143
def register(self, route, handler):
151144
""" 注册服务函数。
152-
145+
153146
Args:
154147
route (str): 服务名
155148
handler (function): 回调函数
156-
149+
157150
Returns:
158151
Bool. 是否注册成功。
159152
"""
160153
if route in self._routes:
161-
return False
154+
return False
162155
with self._routes_lock:
163156
self._routes[route] = handler
164157
return True
@@ -179,30 +172,28 @@ def _process_request(self, event):
179172
try:
180173
with self._routes_lock:
181174
handler = self._routes[apiname]
182-
## @TODO async
175+
# @TODO async
183176
ret = handler(**args)
184177
except Exception as e:
185178
log.exception(e)
186179
else:
187-
args = { 'ret': ret,
180+
args = {'ret': ret,
188181
'rid': rid
189182
}
190183
log.debug('RPCServer [%s] emit %s' % (self._name,
191-
str(self.EVENT_FROM_SERVER)))
192-
#str(Event(self.EVENT_FROM_SERVER, args))))
184+
str(self.EVENT_FROM_SERVER)))
193185
self._event_engine.emit(Event(self.EVENT_FROM_SERVER, args))
194186

195187

196-
197188
if __name__ == '__main__':
198189

199190
from eventengine import ZMQEventEngine
200191
import sys
201192

202193
def print_hello(data):
203-
""""""
194+
""""""
204195
six.print_("***************")
205-
six.print_("print_hello" )
196+
six.print_("print_hello")
206197
six.print_("args: ", data)
207198
six.print_("return: ", 123)
208199
return "123"

quantdigger/interaction/serialize.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
##
22
# @file serialize.py
3-
# @brief
3+
# @brief
44
# @author Wells
55
# @version 0.5
66
# @date 2016-08-07
@@ -12,16 +12,16 @@
1212
import json
1313
from json import JSONEncoder
1414

15+
1516
class DataStructCoder(JSONEncoder):
1617
def default(self, o):
17-
return o.__dict__
18+
return o.__dict__
1819

19-
2020

2121
def serialize_pcontract_bars(str_pcontract, bars):
2222
data = {
2323
'pcontract': str_pcontract,
24-
'datetime': map(lambda x: str(x), bars.index),
24+
'datetime': list(map(lambda x: str(x), bars.index)),
2525
'open': bars.open.tolist(),
2626
'close': bars.close.tolist(),
2727
'high': bars.high.tolist(),
@@ -33,24 +33,25 @@ def serialize_pcontract_bars(str_pcontract, bars):
3333

3434
def deserialize_pcontract_bars(data):
3535
data = json.loads(data)
36-
dt = map(lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S"), data['datetime'])
37-
#datetime.datetime.strptime(string_date, "%Y-%m-%d %H:%M:%S.%f")
36+
dt = list(map(lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S"), data['datetime']))
37+
# datetime.datetime.strptime(string_date, "%Y-%m-%d %H:%M:%S.%f")
3838
pcon = data['pcontract']
3939
del data['pcontract']
4040
del data['datetime']
41-
return pcon, pd.DataFrame(data, index = dt)
41+
return pcon, pd.DataFrame(data, index=dt)
4242

4343

4444
def serialize_all_pcontracts(pcontracts):
4545
return [str(pcontract) for pcontract in pcontracts]
4646

47+
4748
def serialize_all_contracts(contracts):
4849
return [str(contract) for contract in contracts]
4950

51+
5052
def deserialize_all_pcontracts(pcontracts):
5153
return [PContract.from_string(strpcon) for strpcon in pcontracts]
5254

55+
5356
def deserialize_all_contracts(contracts):
5457
return [Contract.from_string(strcon) for strcon in contracts]
55-
56-

quantdigger/widgets/mplotwidgets/mainwindow.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import matplotlib.pyplot as plt
77
from matplotlib.widgets import Button
88

9-
from quantdigger.technicals.common import MA, Volume
109
from quantdigger.util import gen_logger as log
1110
from quantdigger.interaction.windowgate import WindowGate
1211
from quantdigger.widgets.mplotwidgets import widgets
@@ -34,7 +33,7 @@ def __init__(self):
3433
self._fig = plt.figure()
3534
self._gate = WindowGate(self)
3635
self._cur_contract_index = 0
37-
self._pcontracts_of_contract = { } # {[], []}
36+
self._pcontracts_of_contract = {} # {[], []}
3837
self._subwindows = []
3938
self._cur_period = 0
4039

@@ -65,7 +64,7 @@ def _create_toolbar(self):
6564
def _create_technical_window(self):
6665
self.frame = widgets.TechnicalWidget(self._fig, price_data, height=0.85)
6766
axes = self.frame.init_layout(50, 4, 1)
68-
ax_candles, ax_volume = axes[0], axes[1]
67+
ax_volume = axes[1]
6968
# at most 5 ticks, pruning the upper and lower so they don't overlap
7069
# with other ticks
7170
ax_volume.yaxis.set_major_locator(widgets.MyLocator(5, prune='both'))
@@ -102,7 +101,7 @@ def close(self):
102101
def on_next_contract(self, event):
103102
if self._cur_contract_index + 1 < len(self._pcontracts_of_contract.keys()):
104103
self._cur_contract_index += 1
105-
contract = self._pcontracts_of_contract.keys()[self._cur_contract_index]
104+
contract = list(self._pcontracts_of_contract.keys())[self._cur_contract_index]
106105

107106
pcon = self._pcontracts_of_contract[contract][self._cur_period]
108107
pcon, data = self._gate.get_pcontract(str(pcon))
@@ -116,7 +115,7 @@ def on_next_contract(self, event):
116115
def on_previous_contract(self, event):
117116
if self._cur_contract_index - 1 >= 0:
118117
self._cur_contract_index -= 1
119-
contract = self._pcontracts_of_contract.keys()[self._cur_contract_index]
118+
contract = list(self._pcontracts_of_contract.keys())[self._cur_contract_index]
120119

121120
pcon = self._pcontracts_of_contract[contract][self._cur_period]
122121
pcon, data = self._gate.get_pcontract(str(pcon))

quantdigger/widgets/mplotwidgets/mplots.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
11
# -*- coding: utf-8 -*-
2-
##
3-
# @file mplots.py
4-
# @brief 绘制k线,交易信号连线。
5-
# @author wondereamer
6-
# @version 2.0
7-
# @date 2015-10-19
82

93
import six
104
from six.moves import range
@@ -75,7 +69,7 @@ def __init__(self, data, tracker, name='candle',
7569
# note this code assumes if any value open, close, low, high is
7670
# missing they all are missing
7771
@override_attributes
78-
def plot(self, widget, data, width=0.6,
72+
def plot(self, widget, data, width=0.6,
7973
colorup='r', colordown='g', lc='k', alpha=1):
8074

8175
if self.lineCollection:
@@ -86,20 +80,19 @@ def plot(self, widget, data, width=0.6,
8680
self.set_yrange(data.low.values, data.high.values)
8781
self.data = data
8882
"""docstring for plot"""
89-
delta = self.width/2.
90-
barVerts = [((i-delta, open),
91-
(i-delta, close),
92-
(i+delta, close),
93-
(i+delta, open))
83+
delta = self.width / 2.
84+
barVerts = [((i - delta, open),
85+
(i - delta, close),
86+
(i + delta, close),
87+
(i + delta, open))
9488
for i, open, close in zip(range(len(self.data)),
9589
self.data.open,
9690
self.data.close)
9791
if open != -1 and close != -1]
9892
rangeSegments = [((i, low), (i, high))
99-
for i, low, high in zip(
100-
range(len(self.data)),
101-
self.data.low,
102-
self.data.high)
93+
for i, low, high in zip(range(len(self.data)),
94+
self.data.low,
95+
self.data.high)
10396
if low != -1]
10497
r, g, b = colorConverter.to_rgb(self.colorup)
10598
colorup = r, g, b, self.alpha
@@ -118,32 +111,23 @@ def plot(self, widget, data, width=0.6,
118111
r, g, b = colorConverter.to_rgb(self.lc)
119112
linecolor = r, g, b, self.alpha
120113
self.lineCollection = LineCollection(rangeSegments,
121-
colors=(linecolor,),
122-
linewidths=lw,
123-
antialiaseds=useAA,
124-
zorder=0)
114+
colors=(linecolor,),
115+
linewidths=lw,
116+
antialiaseds=useAA,
117+
zorder=0)
125118

126119
self.barCollection = PolyCollection(barVerts,
127-
facecolors=colors,
128-
edgecolors=colors,
129-
antialiaseds=useAA,
130-
linewidths=lw,
131-
zorder=1)
132-
#minx, maxx = 0, len(rangeSegments)
133-
#miny = min([low for low in self.data.low if low !=-1])
134-
#maxy = max([high for high in self.data.high if high != -1])
135-
#corners = (minx, miny), (maxx, maxy)
136-
#ax.update_datalim(corners)
120+
facecolors=colors,
121+
edgecolors=colors,
122+
antialiaseds=useAA,
123+
linewidths=lw,
124+
zorder=1)
137125
widget.autoscale_view()
138126
# add these last
139127
widget.add_collection(self.barCollection)
140128
widget.add_collection(self.lineCollection)
141-
142-
#ax.plot(self.data.close, color = 'y')
143-
#lineCollection, barCollection = None, None
144129
return self.lineCollection, self.barCollection
145130

146-
147131
def set_yrange(self, lower, upper=[]):
148132
self.upper = upper if len(upper) > 0 else lower
149133
self.lower = lower
@@ -159,10 +143,6 @@ def y_interval(self, w_left, w_right):
159143
class TradingSignal(object):
160144
""" 从信号坐标(时间, 价格)中绘制交易信号。 """
161145
def __init__(self, signal, name="Signal", c=None, lw=2):
162-
#self.set_yrange(price)
163-
#self.signal=signal
164-
#self.c = c
165-
#self.lw = lw
166146
self.signal = signal
167147
self.name = name
168148

0 commit comments

Comments
 (0)