import baostock as bs
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.ticker import MaxNLocator
import datetime
# 设置中文显示
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei", "SimSun", "KaiTi", "FangSong"]
plt.rcParams["axes.unicode_minus"] = False # 正确显示负号
def get_stock_data(code, start_date, end_date):
"""从baostock获取股票数据"""
# 登录baostock
lg = bs.login()
if lg.error_code != '0':
print(f"登录失败:{lg.error_msg}")
return None
# 获取股票数据
rs = bs.query_history_k_data_plus(
code,
"date,open,high,low,close,volume",
start_date=start_date,
end_date=end_date,
frequency="d",
adjustflag="3" # 复权类型,3表示不复权
)
# 处理数据
data_list = []
while (rs.error_code == '0') & rs.next():
data_list.append(rs.get_row_data())
# 登出baostock
bs.logout()
# 转换为DataFrame并处理
if not data_list:
print("没有获取到数据")
return None
df = pd.DataFrame(data_list, columns=rs.fields)
# 转换数据类型
df['date'] = pd.to_datetime(df['date'])
df['open'] = df['open'].astype(float)
df['high'] = df['high'].astype(float)
df['low'] = df['low'].astype(float)
df['close'] = df['close'].astype(float)
df['volume'] = df['volume'].astype(float)
df.set_index('date', inplace=True)
return df
def calculate_rsi(data, window=14):
"""计算RSI指标"""
delta = data.loc[:, 'close'].diff()
# 分离涨跌
gain = (delta.where(delta > 0, 0)).rolling(window=window).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean()
# 计算RSI
rs = gain / loss
data['rsi'] = 100 - (100 / (1 + rs))
return data
def generate_signals(data):
"""生成交易信号"""
# 初始化信号:0表示无信号,1表示买入,-1表示卖出
data['signal'] = 0
# RSI < 20 发出买入信号
data.loc[data['rsi'] < 20, 'signal'] = 1
# RSI > 80 发出卖出信号
data.loc[data['rsi'] > 80, 'signal'] = -1
return data
def backtest_strategy(data, initial_capital=100000):
"""回测策略"""
# 初始化资金和持仓,明确使用浮点数类型
portfolio = pd.DataFrame(index=data.index).fillna(0.0)
portfolio['cash'] = float(initial_capital)
portfolio['shares'] = 0 # shares保持为整数
portfolio['total'] = float(initial_capital)
in_position = False # 是否持仓
for i in range(1, len(data)):
date = data.index[i]
prev_date = data.index[i-1]
# 复制前一天的持仓和资金状态
portfolio.loc[date, 'cash'] = portfolio.loc[prev_date, 'cash']
portfolio.loc[date, 'shares'] = portfolio.loc[prev_date, 'shares']
# 检查交易信号
if data.loc[date, 'signal'] == 1 and not in_position:
# 买入信号且未持仓,执行买入
price = data.loc[date, 'close']
max_shares = int(portfolio.loc[date, 'cash'] / price)
if max_shares > 0:
portfolio.loc[date, 'shares'] = max_shares
portfolio.loc[date, 'cash'] -= max_shares * price
in_position = True
print(f"{date.date()} 发出买入信号,价格: {price:.2f}, 买入 {max_shares} 股")
elif data.loc[date, 'signal'] == -1 and in_position:
# 卖出信号且持仓,执行卖出
price = data.loc[date, 'close']
shares = portfolio.loc[date, 'shares']
portfolio.loc[date, 'cash'] += shares * price
portfolio.loc[date, 'shares'] = 0
in_position = False
print(f"{date.date()} 发出卖出信号,价格: {price:.2f}, 卖出 {shares} 股")
# 计算总资产
portfolio.loc[date, 'total'] = portfolio.loc[date, 'cash'] + portfolio.loc[date, 'shares'] * data.loc[date, 'close']
# 将回测结果合并到数据中
data['portfolio'] = portfolio['total']
return data
def plot_results(data):
"""绘制结果图表"""
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(16, 18), sharex=True)
# 价格和交易信号图
ax1.plot(data.index, data['close'], label='收盘价', linewidth=2)
ax1.scatter(data.index[data['signal'] == 1], data['close'][data['signal'] == 1],
marker='^', color='g', label='买入信号', zorder=3)
ax1.scatter(data.index[data['signal'] == -1], data['close'][data['signal'] == -1],
marker='v', color='r', label='卖出信号', zorder=3)
ax1.set_title('股票价格与交易信号')
ax1.set_ylabel('价格 (元)')
ax1.legend()
ax1.grid(True)
# RSI指标图
ax2.plot(data.index, data['rsi'], label='RSI (14)', color='purple', linewidth=2)
ax2.axhline(70, color='orange', linestyle='--', alpha=0.7)
ax2.axhline(30, color='orange', linestyle='--', alpha=0.7)
ax2.axhline(80, color='r', linestyle='--')
ax2.axhline(20, color='g', linestyle='--')
ax2.set_title('RSI指标 (14周期)')
ax2.set_ylabel('RSI值')
ax2.set_ylim(0, 100)
ax2.legend()
ax2.grid(True)
# 资金曲线
ax3.plot(data.index, data['portfolio'], label='策略资产', color='b', linewidth=2)
# 计算买入持有策略的资产
initial_capital = data['portfolio'].iloc[0]
buy_hold = initial_capital * (data['close'] / data['close'].iloc[0])
ax3.plot(data.index, buy_hold, label='买入持有', color='gray', linestyle='--', linewidth=2)
ax3.set_title('策略表现与买入持有对比')
ax3.set_xlabel('日期')
ax3.set_ylabel('资产 (元)')
ax3.legend()
ax3.grid(True)
# 设置x轴日期格式
ax3.xaxis.set_major_locator(mdates.MonthLocator(interval=3))
ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
return fig
def calculate_performance_metrics(data):
"""计算绩效指标"""
initial_capital = data['portfolio'].iloc[0]
final_capital = data['portfolio'].iloc[-1]
# 计算策略总收益
total_return = (final_capital - initial_capital) / initial_capital * 100
# 计算买入持有总收益
initial_price = data['close'].iloc[0]
final_price = data['close'].iloc[-1]
buy_hold_return = (final_price - initial_price) / initial_price * 100
# 计算交易次数
buy_signals = sum(data['signal'] == 1)
sell_signals = sum(data['signal'] == -1)
# 计算持有天数
days_held = len(data)
# 计算年化收益率
years = days_held / 252 # 假设一年252个交易日
annualized_return = (pow((final_capital / initial_capital), 1/years) - 1) * 100 if years > 0 else 0
print("\n====== 策略绩效指标 ======")
print(f"回测时间段: {data.index[0].date()} 至 {data.index[-1].date()}")
print(f"持有天数: {days_held} 天")
print(f"初始资金: {initial_capital:.2f} 元")
print(f"最终资金: {final_capital:.2f} 元")
print(f"策略总收益率: {total_return:.2f}%")
print(f"买入持有总收益率: {buy_hold_return:.2f}%")
print(f"策略年化收益率: {annualized_return:.2f}%")
print(f"买入信号次数: {buy_signals}")
print(f"卖出信号次数: {sell_signals}")
return {
'total_return': total_return,
'buy_hold_return': buy_hold_return,
'annualized_return': annualized_return,
'buy_signals': buy_signals,
'sell_signals': sell_signals,
'days_held': days_held
}
def main():
# 设置股票代码和日期范围(最近5年数据)
stock_code = "sh.600938" # 600938的证券代码
end_date = datetime.datetime.now().strftime("%Y-%m-%d")
start_date = (datetime.datetime.now() - datetime.timedelta(days=5*365)).strftime("%Y-%m-%d")
print(f"获取 {stock_code} 从 {start_date} 到 {end_date} 的数据...")
# 获取股票数据
data = get_stock_data(stock_code, start_date, end_date)
if data is None or len(data) == 0:
print("无法获取足够的股票数据进行分析")
return
# 计算RSI指标
data = calculate_rsi(data)
# 生成交易信号
data = generate_signals(data)
# 回测策略
data = backtest_strategy(data)
# 计算并显示绩效指标
metrics = calculate_performance_metrics(data)
# 绘制结果图表
plot_results(data)
if __name__ == "__main__":
main()
'''
脚本现在可以顺利执行RSI交易策略回测,输出显示: 成功获取了sh.600938股票的数据
生成了多个买入和卖出交易信号
计算了详细的策略绩效指标,
包括:
总收益率:61.16%
年化收益率:15.33%
与买入持有策略对比(买入持有收益率:92.31%)
'''