|
|
import gradio as gr |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import ( |
|
|
CLIPProcessor, CLIPModel, |
|
|
LlamaForCausalLM, LlamaTokenizer, |
|
|
pipeline |
|
|
) |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import os |
|
|
|
|
|
class ImageStoryteller: |
|
|
def __init__(self): |
|
|
print("Initializing Image Storyteller with CLIP-ViT + LLaMA...") |
|
|
|
|
|
|
|
|
try: |
|
|
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
print("CLIP-ViT model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"CLIP loading failed: {e}") |
|
|
self.clip_model = None |
|
|
self.clip_processor = None |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
self.llama_model = LlamaForCausalLM.from_pretrained( |
|
|
"huggyllama/llama-7b", |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
load_in_8bit=True |
|
|
) |
|
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b") |
|
|
print("LLaMA model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"LLaMA loading failed: {e}") |
|
|
|
|
|
try: |
|
|
self.story_pipeline = pipeline( |
|
|
"text-generation", |
|
|
model="microsoft/DialoGPT-medium", |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
print("Fallback story pipeline initialized!") |
|
|
except Exception as e: |
|
|
print(f"Fallback pipeline failed: {e}") |
|
|
self.story_pipeline = None |
|
|
|
|
|
|
|
|
self.common_objects = [ |
|
|
'person', 'people', 'human', 'man', 'woman', 'child', 'baby', |
|
|
'dog', 'cat', 'animal', 'bird', 'horse', 'cow', 'sheep', |
|
|
'car', 'vehicle', 'bus', 'truck', 'bicycle', 'motorcycle', |
|
|
'building', 'house', 'skyscraper', 'architecture', |
|
|
'tree', 'forest', 'nature', 'mountain', 'sky', 'clouds', |
|
|
'water', 'ocean', 'river', 'lake', 'beach', |
|
|
'food', 'fruit', 'vegetable', 'meal', |
|
|
'indoor', 'outdoor', 'urban', 'rural' |
|
|
] |
|
|
|
|
|
|
|
|
self.scene_categories = [ |
|
|
"portrait", "landscape", "cityscape", "indoor scene", "outdoor scene", |
|
|
"nature", "urban", "beach", "mountain", "forest", "street", |
|
|
"party", "celebration", "sports", "action", "still life", |
|
|
"abstract", "art", "architecture", "wildlife", "pet" |
|
|
] |
|
|
|
|
|
def analyze_image_with_clip(self, image): |
|
|
"""Analyze image using CLIP to understand content and scene""" |
|
|
if self.clip_model is None or self.clip_processor is None: |
|
|
return self.fallback_image_analysis(image) |
|
|
|
|
|
try: |
|
|
|
|
|
image_rgb = image.convert('RGB') |
|
|
|
|
|
|
|
|
object_inputs = self.clip_processor( |
|
|
text=self.common_objects, |
|
|
images=image_rgb, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
object_outputs = self.clip_model(**object_inputs) |
|
|
object_logits = object_outputs.logits_per_image |
|
|
object_probs = object_logits.softmax(dim=1) |
|
|
|
|
|
|
|
|
top_object_indices = torch.topk(object_probs, 5, dim=1).indices[0] |
|
|
detected_objects = [] |
|
|
for idx in top_object_indices: |
|
|
obj_name = self.common_objects[idx] |
|
|
confidence = object_probs[0][idx].item() |
|
|
if confidence > 0.1: |
|
|
detected_objects.append({ |
|
|
'name': obj_name, |
|
|
'confidence': confidence |
|
|
}) |
|
|
|
|
|
|
|
|
scene_inputs = self.clip_processor( |
|
|
text=self.scene_categories, |
|
|
images=image_rgb, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
scene_outputs = self.clip_model(**scene_inputs) |
|
|
scene_logits = scene_outputs.logits_per_image |
|
|
scene_probs = scene_logits.softmax(dim=1) |
|
|
|
|
|
top_scene_indices = torch.topk(scene_probs, 3, dim=1).indices[0] |
|
|
scene_types = [] |
|
|
for idx in top_scene_indices: |
|
|
scene_name = self.scene_categories[idx] |
|
|
confidence = scene_probs[0][idx].item() |
|
|
scene_types.append({ |
|
|
'type': scene_name, |
|
|
'confidence': confidence |
|
|
}) |
|
|
|
|
|
return { |
|
|
'objects': detected_objects, |
|
|
'scenes': scene_types, |
|
|
'success': True |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"CLIP analysis failed: {e}") |
|
|
return self.fallback_image_analysis(image) |
|
|
|
|
|
def fallback_image_analysis(self, image): |
|
|
"""Fallback image analysis when CLIP fails""" |
|
|
img_np = np.array(image) |
|
|
height, width = img_np.shape[:2] |
|
|
|
|
|
|
|
|
hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV) |
|
|
|
|
|
objects = [] |
|
|
scenes = [] |
|
|
|
|
|
|
|
|
blue_mask = cv2.inRange(hsv, (100, 50, 50), (130, 255, 255)) |
|
|
if np.sum(blue_mask) > height * width * 0.1: |
|
|
objects.append({'name': 'sky', 'confidence': 0.6}) |
|
|
scenes.append({'type': 'outdoor scene', 'confidence': 0.7}) |
|
|
|
|
|
|
|
|
green_mask = cv2.inRange(hsv, (35, 50, 50), (85, 255, 255)) |
|
|
if np.sum(green_mask) > height * width * 0.1: |
|
|
objects.append({'name': 'nature', 'confidence': 0.6}) |
|
|
scenes.append({'type': 'nature', 'confidence': 0.7}) |
|
|
|
|
|
|
|
|
skin_mask = cv2.inRange(hsv, (0, 30, 60), (20, 150, 255)) |
|
|
if np.sum(skin_mask) > 1000: |
|
|
objects.append({'name': 'person', 'confidence': 0.5}) |
|
|
scenes.append({'type': 'portrait', 'confidence': 0.6}) |
|
|
|
|
|
|
|
|
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) |
|
|
edges = cv2.Canny(gray, 50, 150) |
|
|
if np.sum(edges) > height * width * 0.05: |
|
|
objects.append({'name': 'building', 'confidence': 0.5}) |
|
|
scenes.append({'type': 'urban', 'confidence': 0.6}) |
|
|
|
|
|
return { |
|
|
'objects': objects, |
|
|
'scenes': scenes, |
|
|
'success': False |
|
|
} |
|
|
|
|
|
def create_visualization(self, image, analysis_result): |
|
|
"""Create a visualization showing detected elements""" |
|
|
img_np = np.array(image) |
|
|
viz_image = img_np.copy() |
|
|
height, width = img_np.shape[:2] |
|
|
|
|
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
font_scale = 0.6 |
|
|
font_color = (255, 255, 255) |
|
|
background_color = (0, 0, 0) |
|
|
thickness = 2 |
|
|
|
|
|
|
|
|
text_lines = ["CLIP-ViT Analysis:"] |
|
|
|
|
|
|
|
|
if analysis_result['objects']: |
|
|
text_lines.append("Objects:") |
|
|
for obj in analysis_result['objects'][:3]: |
|
|
text_lines.append(f" {obj['name']} ({obj['confidence']:.2f})") |
|
|
|
|
|
|
|
|
if analysis_result['scenes']: |
|
|
text_lines.append("Scene:") |
|
|
for scene in analysis_result['scenes'][:2]: |
|
|
text_lines.append(f" {scene['type']} ({scene['confidence']:.2f})") |
|
|
|
|
|
|
|
|
y_offset = 30 |
|
|
for i, line in enumerate(text_lines): |
|
|
text_size = cv2.getTextSize(line, font, font_scale, thickness)[0] |
|
|
|
|
|
|
|
|
cv2.rectangle(viz_image, |
|
|
(10, y_offset - text_size[1] - 5), |
|
|
(10 + text_size[0] + 10, y_offset + 5), |
|
|
background_color, -1) |
|
|
|
|
|
|
|
|
cv2.putText(viz_image, line, (15, y_offset), |
|
|
font, font_scale, font_color, thickness) |
|
|
y_offset += 25 |
|
|
|
|
|
return Image.fromarray(viz_image) |
|
|
|
|
|
def generate_narrative_with_llama(self, analysis_result, image_size): |
|
|
"""Generate narrative using LLaMA based on CLIP analysis""" |
|
|
|
|
|
objects_text = ", ".join([obj['name'] for obj in analysis_result['objects'][:5]]) |
|
|
scenes_text = analysis_result['scenes'][0]['type'] if analysis_result['scenes'] else "unknown scene" |
|
|
|
|
|
width, height = image_size |
|
|
|
|
|
|
|
|
prompt = f"""Based on this image analysis: |
|
|
Image Size: {width}x{height} |
|
|
Detected Objects: {objects_text} |
|
|
Scene Type: {scenes_text} |
|
|
|
|
|
Please write a beautiful, descriptive narrative story about this image. Focus on the emotional and visual elements, creating a compelling story that brings the scene to life.""" |
|
|
|
|
|
try: |
|
|
if hasattr(self, 'llama_model') and self.llama_model is not None: |
|
|
|
|
|
inputs = self.llama_tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.llama_model.generate( |
|
|
inputs.input_ids, |
|
|
max_length=300, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
pad_token_id=self.llama_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
narrative = self.llama_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if narrative.startswith(prompt): |
|
|
narrative = narrative[len(prompt):].strip() |
|
|
return narrative |
|
|
|
|
|
elif hasattr(self, 'story_pipeline') and self.story_pipeline is not None: |
|
|
|
|
|
result = self.story_pipeline( |
|
|
prompt, |
|
|
max_length=250, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
pad_token_id=50256 |
|
|
) |
|
|
return result[0]['generated_text'] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"LLaMA narrative generation failed: {e}") |
|
|
|
|
|
|
|
|
return self.fallback_narrative(analysis_result, image_size) |
|
|
|
|
|
def fallback_narrative(self, analysis_result, image_size): |
|
|
"""Fallback narrative generation""" |
|
|
width, height = image_size |
|
|
objects = [obj['name'] for obj in analysis_result['objects']] |
|
|
scene = analysis_result['scenes'][0]['type'] if analysis_result['scenes'] else "scene" |
|
|
|
|
|
if 'person' in objects: |
|
|
return f"In this captivating {width}x{height} {scene}, human presence tells a story of connection and experience. " \ |
|
|
f"The composition speaks of moments frozen in time, where light and shadow dance together to reveal " \ |
|
|
f"the beauty of ordinary moments made extraordinary through the lens of perception." |
|
|
|
|
|
elif 'nature' in objects: |
|
|
return f"This breathtaking {width}x{height} natural landscape captures the essence of Earth's timeless beauty. " \ |
|
|
f"Each element harmonizes with the next, creating a symphony of visual poetry that whispers " \ |
|
|
f"ancient stories of growth, change, and the enduring power of the natural world." |
|
|
|
|
|
elif 'building' in objects: |
|
|
return f"Architectural elegance defines this {width}x{height} {scene}, where human ingenuity meets artistic vision. " \ |
|
|
f"The structures stand as silent witnesses to countless stories, their forms telling tales " \ |
|
|
f"of aspiration, community, and the relentless march of progress through time." |
|
|
|
|
|
else: |
|
|
return f"In this compelling {width}x{height} composition, visual elements converge to create a unique narrative. " \ |
|
|
f"The scene invites contemplation, asking viewers to explore the relationships between forms, " \ |
|
|
f"colors, and spaces that together tell a story beyond words." |
|
|
|
|
|
def generate_poetry(self, narrative): |
|
|
"""Generate poetic verses based on the narrative""" |
|
|
prompt = f"""Based on this image description: "{narrative}" |
|
|
|
|
|
Create a beautiful 6-line poem that captures the essence and emotion of the scene:""" |
|
|
|
|
|
try: |
|
|
if hasattr(self, 'llama_model') and self.llama_model is not None: |
|
|
inputs = self.llama_tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.llama_model.generate( |
|
|
inputs.input_ids, |
|
|
max_length=200, |
|
|
temperature=0.8, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
pad_token_id=self.llama_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
poetry = self.llama_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
if poetry.startswith(prompt): |
|
|
poetry = poetry[len(prompt):].strip() |
|
|
|
|
|
|
|
|
lines = [line.strip() for line in poetry.split('.') if line.strip()] |
|
|
if len(lines) >= 4: |
|
|
return '\n'.join(lines[:6]) |
|
|
return poetry |
|
|
|
|
|
elif hasattr(self, 'story_pipeline') and self.story_pipeline is not None: |
|
|
result = self.story_pipeline( |
|
|
prompt, |
|
|
max_length=150, |
|
|
temperature=0.8, |
|
|
do_sample=True |
|
|
) |
|
|
poetry = result[0]['generated_text'] |
|
|
lines = [line.strip() for line in poetry.split('.') if line.strip()] |
|
|
if len(lines) >= 4: |
|
|
return '\n'.join(lines[:6]) |
|
|
return poetry |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Poetry generation failed: {e}") |
|
|
|
|
|
|
|
|
return self.fallback_poetry(narrative) |
|
|
|
|
|
def fallback_poetry(self, narrative): |
|
|
"""Fallback poetry generation""" |
|
|
if 'person' in narrative.lower(): |
|
|
return """A figure stands where light does fall |
|
|
Their silent story captures all |
|
|
In moments caught by lens and eye |
|
|
Where truth and beauty never die |
|
|
Each breath a verse, each glance a call |
|
|
To understand, to stand in awe""" |
|
|
|
|
|
elif 'nature' in narrative.lower(): |
|
|
return """Where trees reach up to touch the sky |
|
|
And gentle streams go flowing by |
|
|
The earth reveals her ancient art |
|
|
In every leaf, in every part |
|
|
Nature's truth will never die |
|
|
In landscape's soul, we learn to fly""" |
|
|
|
|
|
elif 'building' in narrative.lower(): |
|
|
return """Stone and glass against the blue |
|
|
Tell stories old and stories new |
|
|
Where human hands have shaped the space |
|
|
With vision, time, and careful grace |
|
|
Each structure holds a different view |
|
|
Of dreams that humans can pursue""" |
|
|
|
|
|
else: |
|
|
return """In frames of light and color bold |
|
|
A thousand stories wait untold |
|
|
Each element with voice unique |
|
|
In visual language they all speak |
|
|
Of mysteries that unfold |
|
|
More precious than the purest gold""" |
|
|
|
|
|
def process_image(self, image): |
|
|
"""Main processing function""" |
|
|
try: |
|
|
|
|
|
analysis_result = self.analyze_image_with_clip(image) |
|
|
|
|
|
|
|
|
narrative = self.generate_narrative_with_llama(analysis_result, image.size) |
|
|
|
|
|
|
|
|
poetry = self.generate_poetry(narrative) |
|
|
|
|
|
|
|
|
viz_image = self.create_visualization(image, analysis_result) |
|
|
|
|
|
return narrative, poetry, viz_image |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"An error occurred while processing the image: {str(e)}" |
|
|
return error_msg, "Unable to generate poetry due to processing error.", image |
|
|
|
|
|
|
|
|
storyteller = ImageStoryteller() |
|
|
|
|
|
|
|
|
example_images = [] |
|
|
for i in range(1, 10): |
|
|
filename = f"obj_{i:02d}.jpg" |
|
|
if os.path.exists(filename): |
|
|
example_images.append([filename]) |
|
|
print(f"Found example image: {filename}") |
|
|
|
|
|
if not example_images: |
|
|
print("No local example images found, using placeholder") |
|
|
|
|
|
example_images = [[np.ones((300, 300, 3), dtype=np.uint8) * 100]] |
|
|
|
|
|
|
|
|
with gr.Blocks(title="CLIP + LLaMA Image Storyteller", theme="soft") as demo: |
|
|
gr.Markdown("# π¨ CLIP + LLaMA Image Storyteller") |
|
|
gr.Markdown("**Upload any image and watch AI understand the scene using CLIP-ViT and create beautiful stories with LLaMA!**") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image( |
|
|
type="pil", |
|
|
label="πΌοΈ Upload Your Image", |
|
|
height=300 |
|
|
) |
|
|
process_btn = gr.Button("β¨ Analyze Image & Create Story", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
analysis_output = gr.Image( |
|
|
label="π CLIP-ViT Analysis", |
|
|
height=300, |
|
|
show_download_button=True |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Tab("π Narrative Story"): |
|
|
narrative_output = gr.Textbox( |
|
|
label="Image Narrative", |
|
|
lines=5, |
|
|
max_lines=8, |
|
|
placeholder="Your image's story will appear here...", |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
with gr.Tab("π Poetic Verses"): |
|
|
poetry_output = gr.Textbox( |
|
|
label="6-Line Poetry", |
|
|
lines=6, |
|
|
max_lines=7, |
|
|
placeholder="Poetic interpretation will appear here...", |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### π― Try These Examples") |
|
|
gr.Examples( |
|
|
examples=example_images, |
|
|
inputs=input_image, |
|
|
outputs=[narrative_output, poetry_output, analysis_output], |
|
|
fn=storyteller.process_image, |
|
|
cache_examples=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π How It Works", open=False): |
|
|
gr.Markdown(""" |
|
|
**The Magic Behind the Stories:** |
|
|
|
|
|
1. **CLIP-ViT Analysis**: OpenAI's CLIP model understands image content and scene types |
|
|
2. **Object Recognition**: Identifies objects, people, scenery with confidence scores |
|
|
3. **Scene Classification**: Determines the overall scene type (portrait, landscape, urban, etc.) |
|
|
4. **LLaMA Storytelling**: Meta's LLaMA model generates compelling narratives |
|
|
5. **Poetic Creation**: Transforms analysis into beautiful 6-line verses |
|
|
|
|
|
**Technical Stack:** |
|
|
- **CLIP-ViT**: Vision transformer for image understanding |
|
|
- **LLaMA**: Large language model for text generation |
|
|
- **Transformers**: Hugging Face library for model inference |
|
|
|
|
|
**Features:** |
|
|
- Semantic image understanding |
|
|
- Context-aware storytelling |
|
|
- Emotional narrative generation |
|
|
- Beautiful poetic interpretations |
|
|
- Real-time analysis visualization |
|
|
|
|
|
**Perfect for:** |
|
|
- Personal photography |
|
|
- Landscape and nature scenes |
|
|
- Urban and architectural photography |
|
|
- Artistic compositions |
|
|
- Memory preservation |
|
|
""") |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=storyteller.process_image, |
|
|
inputs=input_image, |
|
|
outputs=[narrative_output, poetry_output, analysis_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |