add fix
This commit is contained in:
parent
16855d1de9
commit
dd283bdc23
@ -13,9 +13,9 @@ from src.settings.base import settings
|
|||||||
|
|
||||||
class GoogleHelper:
|
class GoogleHelper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model_name
|
self.model = model_name
|
||||||
@ -24,7 +24,7 @@ class GoogleHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_messages_to_prompt(
|
def _serialize_messages_to_prompt(
|
||||||
chats: MessagesForSendToWorkersSchema,
|
chats: MessagesForSendToWorkersSchema,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
messages_for_request = create_base_message()
|
messages_for_request = create_base_message()
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class GoogleHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_response_to_json(
|
def _serialize_response_to_json(
|
||||||
response_text: str,
|
response_text: str,
|
||||||
) -> ResponseFromGeminiSchema:
|
) -> ResponseFromGeminiSchema:
|
||||||
cleaned_response = response_text.strip().replace('```json\n', '').replace('\n```', '')
|
cleaned_response = response_text.strip().replace('```json\n', '').replace('\n```', '')
|
||||||
try:
|
try:
|
||||||
@ -52,15 +52,25 @@ class GoogleHelper:
|
|||||||
return ResponseFromGeminiSchema(success=None)
|
return ResponseFromGeminiSchema(success=None)
|
||||||
|
|
||||||
def create_request_ai(
|
def create_request_ai(
|
||||||
self,
|
self,
|
||||||
messages: MessagesForSendToWorkersSchema,
|
messages: MessagesForSendToWorkersSchema,
|
||||||
) -> ResponseFromGeminiSchema:
|
) -> ResponseFromGeminiSchema:
|
||||||
|
print(messages.slice_id, "SLICE ID BEFORE REQUEST")
|
||||||
|
for message in messages.messages:
|
||||||
|
print(message.id, message.user_id, message.chat_id)
|
||||||
contents = self._serialize_messages_to_prompt(messages)
|
contents = self._serialize_messages_to_prompt(messages)
|
||||||
response = self._model.generate_content(contents=contents)
|
response = self._model.generate_content(contents=contents)
|
||||||
return self._serialize_response_to_json(response.text)
|
result = self._serialize_response_to_json(response.text)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
for i in result.success:
|
||||||
|
print(i.slice_id, i.user_id, "SUCCESS")
|
||||||
|
print(i.reason)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
gemini_helper = GoogleHelper(
|
gemini_helper = GoogleHelper(
|
||||||
api_key=settings.GEMINI.API_KEY,
|
api_key=settings.GEMINI.API_KEY,
|
||||||
model_name=settings.GEMINI.MODEL_NAME,
|
model_name=settings.GEMINI.MODEL_NAME,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user