from tqz_strategy.future.strategies.pair_trading_strategy import PairTradingStrategy  # noqa
import os
from datetime import datetime, date # noqa
from tqsdk import TqApi, TqAuth
from tqsdk.tools import DataDownloader
from contextlib import closing

from public_module.tqz_extern.tools.file_path_operator.file_path_operator import TQZFilePathOperator
from public_module.tqz_extern.tools.pandas_operator.pandas_operator import pandas # noqa

from server_api.tqz_object import BarData, Exchange

TIME_GAP = 8 * 60 * 60 * 1000000000

class TQZTianQinClient:
"""
天勤接口 每次只能拉取单一合约的数据!
"""

__tq_symbols = None

def __init__(self, account: str = "account", pass_word: str = "password"):
self.api = TqApi(auth=TqAuth(account, pass_word))

if TQZTianQinClient.__tq_symbols is None:
TQZTianQinClient.__tq_symbols = self.api.query_quotes(ins_class="FUTURE", expired=False)


def query_history_bars(self, tq_symbol: str, tq_duration_seconds: int, tq_data_length: int = 8964) -> list:
assert tq_symbol in TQZTianQinClient.__tq_symbols, f'bad tq_symbol: {tq_symbol}'

tq_result = self.api.get_kline_serial(symbol=tq_symbol, duration_seconds=tq_duration_seconds, data_length=tq_data_length)
self.api.close()

tq_result["datetime"] = pandas.to_datetime(tq_result["datetime"] + TIME_GAP)
tq_result['datetime'] = tq_result['datetime'].apply(lambda x: x.strftime('%Y-%m-%d %H:%M:%S')) # %f是毫秒

# symbol, exchange = TQZTianQinClient.__get_symbol_exchange(tq_symbol=tq_symbol)
history_bars = []
if tq_result is not None:
for ix, row in tq_result.loc[tq_result["id"] >= 0].iterrows():
history_bars.append(
BarData(
symbol=tq_symbol.split(".")[1],
exchange=Exchange(tq_symbol.split(".")[0]),
interval='any_interval', # noqa
datetime=row["datetime"],
open_price=row["open"],
high_price=row["high"],
low_price=row["low"],
close_price=row["close"],
volume=row["volume"],
open_interest=row.get("open_oi", 0),
gateway_name="TQ",
)
)

return history_bars

def query_index_history_bars(self, tq_index_symbol: str, tq_duration_seconds: int, tq_data_length: int = 8964) -> list:
tq_result = self.api.get_kline_serial(symbol=tq_index_symbol, duration_seconds=tq_duration_seconds, data_length=tq_data_length)
self.api.close()

tq_result["datetime"] = pandas.to_datetime(tq_result["datetime"] + TIME_GAP)
tq_result['datetime'] = tq_result['datetime'].apply(lambda x: x.strftime('%Y-%m-%d %H:%M:%S')) # %f是毫秒

# symbol, exchange = f'{tq_index_symbol.split(".")[2]}000', Exchange.TQZ_INDEX
history_bars = []
if tq_result is not None:
for ix, row in tq_result.loc[tq_result["id"] >= 0].iterrows():
history_bars.append(
BarData(
symbol=f'{tq_index_symbol.split(".")[2]}000',
exchange=Exchange("TQZ_INDEX"),
interval='any_interval', # noqa
datetime=row["datetime"],
open_price=row["open"],
high_price=row["high"],
low_price=row["low"],
close_price=row["close"],
volume=row["volume"],
open_interest=row.get("open_oi", 0),
gateway_name="TQ",
)
)

return history_bars

def query_main_history_bars(self, tq_main_symbol: str, tq_duration_seconds: int, tq_data_length: int = 8964) -> list:
tq_result = self.api.get_kline_serial(symbol=tq_main_symbol, duration_seconds=tq_duration_seconds, data_length=tq_data_length)
self.api.close()

tq_result["datetime"] = pandas.to_datetime(tq_result["datetime"] + TIME_GAP)
tq_result['datetime'] = tq_result['datetime'].apply(lambda x: x.strftime('%Y-%m-%d %H:%M:%S')) # %f是毫秒

symbol, exchange = f'{tq_main_symbol.split(".")[2]}888', Exchange.TQZ_MAIN
history_bars = []
if tq_result is not None:
for ix, row in tq_result.loc[tq_result["id"] >= 0].iterrows():
history_bars.append(
BarData(
symbol=symbol,
exchange=exchange,
interval='any_interval', # noqa
datetime=row["datetime"],
open_price=row["open"],
high_price=row["high"],
low_price=row["low"],
close_price=row["close"],
volume=row["volume"],
open_interest=row.get("open_oi", 0),
gateway_name="TQ",
)
)

