404 lines
16 KiB
Python
404 lines
16 KiB
Python
import streamlit as st
|
|
import os
|
|
import io
|
|
import zipfile
|
|
import logging
|
|
import requests
|
|
import concurrent.futures
|
|
from generator import configure_genai, generate_carousel_content, generate_background_image
|
|
from utils import format_slide, download_fonts
|
|
from styles import STYLES
|
|
from PIL import Image
|
|
|
|
# Configure CLI Logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configure page
|
|
st.set_page_config(
|
|
page_title="LinkedIn Carousel Generator",
|
|
page_icon="✨",
|
|
layout="wide"
|
|
)
|
|
|
|
# Initialize Session State
|
|
if "generated_content" not in st.session_state:
|
|
st.session_state.generated_content = None
|
|
if "generated_slides" not in st.session_state:
|
|
st.session_state.generated_slides = []
|
|
if "logs" not in st.session_state:
|
|
st.session_state.logs = ""
|
|
|
|
# Constants for Cost Calculation (Gemini 3 Flash Preview + Pro Image 1K/2K)
|
|
TEXT_INPUT_COST_PER_1M = 0.50
|
|
TEXT_OUTPUT_COST_PER_1M = 3.00
|
|
IMAGE_COST_PER_UNIT = 0.134 # 1K/2K Image rate
|
|
|
|
if "total_cost" not in st.session_state:
|
|
st.session_state.total_cost = 0.0
|
|
|
|
def calculate_text_cost(input_tokens, output_tokens):
|
|
cost = (input_tokens / 1_000_000 * TEXT_INPUT_COST_PER_1M) + \
|
|
(output_tokens / 1_000_000 * TEXT_OUTPUT_COST_PER_1M)
|
|
return cost
|
|
|
|
def add_log(message: str):
|
|
"""Adds a message to the session state logs and prints to CLI."""
|
|
st.session_state.logs += f"[LOG] {message}\n"
|
|
logger.info(message)
|
|
|
|
def download_image_from_url(url: str) -> Image.Image:
|
|
"""Downloads an image from a URL and returns a PIL Image object."""
|
|
try:
|
|
response = requests.get(url, timeout=10)
|
|
response.raise_for_status()
|
|
return Image.open(io.BytesIO(response.content))
|
|
except Exception as e:
|
|
add_log(f"Failed to download source image from {url}: {e}")
|
|
return None
|
|
|
|
def process_slide_task(idx, slide, style_name, total_slides, reference_image=None, source_image=None):
|
|
"""Worker function to generate a single slide."""
|
|
logs = []
|
|
try:
|
|
logs.append(f"DEBUG: Starting task for Slide {idx+1}")
|
|
|
|
# Generate Slide with Text and optional Code using Gemini 3
|
|
slide_image = generate_background_image(
|
|
slide.image_prompt,
|
|
slide.headline,
|
|
slide.body,
|
|
slide.code_snippet,
|
|
style_name,
|
|
reference_image=reference_image,
|
|
source_image=source_image
|
|
)
|
|
|
|
# Format
|
|
final_slide = format_slide(
|
|
slide_image,
|
|
idx + 1,
|
|
total_slides,
|
|
STYLES[style_name]
|
|
)
|
|
# Return image cost as well
|
|
return idx, final_slide, None, IMAGE_COST_PER_UNIT, logs
|
|
except Exception as e:
|
|
logs.append(f"ERROR in task: {str(e)}")
|
|
return idx, None, str(e), 0.0, logs
|
|
|
|
# Sidebar for API Key
|
|
with st.sidebar:
|
|
st.header("Settings")
|
|
# Priority 1: User input
|
|
user_api_key = st.text_input("Gemini API Key", type="password", help="Get your key from ai.google.dev")
|
|
# Priority 2: Environment variable
|
|
env_api_key = os.environ.get("GOOGLE_API_KEY")
|
|
|
|
active_key = user_api_key if user_api_key else env_api_key
|
|
|
|
if active_key:
|
|
configure_genai(active_key)
|
|
if user_api_key:
|
|
st.success("API Key configured from input.")
|
|
else:
|
|
st.success("API Key configured from environment.")
|
|
else:
|
|
st.warning("Please enter your API Key to proceed.")
|
|
|
|
# Cost Metric
|
|
st.divider()
|
|
st.metric(label="Estimated Cost", value=f"${st.session_state.total_cost:.4f}")
|
|
|
|
# Style Reference Uploader
|
|
st.divider()
|
|
st.subheader("Advanced")
|
|
uploaded_ref_file = st.file_uploader("Upload Style Reference (Optional)", type=["png", "jpg", "jpeg"], help="Upload an image to define the exact style/layout for all slides.")
|
|
|
|
if st.button("Clear Logs"):
|
|
st.session_state.logs = ""
|
|
st.session_state.total_cost = 0.0
|
|
st.rerun()
|
|
|
|
st.title("✨ LinkedIn Carousel Generator")
|
|
st.markdown("Create professional carousels from text using **Gemini 3** models and the latest **Google GenAI SDK**.")
|
|
|
|
# Input Section
|
|
col1, col2 = st.columns([2, 1])
|
|
|
|
with col1:
|
|
source_text = st.text_area("Paste your article or text here:", height=300, placeholder="Once upon a time in the world of AI...")
|
|
|
|
with col2:
|
|
st.subheader("Configuration")
|
|
|
|
# Logic: If reference is uploaded, style selector is locked to Reference-Based
|
|
if uploaded_ref_file:
|
|
style_name = st.selectbox("Style (Locked to Reference)", ["Reference-Based"], disabled=True)
|
|
st.info("Style is being derived from your uploaded image.")
|
|
else:
|
|
# Filter out Reference-Based from the manual list to avoid confusion
|
|
manual_styles = [s for s in STYLES.keys() if s != "Reference-Based"]
|
|
style_name = st.selectbox("Select Style", manual_styles)
|
|
style_desc = STYLES[style_name]["description"]
|
|
st.caption(f"**Description:** {style_desc}")
|
|
|
|
# Language Selection
|
|
language = st.selectbox(
|
|
"Output Language",
|
|
["English", "Spanish", "French", "German", "Italian", "Portuguese", "Dutch", "Russian", "Chinese", "Japanese", "Korean"],
|
|
index=0
|
|
)
|
|
|
|
generate_btn = st.button("Generate Carousel", type="primary", disabled=not active_key)
|
|
|
|
# Logs Expander (Persistent)
|
|
log_container = st.empty()
|
|
if st.session_state.logs:
|
|
with st.expander("🛠️ Execution Logs", expanded=True):
|
|
st.code(st.session_state.logs)
|
|
|
|
# Main Logic
|
|
if generate_btn and source_text:
|
|
st.session_state.logs = ""
|
|
st.session_state.total_cost = 0.0 # Reset cost
|
|
add_log(f"Starting generation with style: {style_name}")
|
|
|
|
# Pre-check fonts to avoid race conditions in threads
|
|
try:
|
|
download_fonts()
|
|
add_log("Fonts verified/downloaded.")
|
|
except Exception as e:
|
|
add_log(f"Warning: Font download failed: {e}")
|
|
|
|
# Check for user uploaded reference
|
|
user_reference_img = None
|
|
if uploaded_ref_file:
|
|
try:
|
|
user_reference_img = Image.open(uploaded_ref_file)
|
|
add_log("User provided a custom style reference image. Using it for all slides.")
|
|
except Exception as e:
|
|
add_log(f"Error loading uploaded reference image: {e}")
|
|
|
|
# Create containers for progress feedback
|
|
status_text = st.empty()
|
|
progress_bar = st.progress(0)
|
|
|
|
try:
|
|
# Step 1: Text Generation
|
|
status_text.markdown("### 📝 Analyzing text and generating structure...")
|
|
add_log("Calling Gemini 3 Flash Preview for text analysis...")
|
|
progress_bar.progress(10)
|
|
|
|
# Updated to receive usage metadata
|
|
content, usage = generate_carousel_content(source_text, style_name, language)
|
|
|
|
# Calculate text cost
|
|
text_cost = calculate_text_cost(usage.get("input_tokens", 0), usage.get("output_tokens", 0))
|
|
st.session_state.total_cost += text_cost
|
|
add_log(f"Text Gen Cost: ${text_cost:.4f} ({usage['input_tokens']} in, {usage['output_tokens']} out)")
|
|
|
|
st.session_state.generated_content = content
|
|
|
|
add_log(f"Content generated successfully. {len(content.slides)} slides planned.")
|
|
status_text.markdown("### ✅ Structure generated! Preparing image generation...")
|
|
progress_bar.progress(20)
|
|
|
|
except Exception as e:
|
|
add_log(f"ERROR during text generation: {str(e)}")
|
|
st.error(f"Failed to generate content: {e}")
|
|
status_text.empty()
|
|
progress_bar.empty()
|
|
|
|
if st.session_state.generated_content:
|
|
content = st.session_state.generated_content
|
|
slides_data = content.slides
|
|
total_slides = len(slides_data)
|
|
|
|
generated_images_map = {}
|
|
reference_slide_img = None
|
|
|
|
progress_per_slide = 80 / total_slides
|
|
current_progress = 20
|
|
|
|
# Logic Branch: User Reference vs Auto-Generated Reference
|
|
slides_to_process_parallel = []
|
|
|
|
if user_reference_img:
|
|
# Case A: User provided reference -> All slides run in parallel immediately
|
|
reference_slide_img = user_reference_img
|
|
slides_to_process_parallel = list(range(total_slides)) # 0 to N
|
|
add_log("Using uploaded image as style reference for ALL slides.")
|
|
else:
|
|
# Case B: No user reference -> Generate Slide 0 first, then use as reference
|
|
if total_slides > 0:
|
|
add_log("No custom reference provided. Generating Slide 1 as the seed reference...")
|
|
status_text.markdown(f"### 🎨 Generating Reference Slide (1/{total_slides})...")
|
|
|
|
try:
|
|
# Check for source image in first slide
|
|
first_slide_source_img = None
|
|
if slides_data[0].source_image_url:
|
|
add_log(f"Downloading source image for Slide 1: {slides_data[0].source_image_url}")
|
|
first_slide_source_img = download_image_from_url(slides_data[0].source_image_url)
|
|
|
|
# Generate Slide 1 synchronously
|
|
r_idx, img, error, img_cost, task_logs = process_slide_task(
|
|
0,
|
|
slides_data[0],
|
|
style_name,
|
|
total_slides,
|
|
reference_image=None,
|
|
source_image=first_slide_source_img
|
|
)
|
|
|
|
# Process logs from worker
|
|
for log_msg in task_logs:
|
|
add_log(log_msg)
|
|
|
|
if error:
|
|
add_log(f"Error generating reference slide: {error}")
|
|
status_text.error("Failed to generate reference slide.")
|
|
else:
|
|
generated_images_map[0] = img
|
|
reference_slide_img = img
|
|
|
|
st.session_state.total_cost += img_cost
|
|
add_log(f"Reference Slide 1 completed. Cost: +${img_cost:.3f}")
|
|
|
|
# Update progress
|
|
current_progress += progress_per_slide
|
|
progress_bar.progress(min(int(current_progress), 99))
|
|
|
|
# Set remaining slides for parallel processing
|
|
slides_to_process_parallel = list(range(1, total_slides)) # 1 to N
|
|
except Exception as e:
|
|
add_log(f"Critical error in reference slide gen: {e}")
|
|
|
|
# Step 3: Parallel Execution
|
|
if slides_to_process_parallel:
|
|
if not reference_slide_img and not user_reference_img:
|
|
add_log("WARNING: Reference slide generation failed. Proceeding with independent generation (consistency mode disabled).")
|
|
else:
|
|
add_log(f"Starting parallel generation for {len(slides_to_process_parallel)} slides using reference...")
|
|
|
|
status_text.markdown(f"### 🚀 Parallel Generating {len(slides_to_process_parallel)} Slides...")
|
|
|
|
# Download source images
|
|
slide_source_images = {}
|
|
for idx in slides_to_process_parallel:
|
|
if slides_data[idx].source_image_url:
|
|
add_log(f"Downloading source image for Slide {idx+1}...")
|
|
slide_source_images[idx] = download_image_from_url(slides_data[idx].source_image_url)
|
|
else:
|
|
slide_source_images[idx] = None
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
|
future_to_slide = {
|
|
executor.submit(
|
|
process_slide_task,
|
|
idx,
|
|
slides_data[idx],
|
|
style_name,
|
|
total_slides,
|
|
reference_image=reference_slide_img, # Uses user uploaded or generated ref
|
|
source_image=slide_source_images[idx]
|
|
): idx
|
|
for idx in slides_to_process_parallel
|
|
}
|
|
|
|
for future in concurrent.futures.as_completed(future_to_slide):
|
|
idx = future_to_slide[future]
|
|
try:
|
|
r_idx, img, error, img_cost, task_logs = future.result()
|
|
|
|
# Process logs from worker
|
|
for log_msg in task_logs:
|
|
add_log(log_msg)
|
|
|
|
if error:
|
|
add_log(f"Error generating slide {r_idx+1}: {error}")
|
|
status_text.warning(f"Failed to generate Slide {r_idx+1}")
|
|
else:
|
|
generated_images_map[r_idx] = img
|
|
headline = slides_data[r_idx].headline
|
|
st.session_state.total_cost += img_cost
|
|
add_log(f"Slide {r_idx+1} completed: {headline} (Cost: +${img_cost:.3f})")
|
|
status_text.markdown(f"### 🎨 Finished Slide {r_idx+1}/{total_slides}")
|
|
|
|
current_progress += progress_per_slide
|
|
progress_bar.progress(min(int(current_progress), 99))
|
|
|
|
with log_container:
|
|
with st.expander("🛠️ Execution Logs", expanded=False):
|
|
st.code(st.session_state.logs)
|
|
|
|
except Exception as exc:
|
|
add_log(f"Unexpected exception for slide {idx+1}: {exc}")
|
|
|
|
# Sort images by index to maintain order
|
|
generated_images = []
|
|
for i in range(total_slides):
|
|
if i in generated_images_map and generated_images_map[i] is not None:
|
|
generated_images.append(generated_images_map[i])
|
|
else:
|
|
add_log(f"Warning: Slide {i+1} missing or invalid in final set.")
|
|
|
|
st.session_state.generated_slides = generated_images
|
|
|
|
# Finalize
|
|
progress_bar.progress(100)
|
|
add_log(f"Generation completed. Total Estimated Cost: ${st.session_state.total_cost:.4f}")
|
|
status_text.success(f"### 🎉 Carousel Generated! (Est. Cost: ${st.session_state.total_cost:.4f})")
|
|
st.balloons()
|
|
|
|
# Final update to log container
|
|
with log_container:
|
|
with st.expander("🛠️ Execution Logs", expanded=False):
|
|
st.code(st.session_state.logs)
|
|
|
|
# Display Results
|
|
if st.session_state.generated_slides:
|
|
st.divider()
|
|
st.subheader("Preview & Download")
|
|
|
|
# Display slides in a grid
|
|
cols = st.columns(3)
|
|
for idx, img in enumerate(st.session_state.generated_slides):
|
|
with cols[idx % 3]:
|
|
st.image(img, caption=f"Slide {idx+1}", use_column_width=True)
|
|
|
|
# Post Text
|
|
with st.expander("LinkedIn Post Text", expanded=True):
|
|
st.text_area("Copy this for your post:", value=st.session_state.generated_content.post_text, height=200)
|
|
|
|
# Download ZIP
|
|
zip_buffer = io.BytesIO()
|
|
with zipfile.ZipFile(zip_buffer, "w") as zf:
|
|
# Add slides
|
|
for idx, img in enumerate(st.session_state.generated_slides):
|
|
img_byte_arr = io.BytesIO()
|
|
img.save(img_byte_arr, format='PNG')
|
|
zf.writestr(f"slide_{idx+1}.png", img_byte_arr.getvalue())
|
|
|
|
# Add post text
|
|
if st.session_state.generated_content and st.session_state.generated_content.post_text:
|
|
zf.writestr("post_text.md", st.session_state.generated_content.post_text)
|
|
|
|
st.download_button(
|
|
label="Download All Slides (ZIP)",
|
|
data=zip_buffer.getvalue(),
|
|
file_name="linkedin_carousel.zip",
|
|
mime="application/zip",
|
|
type="primary"
|
|
)
|
|
|
|
elif not generate_btn:
|
|
st.info("Enter text and click Generate to start.")
|