Spaces:
Sleeping
Sleeping
| """ | |
| CTI Bench Evaluation Runner | |
| This script provides a command-line interface to run the CTI Bench evaluation | |
| with your Retrieval Supervisor system. | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login as huggingface_login | |
| # Add the project root to Python path so we can import from src | |
| project_root = Path(__file__).parent.parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from src.evaluation.cti_bench.evaluator import CTIBenchEvaluator | |
| from src.agents.retrieval_supervisor.supervisor import RetrievalSupervisor | |
| def setup_environment( | |
| dataset_dir: str = "cti_bench/datasets", output_dir: str = "cti_bench/eval_output" | |
| ): | |
| """Set up the environment for evaluation.""" | |
| load_dotenv() | |
| # Load environment variables | |
| if os.getenv("GOOGLE_API_KEY"): | |
| os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY") | |
| if os.getenv("GROQ_API_KEY"): | |
| os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") | |
| if os.getenv("OPENAI_API_KEY"): | |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
| if os.getenv("HF_TOKEN"): | |
| huggingface_login(token=os.getenv("HF_TOKEN")) | |
| # Create necessary directories | |
| os.makedirs(dataset_dir, exist_ok=True) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Check if datasets exist | |
| dataset_path = Path(dataset_dir) | |
| ate_file = dataset_path / "cti-ate.tsv" | |
| mcq_file = dataset_path / "cti-mcq.tsv" | |
| if not ate_file.exists() or not mcq_file.exists(): | |
| print("ERROR: CTI Bench dataset files not found!") | |
| print(f"Expected files:") | |
| print(f" - {ate_file}") | |
| print(f" - {mcq_file}") | |
| print( | |
| "Please download the CTI Bench dataset and place the files in the correct location." | |
| ) | |
| sys.exit(1) | |
| return True | |
| def run_evaluation_quick_test( | |
| dataset_dir: str, | |
| output_dir: str, | |
| llm_model: str, | |
| kb_path: str, | |
| max_iterations: int, | |
| num_samples: int = 2, | |
| datasets: str = "all", | |
| ): | |
| """Run a quick test with a few samples.""" | |
| print("Running quick test evaluation...") | |
| try: | |
| # Initialize supervisor | |
| supervisor = RetrievalSupervisor( | |
| llm_model=llm_model, | |
| kb_path=kb_path, | |
| max_iterations=max_iterations, | |
| ) | |
| # Initialize evaluator | |
| evaluator = CTIBenchEvaluator( | |
| supervisor=supervisor, | |
| dataset_dir=dataset_dir, | |
| output_dir=output_dir, | |
| ) | |
| # Load datasets | |
| ate_df, mcq_df = evaluator.load_datasets() | |
| ate_filtered = evaluator.filter_dataset(ate_df, "ate") | |
| mcq_filtered = evaluator.filter_dataset(mcq_df, "mcq") | |
| # Test with specified number of samples | |
| print(f"Testing with first {num_samples} samples of each dataset...") | |
| ate_sample = ate_filtered.head(num_samples) | |
| mcq_sample = mcq_filtered.head(num_samples) | |
| # Run evaluations based on dataset selection | |
| ate_results = None | |
| mcq_results = None | |
| ate_metrics = None | |
| mcq_metrics = None | |
| if datasets in ["ate", "all"]: | |
| print(f"\nEvaluating ATE dataset...") | |
| ate_results = evaluator.evaluate_ate_dataset(ate_sample) | |
| ate_metrics = evaluator.calculate_ate_metrics(ate_results) | |
| if datasets in ["mcq", "all"]: | |
| print(f"\nEvaluating MCQ dataset...") | |
| mcq_results = evaluator.evaluate_mcq_dataset(mcq_sample) | |
| mcq_metrics = evaluator.calculate_mcq_metrics(mcq_results) | |
| # Print results | |
| print("\nQuick Test Results:") | |
| if ate_metrics: | |
| print(f"ATE - Macro F1: {ate_metrics.get('macro_f1', 0.0):.3f}") | |
| print(f"ATE - Success Rate: {ate_metrics.get('success_rate', 0.0):.3f}") | |
| if mcq_metrics: | |
| print(f"MCQ - Accuracy: {mcq_metrics.get('accuracy', 0.0):.3f}") | |
| print(f"MCQ - Success Rate: {mcq_metrics.get('success_rate', 0.0):.3f}") | |
| return True | |
| except Exception as e: | |
| print(f"Quick test failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def run_csv_metrics_calculation( | |
| csv_path: str, | |
| output_dir: str, | |
| model_name: str = None, | |
| ): | |
| """Calculate metrics from existing CSV results file.""" | |
| print("Calculating metrics from existing CSV file...") | |
| try: | |
| # Initialize evaluator (supervisor not needed for CSV processing) | |
| evaluator = CTIBenchEvaluator( | |
| supervisor=None, # Not needed for CSV processing | |
| dataset_dir="", # Not needed for CSV processing | |
| output_dir=output_dir, | |
| ) | |
| # Calculate metrics from CSV | |
| results = evaluator.calculate_metrics_from_csv( | |
| csv_path=csv_path, | |
| model_name=model_name, | |
| ) | |
| print("CSV metrics calculation completed successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"CSV metrics calculation failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def run_full_evaluation( | |
| dataset_dir: str, | |
| output_dir: str, | |
| llm_model: str, | |
| kb_path: str, | |
| max_iterations: int, | |
| datasets: str = "all", | |
| ): | |
| """Run the complete evaluation.""" | |
| print("Running full evaluation...") | |
| try: | |
| # Initialize supervisor | |
| supervisor = RetrievalSupervisor( | |
| llm_model=llm_model, | |
| kb_path=kb_path, | |
| max_iterations=max_iterations, | |
| ) | |
| # Initialize evaluator | |
| evaluator = CTIBenchEvaluator( | |
| supervisor=supervisor, | |
| dataset_dir=dataset_dir, | |
| output_dir=output_dir, | |
| ) | |
| # Run full evaluation based on dataset selection | |
| if datasets == "all": | |
| results = evaluator.run_full_evaluation() | |
| elif datasets == "ate": | |
| results = evaluator.run_ate_evaluation() | |
| elif datasets == "mcq": | |
| results = evaluator.run_mcq_evaluation() | |
| else: | |
| print(f"Invalid dataset selection: {datasets}") | |
| return False | |
| print("Full evaluation completed successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"Full evaluation failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def test_supervisor_connection(llm_model: str, kb_path: str): | |
| """Test the supervisor connection.""" | |
| try: | |
| supervisor = RetrievalSupervisor( | |
| llm_model=llm_model, | |
| kb_path=kb_path, | |
| max_iterations=1, | |
| ) | |
| response = supervisor.invoke_direct_query("Test query: What is T1071?") | |
| print("Supervisor connection successful!") | |
| print(f"Sample response length: {len(str(response))} characters") | |
| return True | |
| except Exception as e: | |
| print(f"Supervisor connection failed: {e}") | |
| return False | |
| def parse_arguments(): | |
| """Parse command line arguments.""" | |
| parser = argparse.ArgumentParser( | |
| description="CTI Bench Evaluation Runner", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Run quick test with default settings | |
| python cti_bench_evaluation.py --mode quick | |
| # Run full evaluation with custom settings | |
| python cti_bench_evaluation.py --mode full --llm-model google_genai:gemini-2.0-flash --max-iterations 5 | |
| # Run full evaluation on ATE dataset only | |
| python cti_bench_evaluation.py --mode full --datasets ate | |
| # Run full evaluation on MCQ dataset only | |
| python cti_bench_evaluation.py --mode full --datasets mcq | |
| # Test supervisor connection | |
| python cti_bench_evaluation.py --mode test | |
| # Run quick test with 5 samples | |
| python cti_bench_evaluation.py --mode quick --num-samples 5 | |
| # Calculate metrics from existing CSV file | |
| python cti_bench_evaluation.py --mode csv --csv-path cti_bench/eval_output/cti-ate_gemini-2.0-flash_20251024_193022.csv | |
| # Calculate metrics from CSV with custom model name | |
| python cti_bench_evaluation.py --mode csv --csv-path results.csv --csv-model-name my-model | |
| """, | |
| ) | |
| parser.add_argument( | |
| "--mode", | |
| choices=["quick", "full", "test", "csv"], | |
| required=True, | |
| help="Evaluation mode: 'quick' for quick test, 'full' for complete evaluation, 'test' for connection test, 'csv' for processing existing CSV files", | |
| ) | |
| parser.add_argument( | |
| "--datasets", | |
| choices=["ate", "mcq", "all"], | |
| default="all", | |
| help="Which datasets to evaluate: 'ate' for CTI-ATE only, 'mcq' for CTI-MCQ only, 'all' for both (default: all)", | |
| ) | |
| parser.add_argument( | |
| "--dataset-dir", | |
| default="cti_bench/datasets", | |
| help="Directory containing CTI Bench dataset files (default: cti_bench/datasets)", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default="cti_bench/eval_output", | |
| help="Directory for evaluation output files (default: cti_bench/eval_output)", | |
| ) | |
| parser.add_argument( | |
| "--llm-model", | |
| default="google_genai:gemini-2.0-flash", | |
| help="LLM model to use (default: google_genai:gemini-2.0-flash)", | |
| ) | |
| parser.add_argument( | |
| "--kb-path", | |
| default="./cyber_knowledge_base", | |
| help="Path to knowledge base (default: ./cyber_knowledge_base)", | |
| ) | |
| parser.add_argument( | |
| "--max-iterations", | |
| type=int, | |
| default=3, | |
| help="Maximum iterations for supervisor (default: 3)", | |
| ) | |
| parser.add_argument( | |
| "--num-samples", | |
| type=int, | |
| default=2, | |
| help="Number of samples for quick test (default: 2)", | |
| ) | |
| # CSV processing arguments | |
| parser.add_argument( | |
| "--csv-path", | |
| help="Path to existing CSV results file (required for csv mode)", | |
| ) | |
| parser.add_argument( | |
| "--csv-model-name", | |
| help="Model name to use in summary (optional, will be extracted from filename if not provided)", | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| """Main function.""" | |
| args = parse_arguments() | |
| print("CTI Bench Evaluation Runner") | |
| print("=" * 50) | |
| # Setup environment (skip dataset validation for CSV mode) | |
| if args.mode != "csv": | |
| if not setup_environment(args.dataset_dir, args.output_dir): | |
| return | |
| else: | |
| # For CSV mode, just create output directory | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Execute based on mode | |
| if args.mode == "quick": | |
| success = run_evaluation_quick_test( | |
| dataset_dir=args.dataset_dir, | |
| output_dir=args.output_dir, | |
| llm_model=args.llm_model, | |
| kb_path=args.kb_path, | |
| max_iterations=args.max_iterations, | |
| num_samples=args.num_samples, | |
| datasets=args.datasets, | |
| ) | |
| elif args.mode == "full": | |
| success = run_full_evaluation( | |
| dataset_dir=args.dataset_dir, | |
| output_dir=args.output_dir, | |
| llm_model=args.llm_model, | |
| kb_path=args.kb_path, | |
| max_iterations=args.max_iterations, | |
| datasets=args.datasets, | |
| ) | |
| elif args.mode == "test": | |
| success = test_supervisor_connection( | |
| llm_model=args.llm_model, kb_path=args.kb_path | |
| ) | |
| elif args.mode == "csv": | |
| # Validate CSV mode arguments | |
| if not args.csv_path: | |
| print("ERROR: --csv-path is required for csv mode") | |
| sys.exit(1) | |
| # Check if CSV file exists | |
| if not os.path.exists(args.csv_path): | |
| print(f"ERROR: CSV file not found: {args.csv_path}") | |
| sys.exit(1) | |
| success = run_csv_metrics_calculation( | |
| csv_path=args.csv_path, | |
| output_dir=args.output_dir, | |
| model_name=args.csv_model_name, | |
| ) | |
| if success: | |
| print("\nOperation completed successfully!") | |
| else: | |
| print("\nOperation failed!") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |