// Copyright DEWETRON GmbH 2018

#include "dt_stream_packet_c.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef _WIN32
#include <WinSock2.h>
#include <ws2tcpip.h>
#include <io.h>
#else
#include <arpa/inet.h>
#include <netinet/in.h>
#include <unistd.h>
#endif

#ifdef _WIN32
typedef SOCKET socket_t;

#define WSA_INIT()                                      \
    WSADATA wsaData;                                    \
    int iResult;                                        \
    iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);     \
    if (iResult != 0) {                                 \
        printf("WSAStartup failed: %d\n", iResult);     \
        return 1;                                       \
    }

#define WSA_FINISH()       WSACleanup()
#else
typedef int socket_t;

#define WSA_INIT()          //printf("WSA_INIT\n")
#define WSA_FINISH()        //printf("WSA_FINISH\n")
#endif

/**
 * Connect to a OXYGEN data stream plugin, read and
 * process welcome message.
 * @param dt_server
 * @param port
 * @return socket
 */
socket_t connectTo(const char* dt_server, uint32_t port);

/**
 * Read a packet.
 * @param sock to read from
 * @return 0 if successful
 */
int32_t readPacket(socket_t sock);


int32_t processHeader(const uint8_t* buffer, uint32_t* packet_size);
int32_t processPacket(const uint8_t* buffer, uint32_t buffer_size);
const uint8_t* readData(const uint8_t* pos, void* target, uint32_t target_size);


int32_t processPacketInfo(const uint8_t* buffer, uint32_t buffer_size);
void processXmlConfig(const uint8_t* buffer, uint32_t buffer_size);
void processSyncFixed(const uint8_t* buffer, uint32_t buffer_size);
void processSyncVariable(const uint8_t* buffer, uint32_t buffer_size);
void processAsyncFixed(const uint8_t* buffer, uint32_t buffer_size);
void processAsyncVariable(const uint8_t* buffer, uint32_t buffer_size);
void processFooter(const uint8_t* buffer, uint32_t buffer_size);


/**
 * dt_client_c main
 * Implements: OXYGEN DATA TRANSFER Protocol 1.5
 * default address: 127.0.0.1
 * default port:    5555
 * 
 * Example to access OXYGEN: dt_client_c 192.168.0.155 5555
 */
int main(int argc, char* argv[])
{
    char* address;
    uint32_t port;
    socket_t sock;

    // initialize WINSOCK
    WSA_INIT();
    
    switch(argc)
    {
    case 1:
        address = "127.0.0.1";
        port = 5555;
        break;
    case 3:
        address = argv[1];
        port = atoi(argv[2]);
        break;
    }

    sock = connectTo(address, port);
    if (sock < 0)
    {
        fprintf(stderr, "connectTo %s:%d failed\n", address, port);
        WSA_FINISH();
        return -1;
    }


    while(readPacket(sock) > 0)
    {

    }


    // cleanup WINSOCK
    WSA_FINISH();

    return 0;
}

socket_t connectTo(const char* dt_server, uint32_t port)
{
    char welcome_buffer[DT_WELCOME_MSG_SIZE];
    struct sockaddr_in serv_addr;
    socket_t sock = 0;

    sock = socket(AF_INET, SOCK_STREAM, 0);
    if (sock < 0) return -1;

    memset(&serv_addr, 0x00, sizeof(serv_addr));
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_port = htons(port);

    if (inet_pton(AF_INET, dt_server, &serv_addr.sin_addr) <= 0)
    {
        fprintf(stderr, "Invalid address\n");
        return -1;
    }

    if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0)
    {
        fprintf(stderr, "Connection to %s:%d failed\n", dt_server, port);
        return -1;
    }

    // after successful the client has to read the connection welcome message.
    memset(welcome_buffer, 0, DT_WELCOME_MSG_SIZE);
    int32_t bc = recv(sock, welcome_buffer, DT_WELCOME_MSG_SIZE, 0);
    if (bc == 0)
    {
        fprintf(stderr, "Could not read welcome message");
        return -1;
    }

    printf("Data stream product name: %s\n", welcome_buffer);

    return sock;
}


