66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
import json
|
|
from json import JSONDecodeError
|
|
from typing import List
|
|
|
|
import google.generativeai as genai
|
|
from pydantic import ValidationError
|
|
|
|
from src.gemini_sdk.promt import create_base_message
|
|
from src.gemini_sdk.schemas import ResponseFromGeminiSchema
|
|
from src.schemas import MessagesForSendToWorkersSchema
|
|
from src.settings.base import settings
|
|
|
|
|
|
class GoogleHelper:
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
model_name: str,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.model = model_name
|
|
genai.configure(api_key=api_key)
|
|
self._model = genai.GenerativeModel(model_name=model_name)
|
|
|
|
@staticmethod
|
|
def _serialize_messages_to_prompt(
|
|
chats: MessagesForSendToWorkersSchema,
|
|
) -> List[dict]:
|
|
messages_for_request = create_base_message()
|
|
|
|
# Исправлена двойная сериализация
|
|
text_for_request = json.dumps(chats.model_dump(mode='json'))
|
|
|
|
extend_message = {
|
|
"role": "user",
|
|
"parts": [{"text": text_for_request}],
|
|
}
|
|
|
|
messages_for_request.append(extend_message)
|
|
return messages_for_request
|
|
|
|
@staticmethod
|
|
def _serialize_response_to_json(
|
|
response_text: str,
|
|
) -> ResponseFromGeminiSchema:
|
|
cleaned_response = response_text.strip().replace('```json\n', '').replace('\n```', '')
|
|
try:
|
|
response_as_dict = json.loads(cleaned_response)
|
|
return ResponseFromGeminiSchema(**response_as_dict)
|
|
except:
|
|
print(cleaned_response)
|
|
return ResponseFromGeminiSchema(success=None)
|
|
|
|
def create_request_ai(
|
|
self,
|
|
messages: MessagesForSendToWorkersSchema,
|
|
) -> ResponseFromGeminiSchema:
|
|
contents = self._serialize_messages_to_prompt(messages)
|
|
response = self._model.generate_content(contents=contents)
|
|
return self._serialize_response_to_json(response.text)
|
|
|
|
|
|
gemini_helper = GoogleHelper(
|
|
api_key=settings.GEMINI.API_KEY,
|
|
model_name=settings.GEMINI.MODEL_NAME,
|
|
) |