import logging
from hashlib import md5
from statistics import mean

import backoff
import requests
from faust import Event
from google.protobuf.json_format import MessageToDict
from requests import RequestException
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession

import config
from custom_types import (
    ComputedItem,
    MonolithDataRequest,
    RawItem,
    RequestDataItem,
    SubtopicComputedData,
    SubtopicData,
)
from decorators import WithSession
from models import HiddenState
from services.datasphere.datasphere import DataSphere
from services.datasphere.yandex import YandexDataSphere
from services.state.state import State


logger = logging.getLogger(__name__)


class KafkaReader:
    """Функционал чтения данных из кафки"""

    def __init__(self, state: State):
        self.state = state

    @staticmethod
    def _hash_student_class_type(student_id: int, class_type_id: int) -> str:
        return md5(
            f'{student_id}_{class_type_id}'.encode(),
        ).hexdigest()

    def execute(self, event: Event):
        hash_value = self._hash_student_class_type(event.value.student_id, event.value.class_type_id)
        message = RawItem(
            hash_value=hash_value,
            **MessageToDict(
                event.value,
                preserving_proto_field_name=True,
            ),
        )
        self.state.add_raw_item(message, event.message.partition, event.message.offset)


class DataHandler:
    """Обработка собранных из кафка данных"""

    def __init__(self, state: State):
        self.state = state

    @staticmethod
    async def _prepare_data_for_datasphere(hash_value: str, message: RawItem) -> RequestDataItem:
        """Обработка данных для датасферы, полученных с монолита"""
        hidden_state, _ = await HiddenState.get_or_create(hash=hash_value)
        return RequestDataItem(
            hw_problems_id=list(
                map(str, message['hw_done_ids']),
            ),
            results=message['hw_done_grade'],
            performance_problems_id=[
                str(task['id']) for task in message['task_data']
            ],
            hidden_states=hidden_state.hidden_states,
            rv=hidden_state.rv,
            done_tasks_cnt=message.get('done_tasks_cnt', 0),
        )

    @staticmethod
    def _calculate_subtopics_knowledge(
            subtopic_ids: list[int],
            probably: list[float],
    ) -> list[SubtopicComputedData]:
        """Расчет знаний по подтемам"""
        subtopic_probably = {}
        accuracy: int = 2

        for subtopic_id, probably_value in zip(
                subtopic_ids,
                probably,
                strict=True,
        ):
            subtopic_probably.setdefault(subtopic_id, [])
            subtopic_probably[subtopic_id].append(probably_value)

        return [
            SubtopicComputedData(
                subtopic_id=subtopic_id,
                knowledge=round(mean(knowledge), accuracy),
            ) for subtopic_id, knowledge in subtopic_probably.items()
        ]

    def _prepare_data_for_monolith(
        self, subtopic_data: dict[str, SubtopicData],
        recomputed_data: dict[str, ComputedItem],
    ) -> list[MonolithDataRequest]:
        """Подготовка данных для отправки на монолит"""
        monolith_data: list[MonolithDataRequest] = []

        for key, value in recomputed_data.items():
            value = ComputedItem(*value)
            subtopic_data_item = subtopic_data[key]
            monolith_data.append(
                MonolithDataRequest(
                    student_id=subtopic_data_item['student_id'],
                    subtopics=self._calculate_subtopics_knowledge(
                        subtopic_data_item['subtopic_ids'],
                        value.task_probability,
                    ),
                ))
        return monolith_data

    async def _raw_data_processing(
            self,
            cached_data: dict[str, RawItem],
    ) -> tuple[dict[str, RequestDataItem], dict[str, SubtopicData]]:
        """
        Обработка сырых данных
        Сбор данных о подтемах и данных для отправки в датасферу
        """
        input_data: dict[str, RequestDataItem] = {}
        subtopic_data: dict[str, SubtopicData] = {}

        for hash_value, message in cached_data.items():
            ds_data = await self._prepare_data_for_datasphere(
                message=message,
                hash_value=hash_value,
            )
            input_data[hash_value] = ds_data
            subtopic_data[hash_value] = SubtopicData(
                student_id=message['student_id'],
                subtopic_ids=[task['subtopic_id'] for task in message['task_data']],
            )
        return input_data, subtopic_data

    @backoff.on_exception(
        backoff.expo,
        requests.exceptions.RequestException,
        max_time=10,
        jitter=backoff.full_jitter,
    )
    def _monolith_send_data(self, data: list[MonolithDataRequest]):
        """Отправка данных на монолит"""
        url = config.ACAD_PERF_MONOLITH_URL + config.ACAD_PERF_MONOLITH_API_URL
        response = requests.post(
            url=url,
            json=data,
            headers={'Api-Key': f'{config.ACAD_PERF_MONOLITH_API_KEY}'},
        )

        if not response.ok:
            logger.error('Send to monolith error: %s', response.text)
            raise RequestException(response.text)

        logger.info(msg='Data sent to monolith')

    @WithSession()
    async def _update_recomputed_data(
        self,
        data: dict[str, ComputedItem],
        *,
        session: AsyncSession,
    ) -> None:
        """
        Обновляет данные в postgres базе
        """
        data_list = []

        for hash_, value in data.items():
            data_list.append(
                {
                    'hash': hash_,
                    'hidden_states': value[1],
                    'rv': value[2][0],
                },
            )
        await session.execute(update(HiddenState), data_list)

    async def execute(self):
        cached_data = self.state.get_all_raw()
        if not cached_data:
            return
        input_data, subtopic_data = await self._raw_data_processing(cached_data)
        ds = DataSphere(datasphere=YandexDataSphere())
        recomputed_data = ds.recompute_data(input_data)
        await self._update_recomputed_data(
            data=recomputed_data,
        )
        monolith_data = self._prepare_data_for_monolith(
            subtopic_data=subtopic_data,
            recomputed_data=recomputed_data,
        )
        self._monolith_send_data(monolith_data)
        self.state.clean_raw(list(input_data.keys()))
