#include "TCPSocket.h"

#include <fcntl.h>
#include <sys/types.h>

#if !defined(WIN32)
#include <netdb.h>
#include <arpa/inet.h>
#include <sys/errno.h>
#include <sys/time.h>
#include <unistd.h>
#include <sys/ioctl.h>
#else // WIN32
#define ENOTCONN WSAENOTCONN
#endif // !WIN32

#if !defined(HAVE_INET_NTOA_R)
#ifdef __cplusplus
extern "C" {
extern char *inet_ntoa_r(struct in_addr in, char *buf, int len);
} // extern "C"
#endif
#endif

TCPSocket::TCPSocket()
    : mDescriptor(SOCKET_INVALID_DESCRIPTOR),
    mRemoteHost(0), mRemotePort(SOCKET_INVALID_PORT),
    mConnected(false), mListening(false)
{
}

TCPSocket::TCPSocket(int inDescriptor, struct sockaddr_in* inRemoteAddress)
	: mConnected(true), mListening(false)
{
	mDescriptor = inDescriptor;
	mRemoteHost = inRemoteAddress->sin_addr.s_addr;
    // inet_ntoa isn't thread safe, so we use this instead
    // some systems define inet_ntoa_r, but we also define it in compat/inet_ntoa_r.c
    char tmpAddr[18];
    inet_ntoa_r(inRemoteAddress->sin_addr, tmpAddr, sizeof(tmpAddr));
	mRemoteHostName.assign(tmpAddr);
	mRemotePort = inRemoteAddress->sin_port;

#if defined(SO_NOSIGPIPE)
    int on = 1;
	if (setsockopt(mDescriptor, SOL_SOCKET, SO_NOSIGPIPE, (void*)&on, sizeof(on)) == -1)
	{
		throw socket_error(errno);
	}
    else
    {
        DEBUG_CALL(printf("SO_NOSIGPIPE is on\n"); fflush(stdout));
    }
#endif
}

TCPSocket::~TCPSocket()
{
    Shutdown();
    Close();
}

void TCPSocket::Open()
{
	if ((mDescriptor = socket(AF_INET, SOCK_STREAM, 0)) < 0)
	{
		throw socket_error(errno);
	}
}

void TCPSocket::Close()
{
    mConnected = false;
    mListening = false;
    
    if (mDescriptor != SOCKET_INVALID_DESCRIPTOR)
	{
        DEBUG_CALL(printf("closing socket\n"); fflush(stdout));

		int err = ::closesocket(mDescriptor);

        if (err != 0)
        {
            DEBUG_CALL(printf("error closing socket: %d\n", err));
            // throw an exception or what?
        }
		mDescriptor = SOCKET_INVALID_DESCRIPTOR;
	}
	
	mRemotePort = SOCKET_INVALID_PORT;
}

bool TCPSocket::IsConnected()
{
	return mConnected;
}

bool TCPSocket::SetNonBlock(bool nonBlock)
{
#if defined(WIN32)
	unsigned long flag = (unsigned long)nonBlock;
	return ::ioctlsocket(mDescriptor, FIONBIO, &flag) == 0;
#else
	int oldFlags = fcntl(mDescriptor, F_GETFL, 0);
	if (oldFlags == -1)
		return true;
	
    if (nonBlock)
		oldFlags |= O_NONBLOCK;
	else
		oldFlags &= ~O_NONBLOCK;
	
    return fcntl(mDescriptor, F_SETFL, oldFlags);
#endif // !WIN32
}
/*
bool TCPSocket::SetAsync(bool isAsync)
{
#if 0
   pid_t pid = getpid();
   int flag = isAsync;
   if (isAsync)
      ioctl(mDescriptor, FIOSETOWN, &pid);
   ioctl(mDescriptor, FIOASYNC, &flag);
   return true;
#else
	int oldFlags = fcntl(mDescriptor, F_GETFL, 0);
	if (oldFlags == -1)
		return true;
	
   if (isAsync)
   {
      fcntl(mDescriptor, F_SETOWN, getpid());
		oldFlags |= O_ASYNC;
	}
   else
	   oldFlags &= ~O_ASYNC;
	
   return fcntl(mDescriptor, F_SETFL, oldFlags);
#endif
}
*/

