|
|
import gradio as gr |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "Marcus719/Llama-3.2-3B-Instruct-Lab2" |
|
|
|
|
|
client = InferenceClient(model=MODEL_ID) |
|
|
|
|
|
def chat(message, history, system_message, max_tokens, temperature, top_p): |
|
|
"""Generate response using HuggingFace Inference API""" |
|
|
|
|
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
|
|
|
|
|
for user_msg, assistant_msg in history: |
|
|
if user_msg: |
|
|
messages.append({"role": "user", "content": user_msg}) |
|
|
if assistant_msg: |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
response = "" |
|
|
for chunk in client.chat_completion( |
|
|
messages, |
|
|
max_tokens=max_tokens, |
|
|
stream=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
): |
|
|
if chunk.choices and chunk.choices[0].delta.content: |
|
|
token = chunk.choices[0].delta.content |
|
|
response += token |
|
|
yield response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant." |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="🦙 Llama 3.2 ChatBot") as demo: |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🦙 Llama 3.2 3B Instruct - Fine-tuned on FineTome |
|
|
|
|
|
**KTH ID2223 Scalable Machine Learning - Lab 2** |
|
|
|
|
|
This chatbot uses my fine-tuned Llama 3.2 3B model trained on the FineTome-100k dataset. |
|
|
|
|
|
📦 Model: [Marcus719/Llama-3.2-3B-Instruct-Lab2](https://huggingface.co/Marcus719/Llama-3.2-3B-Instruct-Lab2) |
|
|
""" |
|
|
) |
|
|
|
|
|
chatbot = gr.Chatbot(label="Chat", height=450, show_copy_button=True) |
|
|
|
|
|
with gr.Row(): |
|
|
msg = gr.Textbox( |
|
|
placeholder="Type your message here...", |
|
|
scale=4, |
|
|
container=False, |
|
|
autofocus=True |
|
|
) |
|
|
submit_btn = gr.Button("Send 🚀", scale=1, variant="primary") |
|
|
|
|
|
with gr.Accordion("⚙️ Settings", open=False): |
|
|
system_prompt = gr.Textbox( |
|
|
label="System Prompt", |
|
|
value=DEFAULT_SYSTEM_PROMPT, |
|
|
lines=2 |
|
|
) |
|
|
with gr.Row(): |
|
|
max_tokens = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") |
|
|
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature") |
|
|
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") |
|
|
|
|
|
with gr.Row(): |
|
|
clear_btn = gr.Button("🗑️ Clear Chat") |
|
|
retry_btn = gr.Button("🔄 Regenerate") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"Hello! Can you introduce yourself?", |
|
|
"Explain machine learning in simple terms.", |
|
|
"What is the difference between fine-tuning and pre-training?", |
|
|
"Write a short poem about AI.", |
|
|
], |
|
|
inputs=msg, |
|
|
label="💡 Try these examples" |
|
|
) |
|
|
|
|
|
|
|
|
def user_input(message, history): |
|
|
return "", history + [[message, None]] |
|
|
|
|
|
def bot_response(history, system_prompt, max_tokens, temperature, top_p): |
|
|
if not history: |
|
|
return history |
|
|
message = history[-1][0] |
|
|
history_for_model = history[:-1] |
|
|
for response in chat(message, history_for_model, system_prompt, max_tokens, temperature, top_p): |
|
|
history[-1][1] = response |
|
|
yield history |
|
|
|
|
|
def retry_last(history, system_prompt, max_tokens, temperature, top_p): |
|
|
if history: |
|
|
history[-1][1] = None |
|
|
message = history[-1][0] |
|
|
history_for_model = history[:-1] |
|
|
for response in chat(message, history_for_model, system_prompt, max_tokens, temperature, top_p): |
|
|
history[-1][1] = response |
|
|
yield history |
|
|
|
|
|
msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
bot_response, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot |
|
|
) |
|
|
submit_btn.click(user_input, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
bot_response, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot |
|
|
) |
|
|
clear_btn.click(lambda: [], None, chatbot, queue=False) |
|
|
retry_btn.click(retry_last, [chatbot, system_prompt, max_tokens, temperature, top_p], chatbot) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### 📝 About This Project |
|
|
|
|
|
**Fine-tuning Details:** |
|
|
- Base Model: `meta-llama/Llama-3.2-3B-Instruct` |
|
|
- Dataset: [FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) |
|
|
- Method: QLoRA (4-bit quantization + LoRA) |
|
|
- Framework: [Unsloth](https://github.com/unslothai/unsloth) |
|
|
|
|
|
Built with ❤️ for KTH ID2223 Lab 2 |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |