import { call, select } from 'typed-redux-saga/macro';
import * as process from 'process';

import { STT_WSS_API } from 'api/constants';
import { getAccessToken } from 'features/Auth/selectors';

import {
  startSocket,
  extractClientEvent,
  getClientEventCallback,
  mergeEvents,
  SocketEvent,
} from '../../utils/startSocket';

import {
  CLOSE_REASONS,
  ERRORS,
  METADATA_SCHEMA,
  CURRENT_PROTOCOL_VERSION,
  METADATA_KEYS,
} from './constants';

import type {
  ClientEvents as BaseClientEvents,
  ClientMarkEvent,
  ClientStartEvent,
  EncodedAudio,
  ServerEvents,
  onCallback,
  onCallbackParameters,
  OnCloseBeforeMarkCallback,
  OnTranscriptCallback,
  OnMarkCallback,
  STTTypes,
} from './types';

function* sttSocket<
  CustomClientEvents extends SocketEvent = never,
  CustomServerEvents extends SocketEvent = never
>(sttType: STTTypes) {
  type ExtendedClientEvents = mergeEvents<BaseClientEvents, CustomClientEvents>;
  type ExtendedServerEvents = mergeEvents<ServerEvents, CustomServerEvents>;
  type ClientEvents = Utils.snakeObjectToCamel<ExtendedClientEvents>;

  const accessToken = yield* select(getAccessToken);

  if (!accessToken) throw new Error(ERRORS.NO_ACCESS_TOKEN);

  const url = `${STT_WSS_API}/api/v1/${sttType}/?auth_token=${accessToken}`;

  const { on, send, close } = yield* call(
    startSocket<ExtendedClientEvents, ExtendedServerEvents>,
    url
  );

  let closedByClient = false;
  let chunksSent = 0;
  const sendAudio = (audio: EncodedAudio) => {
    if (process.env.NODE_ENV === 'test') {
      chunksSent++;
      return send(audio);
    }

    const metadataBufferSize = METADATA_SCHEMA.reduce((acm, { length }) => acm + length, 0);

    const buffer = new ArrayBuffer(metadataBufferSize + (audio.length || 0));
    const dataView = new DataView(buffer);

    METADATA_SCHEMA.reduce((offset, { key, length }) => {
      switch (key) {
        case METADATA_KEYS.PROTOCOL_VERSION: {
          dataView.setInt16(offset, CURRENT_PROTOCOL_VERSION, true);
          break;
        }
        case METADATA_KEYS.CHUNK_NUMBER: {
          dataView.setInt16(offset, ++chunksSent, true);
          break;
        }
        case METADATA_KEYS.RESERVED:
        default: {
          break;
        }
      }

      return offset + length;
    }, 0);

    const binaryAudioDataWithMetadata = new Uint8Array(buffer);
    binaryAudioDataWithMetadata.set(audio, metadataBufferSize);

    send(binaryAudioDataWithMetadata);
  };

  const sendStart = ((payload = {}) => {
    send({
      type: 'start',
      payload: {
        ...payload,
        protocolVersion: CURRENT_PROTOCOL_VERSION,
      },
    } as unknown as extractClientEvent<ClientStartEvent, ClientEvents>);
  }) as getClientEventCallback<ClientStartEvent, ClientEvents>;

  type MarkPayload = Omit<
    Parameters<getClientEventCallback<ClientMarkEvent, ClientEvents>>[0],
    'lastChunkNumber'
  >;
  const sendMark = ((payload: MarkPayload) => {
    if (!chunksSent) close();
    else {
      send({
        type: 'mark',
        payload: {
          ...(payload || {}),
          lastChunkNumber: chunksSent - 1,
        },
      } as unknown as extractClientEvent<ClientMarkEvent, ClientEvents>);
    }
  }) as {} extends MarkPayload ? () => void : (payload: MarkPayload) => void;

  const extendedClose = () => {
    close();
    closedByClient = true;
  };

  function* onTranscript(callback: OnTranscriptCallback) {
    yield* call<onCallbackParameters<'transcript'>, onCallback<'transcript'>>(
      on,
      'transcript',
      function* (type, { data, chunkNumber }) {
        yield* call(callback, { text: data, chunkNumber });
      }
    );
  }

  function* onMark(callback: OnMarkCallback) {
    yield* call<onCallbackParameters<'mark'>, onCallback<'mark'>>(
      on,
      'mark',
      function* (type, { lastChunkNumber }) {
        yield* call(callback, { chunkNumber: lastChunkNumber });
      }
    );
  }

  function* onCloseBeforeMark(callback: OnCloseBeforeMarkCallback) {
    let marked = false;

    yield* call<onCallbackParameters<['close', 'mark']>, onCallback<'close' | 'mark'>>(
      on,
      ['close', 'mark'],
      function* (type) {
        if (type === 'mark') {
          marked = true;
        }
        if (type === 'close') {
          if (!marked) yield* call(callback, { closedByClient });
        }
      }
    );
  }

  function* onUnauthenticatedClose(callback: () => void | Generator<unknown, void>) {
    yield* call<onCallbackParameters<'close'>, onCallback<'close'>>(
      on,
      'close',
      function* (type, { reason }) {
        if (reason === CLOSE_REASONS.UNAUTHORIZED) callback();
      }
    );
  }

  type Send = (event: Exclude<ClientEvents, Utils.snakeObjectToCamel<BaseClientEvents>>) => void;
  return {
    on,
    send: send as Send,
    close: extendedClose,
    sendAudio,
    sendStart,
    sendMark,
    onTranscript,
    onMark,
    onCloseBeforeMark,
    onUnauthenticatedClose,
  };
}

export default sttSocket;