int TCPSocket::Accept( struct sockaddr_in& acceptAddr)
{
	int descriptor = SOCKET_INVALID_DESCRIPTOR;
	socklen_t addrSize = sizeof(acceptAddr);
  
	// Socket must be created, not connected, and listening
    
	//assert(mDescriptor != SOCKET_INVALID_DESCRIPTOR);
	//assert(!mConnected);
	//assert(mListening);
    
	// Accept a remote connection.  Raise on failure
	descriptor = accept(mDescriptor, (struct sockaddr *)&acceptAddr, &addrSize);
	
	if (descriptor < 0)
	{
        DEBUG_CALL(printf("accept failed(%d) %s\n", errno, strerror(errno)); fflush(stdout));
		throw socket_error(errno);
	}

	return descriptor;
}

void TCPSocket::SetLinger(bool inOnOff)
{
	struct linger optval;

	optval.l_onoff = inOnOff;
	optval.l_linger = 10;

	if (setsockopt(mDescriptor, SOL_SOCKET, SO_LINGER, (char *)&optval, sizeof(struct linger)) < 0)
    {
        DEBUG_CALL(printf("set linger failed: %d\n", errno); fflush(stdout));
    }
}

void TCPSocket::SetKeepAlive(bool inOn)
{
    int on = inOn;
    if (::setsockopt(mDescriptor, SOL_SOCKET, SO_KEEPALIVE, (netdata_t*)&on, sizeof(on)) < 0)
    {
        DEBUG_CALL(printf("set keepalive failed: %d\n", errno); fflush(stdout));
    }
}

unsigned int TCPSocket::GetSndBufSize()
{
    int bufSize = 0;
	socklen_t optSize = sizeof(bufSize);
    if (::getsockopt(mDescriptor, SOL_SOCKET, SO_SNDBUF, (netdata_t*)&bufSize, &optSize) < 0)
    {
        DEBUG_CALL(printf("get snd buf size failed: %d\n", errno); fflush(stdout));
    }
	return bufSize;
}

unsigned int TCPSocket::GetRcvBufSize()
{
    int bufSize = 0;
	socklen_t optSize = sizeof(bufSize);
    if (::getsockopt(mDescriptor, SOL_SOCKET, SO_RCVBUF, (netdata_t*)&bufSize, &optSize) < 0)
    {
        DEBUG_CALL(printf("get rcv buf size failed: %d\n", errno); fflush(stdout));
    }
	return bufSize;
}

int TCPSocket::GetError()
{
    int sockError = 0;
	socklen_t optSize = sizeof(sockError);
    if (::getsockopt(mDescriptor, SOL_SOCKET, SO_ERROR, (netdata_t*)&sockError, &optSize) < 0)
    {
        DEBUG_CALL(printf("get sock error failed: %d\n", errno); fflush(stdout));
    }
	return sockError;
}

int TCPSocket::GetAvailBytes()
{
#if defined(WIN32)
	unsigned long ret;
	if (::ioctlsocket(mDescriptor, FIONREAD, &ret) != -1)
#else
	int ret;
	if (::ioctl(mDescriptor, FIONREAD, &ret) != -1)
#endif
	{
		return ret;
	}
	else
	{
		if (errno == EINVAL) // listening socket
			return -1;
		else
			return 0;
	}
}

void TCPSocket::Shutdown()
{
    if (mConnected)
    {
        DEBUG_CALL(printf("shutting down socket\n"); fflush(stdout));
        if (shutdown(mDescriptor, SHUT_RDWR) < 0)
        {
            // strerror() is not thread safe, but we'll just take a gamble
            DEBUG_CALL(printf("error shutdown: %s\n", strerror(errno)); fflush(stdout));
        }
    }
    
    mConnected = false;
    mListening = false;
}


int TCPSocket::Recv(void *outBytes, u_int32_t inCount, u_int32_t inFlags) throw ()
{
	int readCount = 0;
	
	// Socket must be created and connected
	//assert(mDescriptor != SOCKET_INVALID_DESCRIPTOR);
	//assert(mConnected == true);
    //assert(inCount != 0);
    
	readCount = ::recv(mDescriptor, (netdata_t*)outBytes, inCount, inFlags);
	// readCount can be 0
	// like when recv is blocked and the socket is disconnected
    
    return readCount;
}

