Coverage for src/stable_yield_lab/rebalance.py: 0%
86 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-04 20:38 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-04 20:38 +0000
1"""Rebalancing utilities for portfolio weight schedules and turnover metrics.
3The functions in this module generate deterministic outputs that can feed into
4visualisation components. The default engine applies a simple momentum style
5overlay to derive time-varying target weights, computes the resulting trading
6turnover and estimates trading fees given a basis-point cost assumption. The
7logic is intentionally light-weight to keep the demo fast while providing
8realistic shaped data for charts.
9"""
11from __future__ import annotations
13from dataclasses import dataclass
15import pandas as pd
18@dataclass(frozen=True)
19class RebalanceResult:
20 """Container for the outputs of a rebalancing run.
22 Attributes
23 ----------
24 target_weights:
25 Post-trade portfolio weights (rows are timestamps, columns are assets).
26 pre_trade_weights:
27 Portfolio weights immediately before rebalancing, after market drift.
28 turnover:
29 Fraction of portfolio traded on each rebalance date (decimal form).
30 fees:
31 Trading fees paid each period expressed as fraction of portfolio NAV.
32 """
34 target_weights: pd.DataFrame
35 pre_trade_weights: pd.DataFrame
36 turnover: pd.Series
37 fees: pd.Series
40def _normalise_rows(df: pd.DataFrame) -> pd.DataFrame:
41 """Ensure each row of ``df`` sums to one while handling zero rows.
43 Rows that sum to zero (or contain only NaNs) are replaced with equal
44 weights across all available columns. Empty frames are returned unchanged.
45 """
47 if df.empty:
48 return df
50 cols = df.columns
51 if len(cols) == 0:
52 return df
54 normalised = df.copy()
55 equal = pd.Series(1.0 / len(cols), index=cols)
56 for idx, row in normalised.iterrows():
57 if not row.notna().any():
58 normalised.loc[idx] = equal
59 continue
60 total = float(row.sum(skipna=True))
61 if total == 0.0:
62 normalised.loc[idx] = equal
63 else:
64 normalised.loc[idx] = row / total
65 return normalised.fillna(0.0)
68def _prepare_targets(
69 returns: pd.DataFrame,
70 target_weights: pd.DataFrame | pd.Series | None,
71) -> pd.DataFrame:
72 """Derive a target weight schedule aligned with ``returns``.
74 When ``target_weights`` is ``None`` a rolling momentum proxy is used to
75 produce time-varying allocations. A provided Series is broadcast to all
76 periods, while a DataFrame is forward-filled to cover every timestamp.
77 """
79 if returns.empty:
80 return returns.copy()
82 cols = returns.columns
83 index = returns.index
85 if target_weights is None:
86 window = min(4, len(returns))
87 momentum = returns.rolling(window=window, min_periods=1).mean()
88 positive = momentum.clip(lower=0.0)
89 weights = positive.div(positive.sum(axis=1).replace(0.0, pd.NA), axis=0)
90 weights = weights.reindex(index=index, columns=cols)
91 elif isinstance(target_weights, pd.Series):
92 base = target_weights.reindex(cols).fillna(0.0)
93 weights = pd.DataFrame([base] * len(index), index=index, columns=cols)
94 else:
95 weights = target_weights.reindex(columns=cols)
96 if not weights.index.equals(index):
97 weights = weights.reindex(index=index, method="ffill")
98 weights = weights.fillna(method="ffill")
100 weights = weights.reindex(index=index, columns=cols).fillna(0.0)
101 return _normalise_rows(weights)
104def run_rebalance(
105 returns: pd.DataFrame,
106 *,
107 target_weights: pd.DataFrame | pd.Series | None = None,
108 trading_cost_bps: float = 5.0,
109) -> RebalanceResult:
110 """Simulate periodic rebalancing and compute turnover/fees.
112 Parameters
113 ----------
114 returns:
115 Wide DataFrame of periodic simple returns. Rows are timestamps and
116 columns are asset identifiers.
117 target_weights:
118 Optional schedule of desired weights. When ``None`` a simple momentum
119 overlay is used to derive targets endogenously.
120 trading_cost_bps:
121 Assumed round-trip trading cost in basis points applied to the traded
122 notionals each period.
123 """
125 if returns.empty or returns.shape[1] == 0:
126 empty_df = returns.copy()
127 empty_series = pd.Series(index=returns.index, dtype=float)
128 return RebalanceResult(empty_df, empty_df, empty_series, empty_series)
130 clean_returns = returns.fillna(0.0)
131 targets = _prepare_targets(clean_returns, target_weights)
133 index = clean_returns.index
134 cols = clean_returns.columns
136 pre_trade = pd.DataFrame(0.0, index=index, columns=cols)
137 post_trade = pd.DataFrame(0.0, index=index, columns=cols)
138 turnover = pd.Series(0.0, index=index, dtype=float)
139 fees = pd.Series(0.0, index=index, dtype=float)
141 post = targets.iloc[0].copy()
142 if float(post.sum()) == 0.0:
143 post[:] = 1.0 / len(cols)
144 post_trade.iloc[0] = post
145 pre_trade.iloc[0] = post
147 turnover.iloc[0] = 0.0
148 fees.iloc[0] = 0.0
150 for i in range(1, len(clean_returns)):
151 ret = clean_returns.iloc[i]
152 drifted = post * (1.0 + ret)
153 total = float(drifted.sum())
154 if total > 0.0:
155 drifted = drifted / total
156 else:
157 drifted = post.copy()
158 pre_trade.iloc[i] = drifted
160 target_row = targets.iloc[i].copy()
161 target_total = float(target_row.sum())
162 if target_total == 0.0:
163 target_row = post.copy()
164 else:
165 target_row = target_row / target_total
166 post_trade.iloc[i] = target_row
168 trade_amount = float((target_row - drifted).abs().sum()) * 0.5
169 turnover.iloc[i] = trade_amount
170 fees.iloc[i] = trade_amount * (trading_cost_bps / 10_000.0)
172 post = target_row
174 return RebalanceResult(post_trade, pre_trade, turnover, fees)
177__all__ = ["RebalanceResult", "run_rebalance"]