linkedin-generator / linkedin_gradio.py
willsh1997's picture
cast to bfloat 16 and delete quant
2731895
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from transformers import pipeline
import pandas as pd
import gradio as gr
import os
import copy
import spaces
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, TextIteratorStreamer
# quantization_config = BitsAndBytesConfig(load_in_4bit=True)
torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32
llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct",
# quantization_config=quantization_config,
torch_dtype=torch_dtype,
device_map=torch_device,
# load_in_4bit=True #for puny devices like mine.
)
llama_tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
# streamer = TextStreamer(llama_tokenizer)
llama32_3b_pipe = pipeline(
"text-generation",
model=llama_model,
tokenizer=llama_tokenizer,
# streamer = streamer,
)
@spaces.GPU
def llama32_3b_chat(message) -> str:
"simplifies pipeline output to only return generated text"
input_history = [{"role": "system", "content": """You are now a LinkedIn post generator. I will just give you a fraction of an idea and you will convert it into a buzzing LinkedIn post full of emojis and excitement, just like every other LinkedIn post.
"""}]
input_history.append({"role": "user", "content": f"{message}"})
##add sth about context window here
outputs = llama32_3b_pipe(
input_history,
max_new_tokens=512
)
return outputs[-1]['generated_text'][-1]['content']
# Create the Gradio interface
def create_interface():
with gr.Blocks() as demo:
with gr.Row():
text_input = gr.Textbox(label="input for Linkedin Post Generator", value = "I am sometimes tired at work")
with gr.Row():
submit_btn = gr.Button("Generate LinkedIn Post")
with gr.Row():
text_output = gr.Textbox(interactive=False)
submit_btn.click(
fn=llama32_3b_chat,
inputs=[text_input],
outputs=[text_output]
)
return demo
# Launch the app
demo = create_interface()
demo.launch()