return history_bars


@classmethod
def query_index_history_bars_from_csv(cls, tq_i_symbol, csv_path: str):
tq_result = pandas.read_csv(csv_path)

history_bars = []
if tq_result is not None:
for ix, row in tq_result.iterrows():
history_bars.append(
BarData(
symbol=tq_i_symbol,
exchange=Exchange("TQZ_INDEX"),
interval='any_interval', # noqa
datetime=row["datetime"],
open_price=row[f'{tq_i_symbol}.open'],
high_price=row[f'{tq_i_symbol}.high'],
low_price=row[f'{tq_i_symbol}.low'],
close_price=row[f'{tq_i_symbol}.close'],
volume=row[f'{tq_i_symbol}.volume'],
open_interest=row.get(f'{tq_i_symbol}.open_oi', 0),
gateway_name="TQ",
)
)

return history_bars

@classmethod
def query_history_bars_from_csv(cls, tq_symbol, csv_path: str):
tq_result = pandas.read_csv(csv_path)

history_bars = []
if tq_result is not None:
for ix, row in tq_result.iterrows():
history_bars.append(
BarData(
symbol=tq_symbol.split(".")[1],
exchange=Exchange(tq_symbol.split(".")[0]),
interval='any_interval', # noqa
datetime=row["datetime"],
open_price=row[f'{tq_symbol}.open'],
high_price=row[f'{tq_symbol}.high'],
low_price=row[f'{tq_symbol}.low'],
close_price=row[f'{tq_symbol}.close'],
volume=row[f'{tq_symbol}.volume'],
open_interest=row.get(f'{tq_symbol}.open_oi', 0),
gateway_name="TQ",
)
)

return history_bars

def dump_history_df_to_csv(self, tq_symbols: list, tq_duration_seconds: int, start_dt, end_dt):
downloading_result = {} # noqa

for tq_symbol in tq_symbols:
csv_file_name = f'{tq_symbol.replace(".", "_")}.csv'
downloading_result[tq_symbol] = DataDownloader(
self.api,
symbol_list=tq_symbol,
dur_sec=tq_duration_seconds,
start_dt=start_dt,
end_dt=end_dt,
csv_file_name=csv_file_name
)

with closing(self.api):
while not all([v.is_finished() for v in downloading_result.values()]):
self.api.wait_update()
print("progress: ", {k: ("%.2f%%" % v.get_progress()) for k, v in downloading_result.items()})


def load_all_tq_symbols(self):
self.api.close()
return TQZTianQinClient.__tq_symbols

def query_quote(self, tq_symbol: str) -> dict:
result = self.api.get_quote(symbol=tq_symbol)
self.api.close()
return result # noqa


# --- private part ---
@classmethod
def __get_symbol_exchange(cls, tq_symbol: str) -> (str, Exchange):
exchange_str, symbol = tq_symbol.split('.')[0], tq_symbol.split('.')[1]

if exchange_str in [Exchange.SHFE.value]:
exchange = Exchange.SHFE
elif exchange_str in [Exchange.INE.value]:
exchange = Exchange.INE
elif exchange_str in [Exchange.DCE.value]:
exchange = Exchange.DCE
elif exchange_str in [Exchange.CZCE.value]:
exchange = Exchange.CZCE
elif exchange_str in [Exchange.CFFEX.value]:
exchange = Exchange.CFFEX
else:
assert False, f'bad exchange_str: {exchange_str}'

return symbol, exchange


class TQZTianQinDataManager:
"""
小时线: 一天6根 (粗略)
1493个交易日

30分钟线: 一天12根
746个交易日

10分钟线: 一天36根
248个交易日
"""

__load_history_bars_map: dict = None

@classmethod
def load_history_bars_map(cls, tq_symbols: list, tq_duration_seconds: int, tq_data_length: int = 8964) -> dict:
"""
Load history bars of multi symbols from tianqin, and only load once time before back tester.
:param tq_symbols: list of multi symbols
:param tq_duration_seconds: just tq_duration_seconds
:param tq_data_length: data length
:return: history bars map
"""

