alemmrr commited on
Commit
ac5d59b
·
verified ·
1 Parent(s): 0cfd049

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -8
app.py CHANGED
@@ -1,7 +1,14 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
 
 
 
 
3
 
4
- # Load classifier
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("alemmrr/finbert-gics-sector-classifier")
6
  model = AutoModelForSequenceClassification.from_pretrained("alemmrr/finbert-gics-sector-classifier")
7
 
@@ -9,12 +16,62 @@ clf = pipeline(
9
  "text-classification",
10
  model=model,
11
  tokenizer=tokenizer,
12
- device=-1,
13
- top_k=None
 
 
 
 
 
 
 
 
 
14
  )
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def predict(text):
17
- outputs = clf(text)
 
 
 
18
 
19
  # FIX: Flatten output if it's list-of-lists
20
  if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
@@ -32,12 +89,19 @@ def predict(text):
32
  scores = sorted(scores, key=lambda x: x["confidence"], reverse=True)
33
  return scores
34
 
 
 
 
 
35
  demo = gr.Interface(
36
  fn=predict,
37
- inputs=gr.Textbox(lines=3, label="Enter text"),
38
  outputs=gr.JSON(label="All Sector Scores"),
39
- title="FinBERT GICS Sector Classifier",
40
- description="Returns all sector confidence scores."
 
 
 
41
  )
42
 
43
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSequenceClassification,
5
+ AutoModelForTokenClassification,
6
+ pipeline
7
+ )
8
 
9
+ # -----------------------------
10
+ # Load Your Classifier
11
+ # -----------------------------
12
  tokenizer = AutoTokenizer.from_pretrained("alemmrr/finbert-gics-sector-classifier")
13
  model = AutoModelForSequenceClassification.from_pretrained("alemmrr/finbert-gics-sector-classifier")
14
 
 
16
  "text-classification",
17
  model=model,
18
  tokenizer=tokenizer,
19
+ top_k=None,
20
+ device=-1
21
+ )
22
+
23
+ # -----------------------------
24
+ # Load NER Model (for auto-formatting)
25
+ # -----------------------------
26
+ ner_pipeline = pipeline(
27
+ "ner",
28
+ model="Jean-Baptiste/roberta-large-ner-english",
29
+ aggregation_strategy="simple"
30
  )
31
 
32
+ # -----------------------------
33
+ # Helper: Format headline (Variant 3 Prefixing)
34
+ # -----------------------------
35
+ def format_headline_variant3(headline):
36
+ ents = ner_pipeline(headline)
37
+
38
+ # Buckets (same as training Variant-3)
39
+ entity_buckets = {
40
+ "ORG": [],
41
+ "LOC": [],
42
+ "PER": [],
43
+ "GPE": []
44
+ }
45
+
46
+ # Fill buckets
47
+ for ent in ents:
48
+ tag = ent["entity_group"]
49
+ word = ent["word"]
50
+ if tag in entity_buckets:
51
+ entity_buckets[tag].append(word)
52
+
53
+ # Build prefix
54
+ prefix = ""
55
+ for tag, values in entity_buckets.items():
56
+ if values:
57
+ prefix += f"[{tag}] " + " | ".join(values) + " "
58
+
59
+ # Append [SEP] if any prefix exists
60
+ if prefix:
61
+ prefix = prefix.strip() + " [SEP] "
62
+
63
+ # Return final formatted input for classifier
64
+ return prefix + headline
65
+
66
+
67
+ # -----------------------------
68
+ # Main Prediction Function
69
+ # -----------------------------
70
  def predict(text):
71
+ # Auto-format headline → Variant 3
72
+ formatted = format_headline_variant3(text)
73
+
74
+ outputs = clf(formatted)
75
 
76
  # FIX: Flatten output if it's list-of-lists
77
  if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
 
89
  scores = sorted(scores, key=lambda x: x["confidence"], reverse=True)
90
  return scores
91
 
92
+
93
+ # -----------------------------
94
+ # Gradio Interface
95
+ # -----------------------------
96
  demo = gr.Interface(
97
  fn=predict,
98
+ inputs=gr.Textbox(lines=3, label="Enter a financial headline (plain text)"),
99
  outputs=gr.JSON(label="All Sector Scores"),
100
+ title="FinBERT GICS Sector Classifier (Auto-Formatted)",
101
+ description=(
102
+ "Enter a plain financial news headline. The app automatically applies NER tagging "
103
+ "and formats the text using the Variant-3 prefix structure before running classification."
104
+ ),
105
  )
106
 
107
  demo.launch()