| from typing import Dict, Any, List | |
| import torch | |
| import time | |
| import uuid | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| DEFAULT_SYSTEM_PROMPT = "You are an expert Minecraft Forge mod developer for version 1.21.11. Write clean, efficient, and well-structured Java code." | |
| class EndpointHandler: | |
| def __init__(self, path: str = ""): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_id = "hwding/forge-coder-v1.21.11" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| base_model_id = "deepseek-ai/deepseek-coder-6.7b-instruct" | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| self.model = PeftModel.from_pretrained(self.model, path) | |
| self.model.eval() | |
| def _format_messages(self, messages: List[Dict[str, str]]) -> str: | |
| prompt_parts = [] | |
| has_system = False | |
| for msg in messages: | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role == "system": | |
| prompt_parts.append(f"### System:\n{content}") | |
| has_system = True | |
| elif role == "user": | |
| prompt_parts.append(f"### User:\n{content}") | |
| elif role == "assistant": | |
| prompt_parts.append(f"### Assistant:\n{content}") | |
| if not has_system: | |
| prompt_parts.insert(0, f"### System:\n{DEFAULT_SYSTEM_PROMPT}") | |
| prompt_parts.append("### Assistant:\n") | |
| return "\n\n".join(prompt_parts) | |
| def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str: | |
| input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature if temperature > 0 else 1.0, | |
| top_p=top_p, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "### Assistant:" in generated_text: | |
| generated_text = generated_text.split("### Assistant:")[-1].strip() | |
| return generated_text | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| messages = data.get("messages") | |
| if messages: | |
| return self._handle_openai_format(data) | |
| return self._handle_simple_format(data) | |
| def _handle_openai_format(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| messages = data.get("messages", []) | |
| max_tokens = data.get("max_tokens", 512) | |
| temperature = data.get("temperature", 0.7) | |
| top_p = data.get("top_p", 0.95) | |
| prompt = self._format_messages(messages) | |
| generated_text = self._generate(prompt, max_tokens, temperature, top_p) | |
| prompt_tokens = len(self.tokenizer.encode(prompt)) | |
| completion_tokens = len(self.tokenizer.encode(generated_text)) | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": self.model_id, | |
| "choices": [{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": generated_text, | |
| }, | |
| "finish_reason": "stop", | |
| }], | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": prompt_tokens + completion_tokens, | |
| } | |
| } | |
| def _handle_simple_format(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| inputs = data.get("inputs", "") | |
| parameters = data.get("parameters", {}) | |
| max_new_tokens = parameters.get("max_new_tokens", 512) | |
| temperature = parameters.get("temperature", 0.7) | |
| top_p = parameters.get("top_p", 0.95) | |
| if not inputs.startswith("### System:"): | |
| prompt = f"### System:\n{DEFAULT_SYSTEM_PROMPT}\n\n### User:\n{inputs}\n\n### Assistant:\n" | |
| else: | |
| prompt = inputs | |
| generated_text = self._generate(prompt, max_new_tokens, temperature, top_p) | |
| return {"generated_text": generated_text} | |