load_history_bars_map = {}
if cls.__load_history_bars_map is None:
for tq_symbol in tq_symbols:
if tq_symbol not in load_history_bars_map.keys():
load_history_bars_map[tq_symbol] = TQZTianQinClient().query_history_bars(
tq_symbol=tq_symbol,
tq_duration_seconds=tq_duration_seconds,
tq_data_length=tq_data_length
)

cls.__load_history_bars_map = load_history_bars_map

return cls.__load_history_bars_map

@classmethod
def load_index_history_bars_from_csv(cls, tq_i_symbol: str, csv_path: str):
return TQZTianQinClient.query_index_history_bars_from_csv(tq_i_symbol=tq_i_symbol, csv_path=csv_path)

@classmethod
def load_history_bars_from_csv(cls, tq_symbol: str, csv_path: str):
return TQZTianQinClient.query_history_bars_from_csv(tq_symbol=tq_symbol, csv_path=csv_path)


@classmethod
def load_arbitrage_strategy_barsData_from_csv(cls, tq_i_symbol: str, csv_name: str):
pair_source_data_path = TQZFilePathOperator.grandfather_path(
source_path=__file__
) + f'/pair_source_data'
index_source_data_path = TQZFilePathOperator.grandfather_path(
source_path=__file__
) + f'/index_source_data/{csv_name}'

contracts_data = {}
for root, dirs, files in os.walk(pair_source_data_path):
for file in files:
all_path = f'{pair_source_data_path}/{file}'
_tq_symbol = file.split(".")[0].replace("_", ".")
contracts_data[_tq_symbol] = cls.load_history_bars_from_csv(
tq_symbol=_tq_symbol,
csv_path=all_path
)

index_data_source = cls.load_index_history_bars_from_csv(
tq_i_symbol=tq_i_symbol,
csv_path=index_source_data_path
)

# merge bars_df
contracts_name = []
bars_df = pandas.DataFrame()
for bar in index_data_source:
current_row = len(bars_df)
bars_df.loc[current_row, "datetime"], bars_df.loc[current_row, "index_bar"] = bar.datetime, bar
contracts_name.append("index_bar")

for tq_symbol, bars in contracts_data.items():
contracts_name.append(tq_symbol)
for bar in bars:
bars_df.loc[bars_df["datetime"] == bar.datetime, tq_symbol] = bar

# init bars_data_list for back tester
bars_data_list = []
for ix, row in bars_df.iterrows():
bars_data = {}
for contract_name in contracts_name:
if not pandas.isna(row[contract_name]):
bars_data[contract_name] = row[contract_name]
bars_data_list.append(bars_data)

return bars_data_list


if __name__ == '__main__':
""" contract data
content = TQZTianQinDataManager.load_history_bars_map(tq_symbols=["CZCE.SM206", "SHFE.rb2205"], tq_duration_seconds=60 * 60)
"""

""" dump history data from tianqin
TQZTianQinClient().dump_history_df_to_csv(tq_symbols=["SHFE.ag1912"], tq_duration_seconds=24 * 60 * 60, start_dt=date(2020, 1, 1), end_dt=date(2022, 1, 1))
TQZTianQinClient().dump_history_df_to_csv(tq_symbols=["KQ.m@SHFE.ag"], tq_duration_seconds=24 * 60 * 60, start_dt=date(2015, 1, 1), end_dt=date(2022, 1, 1))
TQZTianQinClient().dump_history_df_to_csv(tq_symbols=["KQ.i@SHFE.ag"], tq_duration_seconds=24 * 60 * 60, start_dt=date(2020, 1, 1), end_dt=date(2020, 6, 1))
"""

""" load history bars.
index_bars = TQZTianQinDataManager.load_history_bars_from_csv(tq_i_symbol="KQ.i@SHFE.ag", csv_name="KQ_i@SHFE_ag.csv")
i_d_source, c_datas = TQZTianQinDataManager.load_arbitrage_strategy_data_from_csv(tq_i_symbol="KQ.i@SHFE.ag", csv_name="KQ_i@SHFE_ag.csv")
"""

_bars_data_list = TQZTianQinDataManager.load_arbitrage_strategy_barsData_from_csv(tq_i_symbol="KQ.i@SHFE.ag", csv_name="KQ_i@SHFE_ag.csv")

""" pt_s test
pt_s = PairTradingStrategy(strategy_engine=None, strategy_name="pair", vt_symbols=["SHFE.ag2001", "SHFE.ag2002", "SHFE.ag2003", "SHFE.ag2007"], setting={})
for bars_data in _bars_data_list[:10]:
pt_s.on_bars(bars=bars_data)
"""