#include <stdio.h>
#include <stdlib.h>
#include <sys/poll.h>
#include <unistd.h>
#include <afblib/hostport.h>

/* run chat for an established connection

   protocol formats:

      +
      |
      +

      (an empty message is used by the client to open a connection)

      +-----+----------------------------+
      | 'm' | message                    |
      +-----+----------------------------+

      (regular message packet, needs to be acknowledged
      with an empty packet)

      +-----+
      | 'c' |
      +-----+

      (packet that signals that the connection is closed)
*/
void chat_session(int fd) {
   struct pollfd pollfds[2] = {
      {.fd = fd, .events = POLLIN}, /* incoming packets from the socket */
      {.fd = 0, .events = POLLIN},  /* incoming input from stdin */
   };
   /* timeout is set to a positive value if we wait for
      an empty acknowledge packet;
      timeouts counts the number of timeouts */
   int timeout = -1; int timeouts = 0;
   char inbuf[BUFSIZ]; ssize_t inbytes = 0;
   for(;;) {
      int res = poll(pollfds, 2, timeout);
      if (res < 0) return;
      if (res > 0) {
	 /* some input is available */
	 for (int index = 0; index < 2; ++index) {
	    if (pollfds[index].revents == 0) continue;
	    if (index == 0) {
	       /* read from socket */
	       char buf[BUFSIZ];
	       ssize_t nbytes = read(fd, buf, sizeof buf);
	       if (nbytes < 0) {
		  close(fd); printf("Closed.\n"); return;
	       }
	       if (nbytes == 0) {
		  timeouts = 0; timeout = -1; /* acknowledged */
		  // printf("Acknowledged.\n");
		  continue;
	       }
	       if (buf[0] == 'm') {
		  /* acknowledge it */
		  if (write(fd, buf, 0) < 0) {
		     close(fd); return;
		  }
		  /* copy it to stdout */
		  write(1, buf + 1, nbytes - 1);
	       } else {
		  printf("Closed.\n"); return;
	       }
	    } else {
	       /* read from stdin */
	       inbytes = read(1, inbuf + 1, sizeof inbuf - 1);
	       if (inbytes <= 0) {
		  write(fd, "c", 1); close(fd); return;
	       }
	       inbuf[0] = 'm'; timeout = 100; ++inbytes;
	       if (write(fd, inbuf, inbytes) < 0) break;
	    }
	 }
      } else {
	 /* timeout: retransmit last input line */
	 // printf("Timed out.\n");
	 ++timeouts;
	 if (timeouts > 10) {
	    printf("Too many timeouts, giving up.\n"); break;
	 }
	 if (write(fd, inbuf, inbytes) < 0) break;
      }
   }
}

int main(int argc, char** argv) {
   char* cmdname = *argv++; --argc;
   if (argc != 2 || argv[0][0] != '-' ||
	 (argv[0][1] != 'c' && argv[0][1] != 's') || argv[0][2]) {
      fprintf(stderr, "Usage: %s (-c|-s) hostport\n", cmdname);
      exit(1);
   }
   char mode = argv[0][1]; ++argv; --argc;
   char* hostport_string = *argv++; --argc;
   hostport hp;
   if (!parse_hostport(hostport_string, &hp, 22022)) {
      fprintf(stderr, "%s: invalid hostport: %s\n",
	 cmdname, hostport_string);
      exit(1);
   }

   int fd;
   if ((fd = socket(hp.domain, SOCK_DGRAM, hp.protocol)) < 0) {
      perror("socket"); exit(1);
   }
   if (mode == 'c') {
      /* client mode */
      if (connect(fd, (struct sockaddr*) &hp.addr, hp.namelen) < 0) {
	 perror("connect"); exit(1);
      }
      /* send initial empty packet to the server */
      char buf[4];
      ssize_t nbytes = write(fd, buf, 0);
      if (nbytes < 0) {
	 perror("write"); exit(1);
      }
   } else {
      /* server mode */
      if (bind(fd, (struct sockaddr*) &hp.addr, hp.namelen) < 0) {
	 perror("bind"); exit(1);
      }
      /* wait for the first packet of the client */
      char buf[BUFSIZ];
      ssize_t nbytes = recvfrom(fd, buf, sizeof buf, 0,
	 (struct sockaddr*) &hp.addr, &hp.namelen);
      /* now we are just talking to this client,
         hence we can use connect() to tie this
	 socket to our peer */
      if (connect(fd, (struct sockaddr*) &hp.addr, hp.namelen) < 0) {
	 perror("connect"); exit(1);
      }
   }
   printf("Connected.\n");
   chat_session(fd);
}