int32_t readPacket(socket_t sock)
{
    uint8_t packet_header_buffer[DT_PACKET_HEADER_SIZE];
    uint8_t* packet_data = 0;
    uint32_t packet_size = 0;
    int32_t ret = 1;
    int32_t bc = recv(sock, packet_header_buffer, DT_PACKET_HEADER_SIZE, 0);

    if (DT_PACKET_HEADER_SIZE != bc)
    {
        fprintf(stderr, "Could not read header\n");
        return -1;
    }

    if (processHeader(packet_header_buffer, &packet_size) < 0)
    {
        fprintf(stderr, "Could not process packet header\n");
        return -1;
    }

    // read rest of the packet
    packet_size = packet_size - DT_PACKET_HEADER_SIZE;
    packet_data = malloc(packet_size);
    bc = recv(sock, packet_data, packet_size, MSG_WAITALL);
    if (packet_size != bc)
    {
        fprintf(stderr, "Could not read all packet data\n");
        free(packet_data);
        return -1;
    }

    if ((ret = processPacket(packet_data, packet_size)) < 0)
    {
        fprintf(stderr, "Could not process packet\n");
        free(packet_data);
        return -1;
    }

    free(packet_data);

    return ret;
}

int32_t processHeader(const uint8_t* buffer, uint32_t* packet_size)
{
    char start_token[DT_START_TOKEN_SIZE];
    const uint8_t* pos = NULL;

    pos = readData(buffer, start_token, DT_START_TOKEN_SIZE);
    if (0 != strncmp(DT_START_TOKEN, start_token, DT_START_TOKEN_SIZE))
    {
        fprintf(stderr, "Invalid packet start token\n");
        return -1;
    }

    // read the packet size
    pos = readData(pos, packet_size, sizeof(uint32_t));

    return 0;
}

int32_t processPacket(const uint8_t* buffer, uint32_t buffer_size)
{
    const uint8_t* pos = buffer;
    const uint8_t* end = buffer + buffer_size;
    const uint8_t* start_pos = buffer;
    uint32_t packet_size = 0;
    uint32_t packet_type = 0;
    uint32_t hit_footer = 0;
    int32_t  ret = 1;

    while ((pos < end) && (!hit_footer))
    {
        start_pos = pos;

        pos = readData(pos, &packet_size, sizeof(packet_size));
        pos = readData(pos, &packet_type, sizeof(packet_type));

        switch(packet_type)
        {
        case SBT_PACKET_INFO:       ret = processPacketInfo(pos, packet_size);  break;
        case SBT_XML_CONFIG:        processXmlConfig(pos, packet_size);  break;
        case SBT_SYNC_FIXED:        processSyncFixed(pos, packet_size);  break;
        case SBT_SYNC_VARIABLE:     processSyncVariable(pos, packet_size);  break;
        case SBT_ASYNC_FIXED:       processAsyncFixed(pos, packet_size);  break;
        case SBT_ASYNC_VARIABLE:    processAsyncVariable(pos, packet_size);  break;
        case SBT_PACKET_FOOTER:
            processFooter(pos, packet_size);
            hit_footer = 1; // ensure to end the loop after footer processing
            break;
        default:
            fprintf(stderr, "Unsupported SubPacketType: %d (size: %d)\n", packet_type, packet_size);
            break;       
        }

        // iterate to next subpacket
        pos = start_pos + packet_size;
    }

    return ret;
}

const uint8_t* readData(const uint8_t* pos, void* target, uint32_t target_size)
{
    memcpy(target, &(*pos), target_size);
    pos += target_size;
    return pos;
}

int32_t processPacketInfo(const uint8_t* buffer, uint32_t buffer_size)
{
    DtPacketInfo sub_packet;
    const uint8_t* pos = buffer;

    pos = readData(pos, &sub_packet.protocol_version, sizeof(sub_packet.protocol_version));
    pos = readData(pos, &sub_packet.stream_id, sizeof(sub_packet.stream_id));
    pos = readData(pos, &sub_packet.sequence_number, sizeof(sub_packet.sequence_number));
    pos = readData(pos, &sub_packet.stream_status, sizeof(sub_packet.stream_status));
    pos = readData(pos, &sub_packet.seed, sizeof(sub_packet.seed));
    pos = readData(pos, &sub_packet.number_of_subpackets, sizeof(sub_packet.number_of_subpackets));

    fprintf(stdout, "PacketInfo:\n");
    fprintf(stdout, "  Version:            %0x\n", sub_packet.protocol_version);
    fprintf(stdout, "  Stream ID:          %d\n", sub_packet.stream_id);
    fprintf(stdout, "  Seq Number:         %d\n", sub_packet.sequence_number);
    fprintf(stdout, "  Stream status:      %x\n", sub_packet.stream_status);
    fprintf(stdout, "  Stream seed:        %x\n", sub_packet.seed);
    fprintf(stdout, "  Num sub packets:    %d\n", sub_packet.number_of_subpackets);

    return (sub_packet.stream_status & ST_LAST_PACKET) == 0;
}

void processXmlConfig(const uint8_t* buffer, uint32_t buffer_size)
{
    DtXmlSubPacket sub_packet;
    const uint8_t* pos = buffer;

    sub_packet.xml_content_size = buffer_size - DT_PACKET_BASE_SIZE;
    sub_packet.xml_content = malloc(sub_packet.xml_content_size + 1);
    sub_packet.xml_content[sub_packet.xml_content_size] = 0;

    pos = readData(pos, sub_packet.xml_content, sub_packet.xml_content_size);

    fprintf(stdout, "XMLPacket:\n");
    fprintf(stdout, "  xml_content:        %s\n", sub_packet.xml_content);

    free(sub_packet.xml_content);
}

void processSyncFixed(const uint8_t* buffer, uint32_t buffer_size)
{
    DtChannelSyncFixed sub_packet;
    const uint8_t* pos = buffer;

    pos = readData(pos, &sub_packet.channel_data_type, sizeof(sub_packet.channel_data_type));
    pos = readData(pos, &sub_packet.channel_dimension, sizeof(sub_packet.channel_dimension));
    pos = readData(pos, &sub_packet.number_samples, sizeof(sub_packet.number_samples));
    pos = readData(pos, &sub_packet.timestamp, sizeof(sub_packet.timestamp));
    pos = readData(pos, &sub_packet.timebase_frequency, sizeof(sub_packet.timebase_frequency));

    fprintf(stdout, "DtChannelSyncFixed:\n");
    fprintf(stdout, "  channel_data_type:  %s (%d)\n", getDtDataTypeName(sub_packet.channel_data_type), sub_packet.channel_data_type);
    fprintf(stdout, "  channel_dimension:  %d\n", sub_packet.channel_dimension);
    fprintf(stdout, "  number_samples:     %d\n", sub_packet.number_samples);
    fprintf(stdout, "  timestamp:          %" PRIu64 "\n", sub_packet.timestamp);
    fprintf(stdout, "  timebase_frequency  %" PRIu64 "\n", sub_packet.timebase_frequency);
}

void processSyncVariable(const uint8_t* buffer, uint32_t buffer_size)
{
    DtChannelSyncVariable sub_packet;
    const uint8_t* pos = buffer;

    pos = readData(pos, &sub_packet.channel_data_type, sizeof(sub_packet.channel_data_type));
    pos = readData(pos, &sub_packet.channel_dimension, sizeof(sub_packet.channel_dimension));
    pos = readData(pos, &sub_packet.number_samples, sizeof(sub_packet.number_samples));
    pos = readData(pos, &sub_packet.timestamp, sizeof(sub_packet.timestamp));
    pos = readData(pos, &sub_packet.timebase_frequency, sizeof(sub_packet.timebase_frequency));

    fprintf(stdout, "DtChannelSyncVariable:\n");
    fprintf(stdout, "  channel_data_type:  %s (%d)\n", getDtDataTypeName(sub_packet.channel_data_type), sub_packet.channel_data_type);
    fprintf(stdout, "  channel_dimension:  %d\n", sub_packet.channel_dimension);
    fprintf(stdout, "  number_samples:     %d\n", sub_packet.number_samples);
    fprintf(stdout, "  timestamp:          %" PRIu64 "\n", sub_packet.timestamp);
    fprintf(stdout, "  timebase_frequency  %" PRIu64 "\n", sub_packet.timebase_frequency);
}

void processAsyncFixed(const uint8_t* buffer, uint32_t buffer_size)
{
    DtChannelAsyncFixed sub_packet;
    const uint8_t* pos = buffer;

    pos = readData(pos, &sub_packet.channel_data_type, sizeof(sub_packet.channel_data_type));
    pos = readData(pos, &sub_packet.channel_dimension, sizeof(sub_packet.channel_dimension));
    pos = readData(pos, &sub_packet.number_samples, sizeof(sub_packet.number_samples));
    pos = readData(pos, &sub_packet.timebase_frequency, sizeof(sub_packet.timebase_frequency));

    fprintf(stdout, "DtChannelAsyncFixed:\n");
    fprintf(stdout, "  channel_data_type:  %s (%d)\n", getDtDataTypeName(sub_packet.channel_data_type), sub_packet.channel_data_type);
    fprintf(stdout, "  channel_dimension:  %d\n", sub_packet.channel_dimension);
    fprintf(stdout, "  number_samples:     %d\n", sub_packet.number_samples);
    fprintf(stdout, "  timebase_frequency  %" PRIu64 "\n", sub_packet.timebase_frequency);
}

void processAsyncVariable(const uint8_t* buffer, uint32_t buffer_size)
{
    DtChannelAsyncVariable sub_packet;
    const uint8_t* pos = buffer;

    pos = readData(pos, &sub_packet.channel_data_type, sizeof(sub_packet.channel_data_type));
    pos = readData(pos, &sub_packet.channel_dimension, sizeof(sub_packet.channel_dimension));
    pos = readData(pos, &sub_packet.number_samples, sizeof(sub_packet.number_samples));
    pos = readData(pos, &sub_packet.timebase_frequency, sizeof(sub_packet.timebase_frequency));

    fprintf(stdout, "DtChannelAsyncVariable:\n");
    fprintf(stdout, "  channel_data_type:  %s (%d)\n", getDtDataTypeName(sub_packet.channel_data_type), sub_packet.channel_data_type);
    fprintf(stdout, "  channel_dimension:  %d\n", sub_packet.channel_dimension);
    fprintf(stdout, "  number_samples:     %d\n", sub_packet.number_samples);
    fprintf(stdout, "  timebase_frequency  %" PRIu64 "\n", sub_packet.timebase_frequency);
}

void processFooter(const uint8_t* buffer, uint32_t buffer_size)
{
    DtPacketFooter sub_packet;
    const uint8_t* pos = buffer;

    pos = readData(pos, &sub_packet.checksum, sizeof(sub_packet.checksum));

    fprintf(stdout, "PacketFooter:\n");
    fprintf(stdout, "  Checksum:           %0x\n", sub_packet.checksum);
}


int32_t getDtDataTypeSize(uint32_t dt)
{
    switch(dt)
    {
    case dt_sint8:  return 1;
    case dt_uint8:  return 1;
    case dt_sint16: return 2;
    case dt_uint16: return 2;
    case dt_sint24: return 3;
    case dt_uint24: return 3;
    case dt_sint32: return 4;
    case dt_uint32: return 4;
    case dt_sint64: return 8;
    case dt_uint64: return 8;
    case dt_float:  return 4;
    case dt_double: return 8;
    case dt_complex_float:  return 8;
    case dt_complex_double: return 16;
    case dt_string: return -1;
    case dt_binary: return -1;
    case dt_CAN:    return -1;
    }

    return 0;
}

const char* getDtDataTypeName(uint32_t dt)
{
    switch(dt)
    {
    case dt_sint8:  return "dt_sint8";
    case dt_uint8:  return "dt_uint8";
    case dt_sint16: return "dt_sint16";
    case dt_uint16: return "dt_uint16";
    case dt_sint24: return "dt_sint24";
    case dt_uint24: return "dt_uint24";
    case dt_sint32: return "dt_sint32";
    case dt_uint32: return "dt_uint32";
    case dt_sint64: return "dt_sint64";
    case dt_uint64: return "dt_uint64";
    case dt_float:  return "dt_float";
    case dt_double: return "dt_double";
    case dt_complex_float:  return "dt_complex_float";
    case dt_complex_double: return "dt_complex_double";
    case dt_string: return "dt_string";
    case dt_binary: return "dt_binary";
    case dt_CAN:    return "dt_CAN";
    }

    return "dt_unknown_type";
}