alemmrr's picture
Update app.py
83ba5d0 verified
raw
history blame
2.76 kB
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
pipeline
)
# -----------------------------
# Load Your Classifier
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained("alemmrr/finbert-gics-sector-classifier")
model = AutoModelForSequenceClassification.from_pretrained("alemmrr/finbert-gics-sector-classifier")
clf = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
top_k=None,
device=-1
)
# -----------------------------
# Load NER Model (for auto-formatting)
# -----------------------------
ner_pipeline = pipeline(
"ner",
model="Jean-Baptiste/roberta-large-ner-english",
aggregation_strategy="simple"
)
# -----------------------------
# Helper: Format headline (Variant 3 Prefixing)
# -----------------------------
def format_headline_variant3(headline):
ents = ner_pipeline(headline)
# Buckets (same as training Variant-3)
entity_buckets = {
"ORG": [],
"LOC": [],
"PER": [],
"GPE": []
}
# Fill buckets
for ent in ents:
tag = ent["entity_group"]
word = ent["word"]
if tag in entity_buckets:
entity_buckets[tag].append(word)
# Build prefix
prefix = ""
for tag, values in entity_buckets.items():
if values:
prefix += f"[{tag}] " + " | ".join(values) + " "
# Append [SEP] if any prefix exists
if prefix:
prefix = prefix.strip() + " [SEP] "
# Return final formatted input for classifier
return prefix + headline
# -----------------------------
# Main Prediction Function
# -----------------------------
def predict(text):
# Auto-format headline β†’ Variant 3
formatted = format_headline_variant3(text)
outputs = clf(formatted)
# FIX: Flatten output if it's list-of-lists
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
outputs = outputs[0]
scores = [
{
"label": o["label"],
"confidence": round(float(o["score"]) * 100, 2)
}
for o in outputs
]
# Sort by confidence
scores = sorted(scores, key=lambda x: x["confidence"], reverse=True)
return scores
# -----------------------------
# Gradio Interface
# -----------------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, label="Enter a financial headline (plain text)"),
outputs=gr.JSON(label="All Sector Scores"),
title="FinBERT GICS Sector Classifier (Auto-Formatted)",
description=(
"Enter a plain financial news headline. The app automatically applies NER tagging "
),
)
demo.launch()