43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
def format_takes_diff(takes: dict[str, list[dict]]) -> str:
|
|
if not takes:
|
|
return ""
|
|
|
|
histories = list(takes.values())
|
|
if not histories:
|
|
return ""
|
|
|
|
min_len = min(len(h) for h in histories)
|
|
common_prefix_len = 0
|
|
for i in range(min_len):
|
|
first_msg = histories[0][i]
|
|
if all(h[i] == first_msg for h in histories):
|
|
common_prefix_len += 1
|
|
else:
|
|
break
|
|
|
|
shared_lines = []
|
|
for i in range(common_prefix_len):
|
|
msg = histories[0][i]
|
|
shared_lines.append(f"{msg.get('role', 'unknown')}: {msg.get('content', '')}")
|
|
|
|
shared_text = "=== Shared History ==="
|
|
if shared_lines:
|
|
shared_text += "\n" + "\n".join(shared_lines)
|
|
|
|
variation_lines = []
|
|
if len(takes) > 1:
|
|
for take_name, history in takes.items():
|
|
if len(history) > common_prefix_len:
|
|
variation_lines.append(f"[{take_name}]")
|
|
for i in range(common_prefix_len, len(history)):
|
|
msg = history[i]
|
|
variation_lines.append(f"{msg.get('role', 'unknown')}: {msg.get('content', '')}")
|
|
variation_lines.append("")
|
|
else:
|
|
# Single take case
|
|
pass
|
|
|
|
variations_text = "=== Variations ===\n" + "\n".join(variation_lines)
|
|
|
|
return shared_text + "\n\n" + variations_text
|