// Copyright DEWETRON GmbH 2018

#include "dt_client.h"
#include "dt_cmdline.h"
#include "dt_log.h"
#include "dt_stream_packet.h"
#include "dt_stopwatch.h"
#include <cassert>
#include <cinttypes>
#include <iostream>
#include <string>
#include <string.h>
#include <sstream>

#ifdef _WIN32
#include <WinSock2.h>
#include <ws2tcpip.h>
#include <io.h>
#else
#include <arpa/inet.h>
#include <netinet/in.h>
#include <unistd.h>
#endif

using socket_t = decltype(socket(0, 0, 0));

/**
 * Connect to a OXYGEN data stream plugin, read and
 * process welcome message.
 * @param dt_server
 * @param port
 * @return socket
 */
socket_t connectTo(const std::string& dt_server, uint32_t port);

/**
 * Read a packet.
 * @param sock to read from
 * @param packet
 * @return 0 if successful
 */
int32_t readPacket(socket_t sock, DtStreamPacket& packet);

/**
 * Print benchmark data
 */
void printPerformanceOutput(const char* entry_title, uint64_t runtime_ms, uint32_t bytes_received);

/**
 * dt_client main
 * Implements: OXYGEN DATA TRANSFER Protocol 1.5
 * default address: 127.0.0.1
 * default port:    5555
 */
int main(int argc, char* argv[])
{
#ifdef _WIN32
    WSADATA wsaData;
    int iResult;
    // Initialize Winsock
    iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
    if (iResult != 0) {
        printf("WSAStartup failed: %d\n", iResult);
        return 1;
    }
#endif

    DtClientConfig config;
    DtStreamPacket packet;
    DT_StopWatchHandle sw;
    DT_StopWatch_Create(&sw);

    try
    {
        if (!parseAruments(argc, argv, config))
        {
            throw std::runtime_error("");
        }

        DTLOG_INFO(1, "Starting dt_client");

        socket_t sock = connectTo(config.m_address, config.m_port);
        if (sock < 0)
        {
            throw std::runtime_error("Socket invalid");
        }

        DTLOG_INFO(0, "Streaming started");

        if (config.m_performance)
        {
            std::cout << "Performance statistics enabled" << std::endl;
        }

        uint32_t num_recv_packets = 0;
        uint32_t bytes_received = 0;       
        uint32_t old_bytes_received = 0;
        DT_StopWatch_Start(&sw);
        uint64_t start_time_ms = DT_StopWatch_GetMeantimeMS(&sw);
        uint64_t last_run_time_ms = start_time_ms;
        while (true)
        {
            int32_t success = readPacket(sock, packet);
            if (success < 0)
            {
                throw std::runtime_error("readDataPacket error");
            }
            ++num_recv_packets;

            DTLOG_INFO(1, packet.getPacketInfo());

            if (config.m_print_samples)
            {
                packet.printChannelSamples();
            }

            if (config.m_performance)
            {
                bytes_received += packet.getPacketSize();
            }

            if (config.m_performance)
            {
                auto run_time_ms = DT_StopWatch_GetMeantimeMS(&sw);
                if (run_time_ms/1000 != last_run_time_ms/1000)
                {
                    printPerformanceOutput("Meantime", run_time_ms - last_run_time_ms, bytes_received - old_bytes_received);
                    last_run_time_ms = run_time_ms;
                    old_bytes_received = bytes_received;
                }
            }

            packet.clearChannels();

            if (packet.isLastPacket()) break;
        }

        DT_StopWatch_Stop(&sw);
        auto time_ms = DT_StopWatch_GetMS(&sw);

        DTLOG_INFO(0, "Streaming stopped. Received packets: " << num_recv_packets);
        if (config.m_performance)
        {
            printPerformanceOutput("Summary", time_ms - start_time_ms, bytes_received);
        }

    }
    catch (std::exception& e)
    {
        if (!std::string(e.what()).empty())
        {
            DTLOG_ERROR(e.what());
        }
    }

    DT_StopWatch_Destroy(&sw);

#ifdef _WIN32
    WSACleanup();
#endif

    return 0;
}


socket_t connectTo(const std::string& dt_server, uint32_t port)
{
    DTLOG_INFO(1, "connectTo " << dt_server << ":" << port);

    socket_t sock = 0;
    if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0)
    {
        throw std::runtime_error("Socket creation error");
    }

    struct sockaddr_in serv_addr;
    memset(&serv_addr, 0x00, sizeof(serv_addr));
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_port = htons(port);

    if (inet_pton(AF_INET, dt_server.c_str(), &serv_addr.sin_addr) <= 0)
    {
        throw std::runtime_error("Invalid address / Address not supported");
    }

    if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0)
    {
        std::stringstream err;
        err << "Connection to " << dt_server << ":" << port << " failed";
        throw std::runtime_error(err.str());
    }

    // read welcome msg
    ByteBuffer welcome_buffer(DT_WELCOME_MSG_SIZE, 0);
    std::size_t bc = recv(sock, welcome_buffer.data(), DT_WELCOME_MSG_SIZE, 0);
    if (bc == 0)
    {
        throw std::runtime_error("Could not read welcome message");
    }

    DTLOG_INFO(1, "Data stream product name: " << std::string(static_cast<char*>(welcome_buffer.data()), DT_WELCOME_MSG_SIZE));

    return sock;
}

int32_t readPacket(socket_t sock, DtStreamPacket& packet)
{
    // read packet header from socket
    ByteBuffer packet_header_buffer(DT_PACKET_HEADER_SIZE, 0);
    auto bc = recv(sock, packet_header_buffer.data(), DT_PACKET_HEADER_SIZE, 0);
    if (DT_PACKET_HEADER_SIZE != bc)
    {
        DTLOG_ERROR("Could not read header");
        return -1;
    }

    if (packet.processPacketHeader(packet_header_buffer) < 0)
    {
        DTLOG_ERROR("Could not process packet header");
        return -1;
    }

    // read rest of the packet
    int32_t packet_size = packet.getPacketSize() - DT_PACKET_HEADER_SIZE;
    assert(packet_size > 0);
    ByteBuffer packet_buffer(packet_size, 0);
    bc = recv(sock, packet_buffer.data(), packet_size, MSG_WAITALL);
    if (packet_size != bc)
    {
        DTLOG_ERROR("Could not read all packet data");
        return -1;
    }

    if (packet.processSubPackets(packet_buffer) < 0)
    {
        DTLOG_ERROR("Could not process packet");
        return -1;
    }

    return 0;
}

void printPerformanceOutput(const char* entry_title, uint64_t runtime_ms, uint32_t bytes_received)
{
    auto mb_received = bytes_received / 1024.0 / 1024.0;
    std::cout << entry_title << ":\n";
    std::cout << "  Time:     " << runtime_ms << "ms\n";
    std::cout << "  Received: " << mb_received << "MB\n";
    std::cout << "  Speed:    " << mb_received / (runtime_ms / 1000.0) << "MB/s" << std::endl;
}
