def format_takes_diff(takes: dict[str, list[dict]]) -> str: if not takes: return "" histories = list(takes.values()) if not histories: return "" min_len = min(len(h) for h in histories) common_prefix_len = 0 for i in range(min_len): first_msg = histories[0][i] if all(h[i] == first_msg for h in histories): common_prefix_len += 1 else: break shared_lines = [] for i in range(common_prefix_len): msg = histories[0][i] shared_lines.append(f"{msg.get('role', 'unknown')}: {msg.get('content', '')}") shared_text = "=== Shared History ===" if shared_lines: shared_text += "\n" + "\n".join(shared_lines) variation_lines = [] if len(takes) > 1: for take_name, history in takes.items(): if len(history) > common_prefix_len: variation_lines.append(f"[{take_name}]") for i in range(common_prefix_len, len(history)): msg = history[i] variation_lines.append(f"{msg.get('role', 'unknown')}: {msg.get('content', '')}") variation_lines.append("") else: # Single take case pass variations_text = "=== Variations ===\n" + "\n".join(variation_lines) return shared_text + "\n\n" + variations_text