76 lines
2.0 KiB
Python
76 lines
2.0 KiB
Python
import json
|
|
|
|
from src.core.ai_services.base import BaseAiService
|
|
|
|
import google.generativeai as genai
|
|
|
|
from src.core.ai_services.gemini.constants import GEMINI_BASE_MESSAGE
|
|
from src.core.ai_services.schemas import MessageFromChatSchema, ResponseFromAiSchema
|
|
from src.core.settings.base import settings
|
|
|
|
|
|
class GoogleHelper(BaseAiService):
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
model_name: str,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.model = model_name
|
|
|
|
genai.configure(api_key=api_key)
|
|
self._model = genai.GenerativeModel(model_name=model_name)
|
|
|
|
@staticmethod
|
|
def _serialize_messages_to_promt(
|
|
messages: list[MessageFromChatSchema],
|
|
) -> list[dict]:
|
|
messages_for_request = GEMINI_BASE_MESSAGE.copy()
|
|
dumped_messages = [msg.model_dump_with_datetime() for msg in messages]
|
|
|
|
text_for_request = json.dumps({"messages": dumped_messages})
|
|
|
|
extend_message = {
|
|
"role": "user",
|
|
"parts": [
|
|
{
|
|
"text": text_for_request,
|
|
}
|
|
],
|
|
}
|
|
|
|
messages_for_request.append(extend_message)
|
|
|
|
return messages_for_request
|
|
|
|
@staticmethod
|
|
def _serialize_response_to_json(
|
|
response_text: str,
|
|
) -> ResponseFromAiSchema:
|
|
response = response_text.replace('\n', '')
|
|
print(response)
|
|
print(len(response))
|
|
print("gemini")
|
|
response = response_text.replace('\n', '')
|
|
response_as_dict = json.loads(response_text)
|
|
|
|
return ResponseFromAiSchema(**response_as_dict)
|
|
|
|
async def create_request_ai(
|
|
self,
|
|
messages: list[MessageFromChatSchema],
|
|
) -> ResponseFromAiSchema:
|
|
contents = self._serialize_messages_to_promt(messages)
|
|
|
|
response = await self._model.generate_content_async(
|
|
contents=contents
|
|
)
|
|
|
|
return self._serialize_response_to_json(response.text)
|
|
|
|
|
|
|
|
gemini_helper = GoogleHelper(
|
|
api_key=settings.GEMINI.API_KEY,
|
|
model_name=settings.GEMINI.MODEL_NAME,
|
|
) |