Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| from transformers import CLIPProcessor, CLIPModel | |
| from transformers import BlipProcessor, BlipForQuestionAnswering | |
| #from transformers import YolosImageProcessor, YolosForObjectDetection | |
| from PIL import Image | |
| from functions import * | |
| import io | |
| #load models | |
| def load_models(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm") | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50",revision="no_timm") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| sales_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| sales_model = BlipForQuestionAnswering.from_pretrained( | |
| "Salesforce/blip-vqa-base", | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ).to(device) | |
| return { | |
| "detector": model, | |
| "processor": processor, | |
| "clip": clip_model, | |
| "clip process": clip_processor, | |
| #"t5 token": t5_tokenizer, | |
| #"t5": t5_model, | |
| 'story_teller': pipeline("text-generation", model="nickypro/tinyllama-15M"), | |
| "sales process": sales_processor, | |
| "sales model": sales_model, | |
| "device": device | |
| } | |
| def main(): | |
| st.header("π± Nano AI Image Analyzer") | |
| uploaded_file= st.file_uploader("upload image")#, type=['.PNG','png','jpg','jpeg']) | |
| models= load_models() | |
| st.write('models loaded') | |
| #im2=detect_objects(image_path=image, models= models) | |
| #st.write(im2) | |
| #st.write("done") | |
| #annotated_image= draw_bounding_boxes(image, im2) | |
| #st.image(annotated_image, caption="Detected Objects", use_container_width=True) | |
| #buttons UI | |
| if uploaded_file is not None: | |
| image_bytes = uploaded_file.getvalue() | |
| st.write("Filename:", uploaded_file.name) | |
| image = Image.open(uploaded_file).convert('RGB') | |
| st.image(image, caption="Uploaded Image", width=200) #use_container_width= False, | |
| col1, col2, col3 = st.columns([0.33,0.33,0.33]) | |
| with col1: | |
| detect= st.button("π Detect Objects", key="btn1") | |
| with col2: | |
| describe= st.button("π Describe Image", key="btn2") | |
| with col3: | |
| story= st.button("π Generate Story", key="btn3", | |
| help="story is generated based on caption") | |
| if detect: | |
| with st.spinner("Detecting objects..."): | |
| try: | |
| detections = detect_objects(image.copy(), models) | |
| annotated_image= draw_bounding_boxes(image, detections) | |
| st.image(annotated_image, caption="Detected Objects", use_column_width=True) | |
| show_detection_table(detections) | |
| except: | |
| st.write("some error!! try another image") | |
| elif describe: | |
| with st.spinner("trying to describe..."): | |
| description= get_image_description(image.copy(),models) | |
| st.write(description) | |
| elif story: | |
| #st.write('btn3 clicked') | |
| with st.spinner("getting a story..."): | |
| description= get_image_description(image.copy(),models) | |
| story= generate_story(description, models) | |
| st.write(story) | |
| # Chat interface | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| chat_container = st.container(height=400) | |
| with chat_container: | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("Ask about the image"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| response = answer_question(image, | |
| prompt, | |
| models["sales process"], | |
| models["sales model"], | |
| models["device"]) | |
| #response= "response sample" | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| if __name__ == "__main__": | |
| main() |