dungeon29 commited on
Commit
4faed86
·
verified ·
1 Parent(s): 0527e37

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import torch.nn.functional as F
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
4
  from huggingface_hub import hf_hub_download
5
  import gradio as gr
6
  import requests
@@ -20,8 +20,8 @@ from rag_engine import RAGEngine
20
  from llm_client import LLMClient
21
 
22
  # --------- Config ----------
23
- REPO_ID = "dungeon29/deberta-lstm-detect-phishing"
24
- CKPT_NAME = "pytorch_model.bin"
25
  MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
26
  LABELS = ["benign", "phishing"] # adjust to your classes
27
 
@@ -33,27 +33,48 @@ LABELS = ["benign", "phishing"] # adjust to your classes
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
 
36
- ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  checkpoint = torch.load(ckpt_path, map_location=device)
38
 
39
  # If you saved hyperparams in the checkpoint, use them:
40
- model = DeBERTaLSTMClassifier()
 
 
 
 
41
 
42
  # Load weights
43
  try:
44
- model.load_state_dict(checkpoint)
45
- except RuntimeError as e:
46
- if "attention" in str(e):
47
- # Old model without attention layer - initialize attention layer and load partial state
48
- state_dict = checkpoint["model_state_dict"]
49
- model_dict = model.state_dict()
50
- # Filter out attention layer parameters
51
- filtered_dict = {k: v for k, v in state_dict.items() if "attention" not in k}
52
- model_dict.update(filtered_dict)
53
- model.load_state_dict(model_dict)
54
- print("Loaded model without attention layer, using newly initialized attention weights")
55
  else:
56
- raise e
 
 
 
 
57
 
58
  model.to(device).eval()
59
 
@@ -360,7 +381,7 @@ def rag_predict_fn(text: str):
360
 
361
  if fetched_content:
362
  # Limit content length to avoid token overflow
363
- truncated_content = fetched_content[:4000]
364
  analysis_context = f"URL: {input_text}\n\nWebsite Content:\n{truncated_content}\n..."
365
  print(f"✅ Successfully fetched {len(fetched_content)} chars from URL.")
366
  else:
 
1
  import torch
2
  import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  from huggingface_hub import hf_hub_download
5
  import gradio as gr
6
  import requests
 
20
  from llm_client import LLMClient
21
 
22
  # --------- Config ----------
23
+ REPO_ID = "dungeon29/phishing-deberta-lstm" # HF repo that holds the checkpoint
24
+ CKPT_NAME = "deberta_lstm_checkpoint.pt" # the .pt file name
25
  MODEL_NAME = "microsoft/deberta-base" # base tokenizer/backbone
26
  LABELS = ["benign", "phishing"] # adjust to your classes
27
 
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
 
36
+ # Check if checkpoint exists locally, otherwise download from HF
37
+ if os.path.exists(CKPT_NAME):
38
+ print(f"📂 Found local checkpoint: {CKPT_NAME}")
39
+ ckpt_path = CKPT_NAME
40
+ else:
41
+ print(f"⬇️ Downloading checkpoint {CKPT_NAME} from HF Hub...")
42
+ try:
43
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_NAME)
44
+ except Exception as e:
45
+ print(f"⚠️ Could not download from HF: {e}")
46
+ # Fallback to pytorch_model.bin if the new name fails (optional, but good for safety)
47
+ print("🔄 Trying fallback to pytorch_model.bin...")
48
+ ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
49
+
50
  checkpoint = torch.load(ckpt_path, map_location=device)
51
 
52
  # If you saved hyperparams in the checkpoint, use them:
53
+ if isinstance(checkpoint, dict):
54
+ model_args = checkpoint.get("model_args", {}) # e.g., {"lstm_hidden":256, "num_labels":2, ...}
55
+ else:
56
+ model_args = {}
57
+ model = DeBERTaLSTMClassifier(**model_args)
58
 
59
  # Load weights
60
  try:
61
+ state_dict = torch.load(ckpt_path, map_location=device)
62
+
63
+ # Xử nếu file lưu dạng checkpoint đầy đủ (có key "model_state_dict")
64
+ if "model_state_dict" in state_dict:
65
+ state_dict = state_dict["model_state_dict"]
66
+
67
+ model.load_state_dict(state_dict, strict=False)
68
+
69
+ # Kiểm tra layer attention
70
+ if hasattr(model, 'attention') and 'attention.weight' not in state_dict:
71
+ print("⚠️ Loaded model without attention layer, using newly initialized attention weights")
72
  else:
73
+ print("✅ Load weights successfully!")
74
+
75
+ except Exception as e:
76
+ print(f"❌ Error when loading weights: {e}")
77
+ raise e
78
 
79
  model.to(device).eval()
80
 
 
381
 
382
  if fetched_content:
383
  # Limit content length to avoid token overflow
384
+ truncated_content = fetched_content[:1500]
385
  analysis_context = f"URL: {input_text}\n\nWebsite Content:\n{truncated_content}\n..."
386
  print(f"✅ Successfully fetched {len(fetched_content)} chars from URL.")
387
  else: