|
|
import gradio as gr |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
AutoModelForTokenClassification, |
|
|
pipeline |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("alemmrr/finbert-gics-sector-classifier") |
|
|
model = AutoModelForSequenceClassification.from_pretrained("alemmrr/finbert-gics-sector-classifier") |
|
|
|
|
|
clf = pipeline( |
|
|
"text-classification", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
top_k=None, |
|
|
device=-1 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ner_pipeline = pipeline( |
|
|
"ner", |
|
|
model="Jean-Baptiste/roberta-large-ner-english", |
|
|
aggregation_strategy="simple" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_headline_variant3(headline): |
|
|
ents = ner_pipeline(headline) |
|
|
|
|
|
|
|
|
entity_buckets = { |
|
|
"ORG": [], |
|
|
"LOC": [], |
|
|
"PER": [], |
|
|
"GPE": [] |
|
|
} |
|
|
|
|
|
|
|
|
for ent in ents: |
|
|
tag = ent["entity_group"] |
|
|
word = ent["word"] |
|
|
if tag in entity_buckets: |
|
|
entity_buckets[tag].append(word) |
|
|
|
|
|
|
|
|
prefix = "" |
|
|
for tag, values in entity_buckets.items(): |
|
|
if values: |
|
|
prefix += f"[{tag}] " + " | ".join(values) + " " |
|
|
|
|
|
|
|
|
if prefix: |
|
|
prefix = prefix.strip() + " [SEP] " |
|
|
|
|
|
|
|
|
return prefix + headline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(text): |
|
|
|
|
|
formatted = format_headline_variant3(text) |
|
|
|
|
|
outputs = clf(formatted) |
|
|
|
|
|
|
|
|
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list): |
|
|
outputs = outputs[0] |
|
|
|
|
|
scores = [ |
|
|
{ |
|
|
"label": o["label"], |
|
|
"confidence": round(float(o["score"]) * 100, 2) |
|
|
} |
|
|
for o in outputs |
|
|
] |
|
|
|
|
|
|
|
|
scores = sorted(scores, key=lambda x: x["confidence"], reverse=True) |
|
|
return scores |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Textbox(lines=3, label="Enter a financial headline (plain text)"), |
|
|
outputs=gr.JSON(label="All Sector Scores"), |
|
|
title="FinBERT GICS Sector Classifier (Auto-Formatted)", |
|
|
description=( |
|
|
"Enter a plain financial news headline. The app automatically applies NER tagging " |
|
|
), |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|