File size: 4,932 Bytes
b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f b9bae32 5b25c5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from typing import Dict, Any, List
import torch
import time
import uuid
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
DEFAULT_SYSTEM_PROMPT = "You are an expert Minecraft Forge mod developer for version 1.21.11. Write clean, efficient, and well-structured Java code."
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_id = "hwding/forge-coder-v1.21.11"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
base_model_id = "deepseek-ai/deepseek-coder-6.7b-instruct"
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
self.model = PeftModel.from_pretrained(self.model, path)
self.model.eval()
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
prompt_parts = []
has_system = False
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
prompt_parts.append(f"### System:\n{content}")
has_system = True
elif role == "user":
prompt_parts.append(f"### User:\n{content}")
elif role == "assistant":
prompt_parts.append(f"### Assistant:\n{content}")
if not has_system:
prompt_parts.insert(0, f"### System:\n{DEFAULT_SYSTEM_PROMPT}")
prompt_parts.append("### Assistant:\n")
return "\n\n".join(prompt_parts)
def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature if temperature > 0 else 1.0,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id,
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if "### Assistant:" in generated_text:
generated_text = generated_text.split("### Assistant:")[-1].strip()
return generated_text
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
messages = data.get("messages")
if messages:
return self._handle_openai_format(data)
return self._handle_simple_format(data)
def _handle_openai_format(self, data: Dict[str, Any]) -> Dict[str, Any]:
messages = data.get("messages", [])
max_tokens = data.get("max_tokens", 512)
temperature = data.get("temperature", 0.7)
top_p = data.get("top_p", 0.95)
prompt = self._format_messages(messages)
generated_text = self._generate(prompt, max_tokens, temperature, top_p)
prompt_tokens = len(self.tokenizer.encode(prompt))
completion_tokens = len(self.tokenizer.encode(generated_text))
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
"created": int(time.time()),
"model": self.model_id,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text,
},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
}
def _handle_simple_format(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
max_new_tokens = parameters.get("max_new_tokens", 512)
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.95)
if not inputs.startswith("### System:"):
prompt = f"### System:\n{DEFAULT_SYSTEM_PROMPT}\n\n### User:\n{inputs}\n\n### Assistant:\n"
else:
prompt = inputs
generated_text = self._generate(prompt, max_new_tokens, temperature, top_p)
return {"generated_text": generated_text}
|