Update pipeline.py
Browse files- pipeline.py +268 -1
pipeline.py
CHANGED
|
@@ -742,4 +742,271 @@ class UltraRobustCallAnalytics:
|
|
| 742 |
|
| 743 |
with torch.no_grad():
|
| 744 |
logits = model(**inputs).logits
|
| 745 |
-
probs = torch.softmax(logits, dim=-1)[0].
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
|
| 743 |
with torch.no_grad():
|
| 744 |
logits = model(**inputs).logits
|
| 745 |
+
probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
|
| 746 |
+
|
| 747 |
+
# Map labels to age buckets (aggregating across genders)
|
| 748 |
+
# Labels usually look like: 'female_20-29', 'male_20-29', etc.
|
| 749 |
+
labels = model.config.id2label
|
| 750 |
+
age_scores = defaultdict(float)
|
| 751 |
+
|
| 752 |
+
for i, score in enumerate(probs):
|
| 753 |
+
label = labels[i]
|
| 754 |
+
# Extract age part (assuming format gender_age)
|
| 755 |
+
parts = label.split('_')
|
| 756 |
+
if len(parts) > 1:
|
| 757 |
+
age_group = parts[-1] # e.g., "20-29"
|
| 758 |
+
age_scores[age_group] += score
|
| 759 |
+
|
| 760 |
+
# Get best age bracket
|
| 761 |
+
if age_scores:
|
| 762 |
+
best_age = max(age_scores, key=age_scores.get)
|
| 763 |
+
return best_age
|
| 764 |
+
|
| 765 |
+
return "UNKNOWN"
|
| 766 |
+
|
| 767 |
+
except Exception as e:
|
| 768 |
+
print(f" ⚠ Age detection failed: {e}")
|
| 769 |
+
return "UNKNOWN"
|
| 770 |
+
|
| 771 |
+
def _run_enhanced_diarization(self, wav, sr, file_path):
|
| 772 |
+
"""
|
| 773 |
+
Run Pyannote diarization or fallback to simple segmentation
|
| 774 |
+
"""
|
| 775 |
+
if self.diarization_pipeline is None:
|
| 776 |
+
print(" ⚠ No auth token provided, using energy-based fallback segmentation")
|
| 777 |
+
return self._energy_based_segmentation(wav, sr)
|
| 778 |
+
|
| 779 |
+
try:
|
| 780 |
+
# Run pipeline
|
| 781 |
+
diarization = self.diarization_pipeline(file_path)
|
| 782 |
+
|
| 783 |
+
segments = []
|
| 784 |
+
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
| 785 |
+
segments.append({
|
| 786 |
+
"start": turn.start,
|
| 787 |
+
"end": turn.end,
|
| 788 |
+
"speaker": speaker
|
| 789 |
+
})
|
| 790 |
+
return segments
|
| 791 |
+
|
| 792 |
+
except Exception as e:
|
| 793 |
+
print(f" ⚠ Diarization error: {e}, using fallback")
|
| 794 |
+
return self._energy_based_segmentation(wav, sr)
|
| 795 |
+
|
| 796 |
+
def _energy_based_segmentation(self, wav, sr):
|
| 797 |
+
"""Fallback if deep learning diarization fails"""
|
| 798 |
+
# Simple energy detection to split speech from silence
|
| 799 |
+
# Treating as single speaker (SPEAKER_00)
|
| 800 |
+
intervals = librosa.effects.split(wav, top_db=30)
|
| 801 |
+
segments = []
|
| 802 |
+
for start, end in intervals:
|
| 803 |
+
segments.append({
|
| 804 |
+
"start": start / sr,
|
| 805 |
+
"end": end / sr,
|
| 806 |
+
"speaker": "SPEAKER_00"
|
| 807 |
+
})
|
| 808 |
+
return segments
|
| 809 |
+
|
| 810 |
+
def _merge_segments_smart(self, segments, min_gap=0.5):
|
| 811 |
+
"""Merge segments from same speaker that are close together"""
|
| 812 |
+
if not segments:
|
| 813 |
+
return []
|
| 814 |
+
|
| 815 |
+
merged = []
|
| 816 |
+
current = segments[0]
|
| 817 |
+
|
| 818 |
+
for next_seg in segments[1:]:
|
| 819 |
+
# If same speaker and gap is small
|
| 820 |
+
if (next_seg['speaker'] == current['speaker'] and
|
| 821 |
+
(next_seg['start'] - current['end']) < min_gap):
|
| 822 |
+
# Extend current segment
|
| 823 |
+
current['end'] = next_seg['end']
|
| 824 |
+
else:
|
| 825 |
+
merged.append(current)
|
| 826 |
+
current = next_seg
|
| 827 |
+
|
| 828 |
+
merged.append(current)
|
| 829 |
+
return merged
|
| 830 |
+
|
| 831 |
+
def _is_silence(self, chunk, threshold=0.005):
|
| 832 |
+
"""Check if audio chunk is essentially silence"""
|
| 833 |
+
return np.max(np.abs(chunk)) < threshold
|
| 834 |
+
|
| 835 |
+
def _detect_emotion(self, chunk):
|
| 836 |
+
"""Detect emotion from audio chunk"""
|
| 837 |
+
try:
|
| 838 |
+
# Ensure chunk is long enough for model
|
| 839 |
+
if len(chunk) < 16000 * 0.5:
|
| 840 |
+
return "neutral"
|
| 841 |
+
|
| 842 |
+
# Use the pipeline loaded in init
|
| 843 |
+
# Note: Pipeline expects file path or numpy array
|
| 844 |
+
preds = self.emotion_classifier(chunk, top_k=1)
|
| 845 |
+
return preds[0]['label']
|
| 846 |
+
except:
|
| 847 |
+
return "neutral"
|
| 848 |
+
|
| 849 |
+
def _calculate_tone_advanced(self, chunk, sr, text):
|
| 850 |
+
"""
|
| 851 |
+
Calculate pitch, jitter, and shimmer using Parselmouth (Praat)
|
| 852 |
+
"""
|
| 853 |
+
try:
|
| 854 |
+
if len(chunk) < sr * 0.1:
|
| 855 |
+
return {"pitch_hz": 0, "jitter": 0, "shimmer": 0}
|
| 856 |
+
|
| 857 |
+
snd = parselmouth.Sound(chunk, sampling_frequency=sr)
|
| 858 |
+
|
| 859 |
+
# Pitch
|
| 860 |
+
pitch = snd.to_pitch()
|
| 861 |
+
pitch_val = pitch.selected_array['frequency']
|
| 862 |
+
pitch_val = pitch_val[pitch_val != 0]
|
| 863 |
+
avg_pitch = np.mean(pitch_val) if len(pitch_val) > 0 else 0
|
| 864 |
+
|
| 865 |
+
# Pulses for Jitter/Shimmer
|
| 866 |
+
point_process = call(snd, "To PointProcess (periodic, cc)", 75, 500)
|
| 867 |
+
|
| 868 |
+
try:
|
| 869 |
+
jitter = call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
|
| 870 |
+
except:
|
| 871 |
+
jitter = 0
|
| 872 |
+
|
| 873 |
+
try:
|
| 874 |
+
shimmer = call([snd, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6)
|
| 875 |
+
except:
|
| 876 |
+
shimmer = 0
|
| 877 |
+
|
| 878 |
+
return {
|
| 879 |
+
"pitch_hz": round(float(avg_pitch), 1),
|
| 880 |
+
"jitter": round(float(jitter * 100), 2), # percentage
|
| 881 |
+
"shimmer": round(float(shimmer * 100), 2) # db
|
| 882 |
+
}
|
| 883 |
+
except:
|
| 884 |
+
return {"pitch_hz": 0, "jitter": 0, "shimmer": 0}
|
| 885 |
+
|
| 886 |
+
def _assign_roles_smart(self, results):
|
| 887 |
+
"""
|
| 888 |
+
Assign AGENT vs CUSTOMER roles based on content analysis
|
| 889 |
+
"""
|
| 890 |
+
speakers = set(r['speaker'] for r in results)
|
| 891 |
+
if len(speakers) == 1:
|
| 892 |
+
# Monologue - assume Agent recording
|
| 893 |
+
for r in results: r['role'] = "AGENT"
|
| 894 |
+
return results
|
| 895 |
+
|
| 896 |
+
speaker_scores = defaultdict(int)
|
| 897 |
+
|
| 898 |
+
# Agent keywords
|
| 899 |
+
agent_keywords = [
|
| 900 |
+
"thank you for calling", "my name is", "how can i help",
|
| 901 |
+
"assist you", "recording", "company", "representative"
|
| 902 |
+
]
|
| 903 |
+
|
| 904 |
+
# Customer keywords
|
| 905 |
+
customer_keywords = [
|
| 906 |
+
"issue", "problem", "not working", "bill", "complain",
|
| 907 |
+
"cancel", "help me", "fix"
|
| 908 |
+
]
|
| 909 |
+
|
| 910 |
+
for res in results:
|
| 911 |
+
text = res['text'].lower()
|
| 912 |
+
spk = res['speaker']
|
| 913 |
+
|
| 914 |
+
# Scoring
|
| 915 |
+
if any(k in text for k in agent_keywords):
|
| 916 |
+
speaker_scores[spk] += 2
|
| 917 |
+
if any(k in text for k in customer_keywords):
|
| 918 |
+
speaker_scores[spk] -= 2
|
| 919 |
+
|
| 920 |
+
# First speaker is often the agent (intro)
|
| 921 |
+
first_spk = results[0]['speaker']
|
| 922 |
+
speaker_scores[first_spk] += 1
|
| 923 |
+
|
| 924 |
+
# Identify Agent (highest score)
|
| 925 |
+
agent_spk = max(speaker_scores, key=speaker_scores.get)
|
| 926 |
+
|
| 927 |
+
# Assign
|
| 928 |
+
for res in results:
|
| 929 |
+
if res['speaker'] == agent_spk:
|
| 930 |
+
res['role'] = "AGENT"
|
| 931 |
+
else:
|
| 932 |
+
res['role'] = "CUSTOMER"
|
| 933 |
+
|
| 934 |
+
return results
|
| 935 |
+
|
| 936 |
+
def _analyze_customer_journey(self, results):
|
| 937 |
+
"""Analyze sentiment flow of the customer"""
|
| 938 |
+
cust_segments = [r for r in results if r['role'] == "CUSTOMER"]
|
| 939 |
+
|
| 940 |
+
if not cust_segments:
|
| 941 |
+
return {"emotional_arc": "No customer audio", "impact_score": 0}
|
| 942 |
+
|
| 943 |
+
# Map emotions to scores
|
| 944 |
+
emo_map = {
|
| 945 |
+
"happy": 1.0, "joy": 1.0, "neutral": 0.1,
|
| 946 |
+
"sad": -0.5, "angry": -1.0, "frustrated": -1.0
|
| 947 |
+
}
|
| 948 |
+
|
| 949 |
+
start_score = sum(emo_map.get(s['emotion'], 0) for s in cust_segments[:3]) / min(3, len(cust_segments))
|
| 950 |
+
end_score = sum(emo_map.get(s['emotion'], 0) for s in cust_segments[-3:]) / min(3, len(cust_segments))
|
| 951 |
+
|
| 952 |
+
impact = end_score - start_score
|
| 953 |
+
|
| 954 |
+
if impact > 0.2: arc = "Positive Resolution"
|
| 955 |
+
elif impact < -0.2: arc = "Negative Escalation"
|
| 956 |
+
else: arc = "Neutral/Unresolved"
|
| 957 |
+
|
| 958 |
+
return {
|
| 959 |
+
"emotional_arc": arc,
|
| 960 |
+
"start_sentiment": round(start_score, 2),
|
| 961 |
+
"end_sentiment": round(end_score, 2),
|
| 962 |
+
"impact_score": round(impact, 2)
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
def _analyze_agent_kpi(self, results, customer_impact):
|
| 966 |
+
"""Calculate Agent performance metrics"""
|
| 967 |
+
agent_segments = [r for r in results if r['role'] == "AGENT"]
|
| 968 |
+
|
| 969 |
+
if not agent_segments:
|
| 970 |
+
return {"overall_score": 0}
|
| 971 |
+
|
| 972 |
+
# 1. Politeness (Keyword based)
|
| 973 |
+
polite_words = ["please", "thank", "sorry", "apologize", "appreciate"]
|
| 974 |
+
total_words = sum(len(s['text'].split()) for s in agent_segments)
|
| 975 |
+
polite_count = sum(1 for s in agent_segments if any(w in s['text'].lower() for w in polite_words))
|
| 976 |
+
|
| 977 |
+
politeness_score = min(100, (polite_count / max(1, len(agent_segments))) * 200)
|
| 978 |
+
|
| 979 |
+
# 2. Tone Consistency (Jitter/Shimmer variance)
|
| 980 |
+
jitter_vals = [s['tone']['jitter'] for s in agent_segments]
|
| 981 |
+
tone_stability = 100 - min(100, np.std(jitter_vals) * 10) if jitter_vals else 50
|
| 982 |
+
|
| 983 |
+
# 3. Resolution Impact (from customer journey)
|
| 984 |
+
# Map -1.0 to 1.0 range -> 0 to 100
|
| 985 |
+
resolution_score = 50 + (customer_impact * 50)
|
| 986 |
+
resolution_score = max(0, min(100, resolution_score))
|
| 987 |
+
|
| 988 |
+
# Overall Weighted Score
|
| 989 |
+
overall = (
|
| 990 |
+
(politeness_score * 0.3) +
|
| 991 |
+
(tone_stability * 0.2) +
|
| 992 |
+
(resolution_score * 0.5)
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
return {
|
| 996 |
+
"overall_score": int(overall),
|
| 997 |
+
"politeness": int(politeness_score),
|
| 998 |
+
"tone_stability": int(tone_stability),
|
| 999 |
+
"resolution_effectiveness": int(resolution_score)
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
+
def _flush_memory(self):
|
| 1003 |
+
"""Aggressive memory cleanup"""
|
| 1004 |
+
gc.collect()
|
| 1005 |
+
if self.device == "cuda":
|
| 1006 |
+
torch.cuda.empty_cache()
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
if __name__ == "__main__":
|
| 1010 |
+
# Example usage
|
| 1011 |
+
print("Initialize with: analyzer = UltraRobustCallAnalytics(hf_token='YOUR_TOKEN')")
|
| 1012 |
+
print("Process with: result = analyzer.process_call('path/to/audio.wav')")
|