akpande2 commited on
Commit
d37a3c7
·
verified ·
1 Parent(s): 081fe78

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +268 -1
pipeline.py CHANGED
@@ -742,4 +742,271 @@ class UltraRobustCallAnalytics:
742
 
743
  with torch.no_grad():
744
  logits = model(**inputs).logits
745
- probs = torch.softmax(logits, dim=-1)[0].
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
 
743
  with torch.no_grad():
744
  logits = model(**inputs).logits
745
+ probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
746
+
747
+ # Map labels to age buckets (aggregating across genders)
748
+ # Labels usually look like: 'female_20-29', 'male_20-29', etc.
749
+ labels = model.config.id2label
750
+ age_scores = defaultdict(float)
751
+
752
+ for i, score in enumerate(probs):
753
+ label = labels[i]
754
+ # Extract age part (assuming format gender_age)
755
+ parts = label.split('_')
756
+ if len(parts) > 1:
757
+ age_group = parts[-1] # e.g., "20-29"
758
+ age_scores[age_group] += score
759
+
760
+ # Get best age bracket
761
+ if age_scores:
762
+ best_age = max(age_scores, key=age_scores.get)
763
+ return best_age
764
+
765
+ return "UNKNOWN"
766
+
767
+ except Exception as e:
768
+ print(f" ⚠ Age detection failed: {e}")
769
+ return "UNKNOWN"
770
+
771
+ def _run_enhanced_diarization(self, wav, sr, file_path):
772
+ """
773
+ Run Pyannote diarization or fallback to simple segmentation
774
+ """
775
+ if self.diarization_pipeline is None:
776
+ print(" ⚠ No auth token provided, using energy-based fallback segmentation")
777
+ return self._energy_based_segmentation(wav, sr)
778
+
779
+ try:
780
+ # Run pipeline
781
+ diarization = self.diarization_pipeline(file_path)
782
+
783
+ segments = []
784
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
785
+ segments.append({
786
+ "start": turn.start,
787
+ "end": turn.end,
788
+ "speaker": speaker
789
+ })
790
+ return segments
791
+
792
+ except Exception as e:
793
+ print(f" ⚠ Diarization error: {e}, using fallback")
794
+ return self._energy_based_segmentation(wav, sr)
795
+
796
+ def _energy_based_segmentation(self, wav, sr):
797
+ """Fallback if deep learning diarization fails"""
798
+ # Simple energy detection to split speech from silence
799
+ # Treating as single speaker (SPEAKER_00)
800
+ intervals = librosa.effects.split(wav, top_db=30)
801
+ segments = []
802
+ for start, end in intervals:
803
+ segments.append({
804
+ "start": start / sr,
805
+ "end": end / sr,
806
+ "speaker": "SPEAKER_00"
807
+ })
808
+ return segments
809
+
810
+ def _merge_segments_smart(self, segments, min_gap=0.5):
811
+ """Merge segments from same speaker that are close together"""
812
+ if not segments:
813
+ return []
814
+
815
+ merged = []
816
+ current = segments[0]
817
+
818
+ for next_seg in segments[1:]:
819
+ # If same speaker and gap is small
820
+ if (next_seg['speaker'] == current['speaker'] and
821
+ (next_seg['start'] - current['end']) < min_gap):
822
+ # Extend current segment
823
+ current['end'] = next_seg['end']
824
+ else:
825
+ merged.append(current)
826
+ current = next_seg
827
+
828
+ merged.append(current)
829
+ return merged
830
+
831
+ def _is_silence(self, chunk, threshold=0.005):
832
+ """Check if audio chunk is essentially silence"""
833
+ return np.max(np.abs(chunk)) < threshold
834
+
835
+ def _detect_emotion(self, chunk):
836
+ """Detect emotion from audio chunk"""
837
+ try:
838
+ # Ensure chunk is long enough for model
839
+ if len(chunk) < 16000 * 0.5:
840
+ return "neutral"
841
+
842
+ # Use the pipeline loaded in init
843
+ # Note: Pipeline expects file path or numpy array
844
+ preds = self.emotion_classifier(chunk, top_k=1)
845
+ return preds[0]['label']
846
+ except:
847
+ return "neutral"
848
+
849
+ def _calculate_tone_advanced(self, chunk, sr, text):
850
+ """
851
+ Calculate pitch, jitter, and shimmer using Parselmouth (Praat)
852
+ """
853
+ try:
854
+ if len(chunk) < sr * 0.1:
855
+ return {"pitch_hz": 0, "jitter": 0, "shimmer": 0}
856
+
857
+ snd = parselmouth.Sound(chunk, sampling_frequency=sr)
858
+
859
+ # Pitch
860
+ pitch = snd.to_pitch()
861
+ pitch_val = pitch.selected_array['frequency']
862
+ pitch_val = pitch_val[pitch_val != 0]
863
+ avg_pitch = np.mean(pitch_val) if len(pitch_val) > 0 else 0
864
+
865
+ # Pulses for Jitter/Shimmer
866
+ point_process = call(snd, "To PointProcess (periodic, cc)", 75, 500)
867
+
868
+ try:
869
+ jitter = call(point_process, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
870
+ except:
871
+ jitter = 0
872
+
873
+ try:
874
+ shimmer = call([snd, point_process], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6)
875
+ except:
876
+ shimmer = 0
877
+
878
+ return {
879
+ "pitch_hz": round(float(avg_pitch), 1),
880
+ "jitter": round(float(jitter * 100), 2), # percentage
881
+ "shimmer": round(float(shimmer * 100), 2) # db
882
+ }
883
+ except:
884
+ return {"pitch_hz": 0, "jitter": 0, "shimmer": 0}
885
+
886
+ def _assign_roles_smart(self, results):
887
+ """
888
+ Assign AGENT vs CUSTOMER roles based on content analysis
889
+ """
890
+ speakers = set(r['speaker'] for r in results)
891
+ if len(speakers) == 1:
892
+ # Monologue - assume Agent recording
893
+ for r in results: r['role'] = "AGENT"
894
+ return results
895
+
896
+ speaker_scores = defaultdict(int)
897
+
898
+ # Agent keywords
899
+ agent_keywords = [
900
+ "thank you for calling", "my name is", "how can i help",
901
+ "assist you", "recording", "company", "representative"
902
+ ]
903
+
904
+ # Customer keywords
905
+ customer_keywords = [
906
+ "issue", "problem", "not working", "bill", "complain",
907
+ "cancel", "help me", "fix"
908
+ ]
909
+
910
+ for res in results:
911
+ text = res['text'].lower()
912
+ spk = res['speaker']
913
+
914
+ # Scoring
915
+ if any(k in text for k in agent_keywords):
916
+ speaker_scores[spk] += 2
917
+ if any(k in text for k in customer_keywords):
918
+ speaker_scores[spk] -= 2
919
+
920
+ # First speaker is often the agent (intro)
921
+ first_spk = results[0]['speaker']
922
+ speaker_scores[first_spk] += 1
923
+
924
+ # Identify Agent (highest score)
925
+ agent_spk = max(speaker_scores, key=speaker_scores.get)
926
+
927
+ # Assign
928
+ for res in results:
929
+ if res['speaker'] == agent_spk:
930
+ res['role'] = "AGENT"
931
+ else:
932
+ res['role'] = "CUSTOMER"
933
+
934
+ return results
935
+
936
+ def _analyze_customer_journey(self, results):
937
+ """Analyze sentiment flow of the customer"""
938
+ cust_segments = [r for r in results if r['role'] == "CUSTOMER"]
939
+
940
+ if not cust_segments:
941
+ return {"emotional_arc": "No customer audio", "impact_score": 0}
942
+
943
+ # Map emotions to scores
944
+ emo_map = {
945
+ "happy": 1.0, "joy": 1.0, "neutral": 0.1,
946
+ "sad": -0.5, "angry": -1.0, "frustrated": -1.0
947
+ }
948
+
949
+ start_score = sum(emo_map.get(s['emotion'], 0) for s in cust_segments[:3]) / min(3, len(cust_segments))
950
+ end_score = sum(emo_map.get(s['emotion'], 0) for s in cust_segments[-3:]) / min(3, len(cust_segments))
951
+
952
+ impact = end_score - start_score
953
+
954
+ if impact > 0.2: arc = "Positive Resolution"
955
+ elif impact < -0.2: arc = "Negative Escalation"
956
+ else: arc = "Neutral/Unresolved"
957
+
958
+ return {
959
+ "emotional_arc": arc,
960
+ "start_sentiment": round(start_score, 2),
961
+ "end_sentiment": round(end_score, 2),
962
+ "impact_score": round(impact, 2)
963
+ }
964
+
965
+ def _analyze_agent_kpi(self, results, customer_impact):
966
+ """Calculate Agent performance metrics"""
967
+ agent_segments = [r for r in results if r['role'] == "AGENT"]
968
+
969
+ if not agent_segments:
970
+ return {"overall_score": 0}
971
+
972
+ # 1. Politeness (Keyword based)
973
+ polite_words = ["please", "thank", "sorry", "apologize", "appreciate"]
974
+ total_words = sum(len(s['text'].split()) for s in agent_segments)
975
+ polite_count = sum(1 for s in agent_segments if any(w in s['text'].lower() for w in polite_words))
976
+
977
+ politeness_score = min(100, (polite_count / max(1, len(agent_segments))) * 200)
978
+
979
+ # 2. Tone Consistency (Jitter/Shimmer variance)
980
+ jitter_vals = [s['tone']['jitter'] for s in agent_segments]
981
+ tone_stability = 100 - min(100, np.std(jitter_vals) * 10) if jitter_vals else 50
982
+
983
+ # 3. Resolution Impact (from customer journey)
984
+ # Map -1.0 to 1.0 range -> 0 to 100
985
+ resolution_score = 50 + (customer_impact * 50)
986
+ resolution_score = max(0, min(100, resolution_score))
987
+
988
+ # Overall Weighted Score
989
+ overall = (
990
+ (politeness_score * 0.3) +
991
+ (tone_stability * 0.2) +
992
+ (resolution_score * 0.5)
993
+ )
994
+
995
+ return {
996
+ "overall_score": int(overall),
997
+ "politeness": int(politeness_score),
998
+ "tone_stability": int(tone_stability),
999
+ "resolution_effectiveness": int(resolution_score)
1000
+ }
1001
+
1002
+ def _flush_memory(self):
1003
+ """Aggressive memory cleanup"""
1004
+ gc.collect()
1005
+ if self.device == "cuda":
1006
+ torch.cuda.empty_cache()
1007
+
1008
+
1009
+ if __name__ == "__main__":
1010
+ # Example usage
1011
+ print("Initialize with: analyzer = UltraRobustCallAnalytics(hf_token='YOUR_TOKEN')")
1012
+ print("Process with: result = analyzer.process_call('path/to/audio.wav')")