/*
   Small utility to execute a performance test for network services:

      Usage: netperftest hostport #processes #connections_per_process msg...

   The hostport specifies the to be tested service per a host/port combination
   in conformance to RFC 2396.

   netperftest hammers the network service with many concurrent sessions and
   measures the average session time. In total #processes processes are
   spawned, each of them having #connections_per_process connections to the
   service.

   netperftest expects a protocol of the form where the service sends
   a greeting upon which a request can be sent, which in turn is
   acknowledged by another response packet. netperftest is not familiar
   with any particular protocol. Instead a series of messages is specified
   which are subsequently sent to the service whenever a package was
   received.

   Within the message strings the escape sequences \n, \r, \t, and \\
   are interpreted as linefeed, carriage-return, tab and a backslash,
   respectively. In addition, %p is expanded to the process id of the
   process that opens a connection. %% can be used to have a single %.

   Example:

      netperftest 0:25 10 100 "HELO localhost\r\n" "QUIT\r\n"

   This example forks 10 processes where each process connects the
   local SMTP service on port 25 100 times. Firstly, the SMTP
   greeting is awaited, then the HELO command is sent. When the
   response to HELO is received, QUIT is sent and the connection
   closed.

   Sample output:

      total time = 1.1500 s
      time per connection: 11.5000 ms
      throughput: 869.565 connections/s

   The first line gives the overall real time which was consumed until
   all processes were finished. Time per connection specifies the
   average time for a connection, i.e. the total time divided by
   #connections_per_process.

   Please use this tool responsively. Do not hammer any network services
   without authorization.
*/
#include <stdio.h>
#include <stdlib.h>
#include <stralloc.h>
#include <sys/times.h>
#include <sys/wait.h>
#include <unistd.h>
#include <afblib/hostport.h>
#include <afblib/inbuf.h>
#include <afblib/outbuf.h>

/* return real time in seconds since start of the process */
double walltime() {
   static int ticks_per_second = 0;
   if (!ticks_per_second) {
      ticks_per_second = sysconf(_SC_CLK_TCK);
   }
   struct tms timebuf;
   /* times returns the number of real time ticks passed since start */
   return (double) times(&timebuf) / ticks_per_second;
}

void construct_msg(stralloc* msg, const char* arg) {
   msg->len = 0;
   for (const char* cp = arg; *cp; ++cp) {
      char ch;
      if (*cp == '\\') {
	 ++cp;
	 switch (*cp) {
	    case 0:
	       fprintf(stderr, "escape sequence at end of string\n"); 
	       exit(1);
	    case 'n':
	       ch = '\n';
	       break;
	    case 'r':
	       ch = '\r';
	       break;
	    case 't':
	       ch = '\t';
	       break;
	    case '\\':
	       ch = '\\';
	       break;
	    default:
	       fprintf(stderr, "unknown escape sequence: \\%c\n", *cp);
	       exit(1);
	 }
      } else if (*cp == '%') {
	 ++cp;
	 switch (*cp) {
	    case '%':
	       ch = '%';
	       break;
	    case 'p':
	       stralloc_catint(msg, getpid());
	       continue;
	    default:
	       fprintf(stderr, "unknown fill-in sequence: %%%c\n", *cp);
	       exit(1);
	 }
      } else {
	 ch = *cp;
      }
      stralloc_append(msg, &ch);
   }
}

int main (int argc, char** argv) {
   char* cmdname = *argv++; --argc;
   if (argc < 4) {
      fprintf(stderr,
	 "Usage: %s hostport #processes #connections_per_process msg...\n",
	 cmdname);
      exit(1);
   }
   char* hostport_string = *argv++; --argc;
   hostport hp;
   if (!parse_hostport(hostport_string, &hp, 0)) {
      fprintf(stderr, "%s: hostport in conformance to RFC 2396 expected\n",
         cmdname);
      exit(1);
   }
   int nof_processes = atoi(*argv++); --argc;
   int nof_connections = atoi(*argv++); --argc;
   if (nof_processes <= 0 || nof_connections <= 0) {
      fprintf(stderr, "%s: invalid argument\n", cmdname); exit(1);
   }

   stralloc* msg = calloc(argc, sizeof(stralloc));
   if (!msg) {
      perror("calloc"); exit(1);
   }

   double t0 = walltime();
   for (int pi = 0; pi < nof_processes; ++pi) {
      pid_t pid = fork();
      if (pid < 0) {
	 perror("fork"); exit(1);
      }
      if (pid == 0) {
	 for (int ci = 0; ci < nof_connections; ++ci) {
	    /* we need to do this here per-process as %p should
	       be the pid of the child, not the parent */
	    for (int ai = 0; ai < argc; ++ai) {
	       construct_msg(&msg[ai], argv[ai]);
	    }
	    int fd; int retry;
	    for (retry = 0; retry < 20; ++retry) {
	       if ((fd = socket(hp.domain, SOCK_STREAM, hp.protocol)) < 0) {
		  perror("socket"); exit(1);
	       }
	       if (connect(fd, (struct sockaddr *) &hp.addr,
		     hp.namelen) == 0) {
		  break;
	       }
	       close(fd);
	    }
	    if (retry == 20) {
	       perror("connect fails persistently"); exit(1);
	    }
	    inbuf in = {fd}; inbuf_alloc(&in, 8192); outbuf out = {fd};
	    for (int i = 0; i < argc; ++i) {
	       if (inbuf_getchar(&in) < 0) break; // read a package
	       in.pos = in.buf.len; // mark entire input buffer as read
	       outbuf_write(&out, msg[i].s, msg[i].len);
	       outbuf_flush(&out);
	    }
	    inbuf_free(&in); outbuf_free(&out);
	    close(fd);
	 }
	 exit(0);
      }
   }
   int wstat; int count = 0;
   while (wait(&wstat) > 0) {
      if (WIFEXITED(wstat) && WEXITSTATUS(wstat) == 0) ++count;
   }
   if (count == nof_processes) {
      double t1 = walltime() - t0;
      double time_per_connection = t1 / nof_connections;
      printf("total time = %.4lf s\n", t1);
      printf("time per connection: %.4lf ms\n", time_per_connection * 1000);
      double throughput = nof_connections * nof_processes / t1;
      printf("throughput: %lg connections/s\n", throughput);
   } else {
      if (count == 0) {
	 printf("no results as all processes failed\n");
      } else {
	 printf("no results as %d processes failed\n", nof_processes - count);
      }
   }
}