int TCPSocket::Send(const void *inBytes, u_int32_t inCount) throw ()
{
	int sent = 0;

	// Socket must be created and connected
	//assert(mDescriptor != SOCKET_INVALID_DESCRIPTOR);
	//assert(mConnected);
	//assert(inCount != 0);
    
    // Send the data
    sent = ::send(mDescriptor, (const netdata_t*)inBytes, inCount, MSG_NOSIGNAL);
    
    if (sent == -1 && errno == EPIPE)
    {
        DEBUG_CALL(printf("send returned EPIPE\n"); fflush(stdout));
        //TCPSocket::Close();
        //throw runtime_error("send failed");
    }
    
    return sent;
}

void TCPSocket::RecvBytes(void *outBytes, u_int32_t inCount) throw (socket_error)
{
	// this implements a real blocking recv even if
	// MSG_WAITALL isn't supported
	unsigned int recvCount = 0;
	while (recvCount < inCount)
	{
		int status = Recv(&((char *)outBytes)[recvCount],
			inCount - recvCount, MSG_WAITALL);
		if (status < 0)
			throw socket_error(errno);
		else if (status == 0)
		{
			// this means the socket is disconnected
			Close();
			throw socket_error(ENOTCONN);
		}
		else
		{
			recvCount += status;
		}
	}
}

void TCPSocket::SendBytes(const void *inBytes, u_int32_t inCount) throw (socket_error)
{
	const char* bytes = (char *)inBytes;    
	int len = inCount;
	int sent = 0;
	
    while (len > 0)
	{
		sent = Send(bytes, len);
		if (sent == -1)
		{
            throw socket_error(errno);
		}
        else
        {
            bytes += sent;
            len -= sent;
        }
	}
}

void TCPSocket::Connect(const u_int32_t inAddress, const u_int16_t inPort)
{
	struct sockaddr_in addr;
	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = htonl(inAddress);
	addr.sin_port = htons(inPort);
	if (connect(mDescriptor, (struct sockaddr *)&addr, sizeof(addr)) < 0)
	{
		throw socket_error(errno);
	}
	mConnected = true;
	mRemoteHost = inAddress;
	mRemotePort = inPort;
}

void TCPSocket::Bind(u_int16_t inPort, u_int32_t inAddress)
{
	struct sockaddr_in localAddr;
	int on = 1;

	// Set a flag so that this address can be re-used immediately after the connection
	// closes.  (TCP normally imposes a delay before an address can be re-used.)
    
	if (setsockopt(mDescriptor, SOL_SOCKET, SO_REUSEADDR, (netdata_t*)&on, sizeof(on)) < 0)
	{
		throw socket_error(errno);
	}
	// Bind the address to the socket

	localAddr.sin_family = AF_INET;
	localAddr.sin_addr.s_addr = htonl(inAddress);
	localAddr.sin_port = htons(inPort);

	if (bind(mDescriptor, (struct sockaddr*)&localAddr, sizeof(localAddr)) < 0)
	{
		throw socket_error(errno);
	}
}

void TCPSocket::Listen(u_int32_t inMaxPendingConnections)
{
	int ret;
	if ( ret = listen(mDescriptor, inMaxPendingConnections) < 0)
	{
		throw socket_error(errno);
	}
	mListening = true;
}

void TCPSocket::Select(bool &ioRead, bool &ioWrite)
{
    fd_set readSet;
    fd_set writeSet;
    
    FD_ZERO(&readSet);
    FD_ZERO(&writeSet);
    
    if (ioRead)
        FD_SET(mDescriptor, &readSet);
    
    if (ioWrite)
        FD_SET(mDescriptor, &writeSet);
    
    ioRead = false;
    ioWrite = false;
    
    int err = select(mDescriptor + 1, &readSet, &writeSet, NULL, NULL);
    
    if (err > 0)
    {
        if (FD_ISSET(mDescriptor, &readSet))
            ioRead = true;
        
        if (FD_ISSET(mDescriptor, &writeSet))
            ioWrite = true;
    }
    else
    {
        DEBUG_CALL(printf("select error: %d\n", err));
    }
}


