#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include "sites.h"
#include "network_dsm.h"

static int *fds;

/* Central copy of memory: */
static char *mems[NUM_MEMS];
static int mem_sizes[NUM_MEMS];

/* Address that site clients use for each block of memory: */
static char *mem_addrs[NUM_MEMS];

/* Track changes to main memory that need to be sent to each site: */
static char *mem_change_masks[NUM_MEMS][NUM_SITES];

/* Semaphores: */
static int semas[NUM_MEMS];
static int sema_wait_counts[NUM_MEMS];
static int sema_waits[NUM_MEMS][NUM_SITES];

static void init_all()
{
  int i;
  int buffer[2];

  for (i = 0; i < NUM_SITES; i++) {
    int mem_counter = 0, sema_counter = 0;

    /* Read messages from client site until it says that it's done
       initializing. */
    while (1) {
      read_all(fds[i], buffer, 2 * sizeof(int));
      if (buffer[0] == MSG_INIT_DONE)
        break;
      else if (buffer[0] == MSG_ALLOC_MEMORY) {
        int n = mem_counter++;
        char *addr;

        if (mem_counter > NUM_MEMS) {
          fprintf(stderr, "more memory blocks than expected");
          abort();
        }
        /* If site 0, then create the central copy */
        if (i == 0) {
          mem_sizes[n] = buffer[1];
          mems[n] = (char *)malloc(buffer[1]);
          memset(mems[n], 0, buffer[1]);
        } else {
          if (mem_sizes[n] != buffer[1]) {
            fprintf(stderr, "different memory-block size than expected");
            abort();
          }
        }
        /* At first, no changes to send to the site: */
        mem_change_masks[n][i] = (char *)malloc(buffer[1]);
        memset(mem_change_masks[n][i], 0, buffer[1]);

        /* Report the memory-block id back to the client: */
        buffer[0] = n;
        write_all(fds[i], buffer, sizeof(int));

        /* Also report back the address to use for this block
           of memory; it will be NULL at first, so the first site
           will pick the actual address. */
        addr = mem_addrs[n];
        write_all(fds[i], &addr, sizeof(void*));

        /* Read the address used by the site, and if the site isn't
           the first one, check that it is the same as all other
           sites. */
        read_all(fds[i], &addr, sizeof(void*));
        if (i == 0)
          mem_addrs[n] = addr;
        else if (addr != mem_addrs[n]) {
          fprintf(stderr, "inconsistent addresses: %p vs. %p",
                  mem_addrs[n], addr);
          abort();
        }
      } else if (buffer[0] == MSG_ALLOC_SEMA) {
        int n = sema_counter++;
        if (mem_counter > NUM_SEMAS) {
          fprintf(stderr, "more semaphores than expected");
          abort();
        }
        /* For the first site, set the initial state; for later sites,
           make sure they're consistent with the first site */
        if (i == 0) {
          semas[n] = buffer[1];
        } else {
          if (semas[n] != buffer[1]) {
            fprintf(stderr, "different semaphore initial state than expected");
            abort();
          }
        }
        /* Report back the semaphore id: */
        buffer[0] = n;
        write_all(fds[i], buffer, sizeof(int));
      } else {
        fprintf(stderr, "unknown message: %d\n", buffer[0]);
        abort();
      }
    }
  }
}

static int find_mem(void *addr)
{
  int k;
        
  for (k = 0; k < NUM_MEMS; k++) {
    if (((unsigned long)addr >= (unsigned long)mem_addrs[k])
        && ((unsigned long)addr < (unsigned long)mem_addrs[k] + mem_sizes[k]))
      return k;
  }

  fprintf(stderr, "address not found: %p\n", addr);
  abort();
  return 0;
}

static void handle_messages()
{
  int i, max_fd, buffer[1];
  fd_set rds;

  /* Loop while any site is still running: */
  while (1) {
    FD_ZERO(&rds);
    max_fd = 0;

    /* Collect fds to detect when a message is available: */
    for (i = 0; i < NUM_SITES; i++) {
      if (fds[i]) {
        FD_SET(fds[i], &rds);
        if (fds[i] > max_fd) max_fd = fds[i];
      }
    }
    if (!max_fd) {
      /* No sites left */
      return;
    }

    /* Wait for a message from any client site: */
    select(max_fd+1, &rds, NULL, NULL, NULL);

    for (i = 0; i < NUM_SITES; i++) {
      if (fds[i]) {
        if (FD_ISSET(fds[i], &rds)) {
          /* There was a message from site `i': */
          read_all(fds[i], buffer, sizeof(int));
          switch(buffer[0]) {
          case MSG_SIGNAL_SEMA:
            {
              int n;
              read_all(fds[i], buffer, sizeof(int));
              n = buffer[0];
              if (sema_wait_counts[n] > 0) {
                /* Let some waiting site on the queue continue: */
                write_all(fds[sema_waits[n][0]], buffer, sizeof(int));
                sema_wait_counts[n]--;
                memmove(&sema_waits[n][0], &sema_waits[n][1], sizeof(int) * sema_wait_counts[n]);
              } else {
                /* No site waiting, so just increment the semaphore count: */
                semas[n]++;
              }
            }
            break;
          case MSG_WAIT_SEMA:
            {
              int n;
              read_all(fds[i], buffer, sizeof(int));
              n = buffer[0];
              if (semas[n]) {
                /* Decrement count and let site continue: */
                semas[n]--;
                write_all(fds[i], buffer, sizeof(int));
              } else {
                /* Add site to wait queue (and don't reply for now): */
                sema_waits[n][sema_wait_counts[n]] = i;
                sema_wait_counts[n]++;
              }
            }
            break;
          case MSG_READ:
            {
              int n, size, delta;
              void *addr;

              /* Get address and size */
              read_all(fds[i], &addr, sizeof(void*));
              read_all(fds[i], &size, sizeof(int));

              n = find_mem(addr);
              delta = (char *)addr - mem_addrs[n];

              /* Send current copy along with the mask that indicates
                 which bytes were changed by other sites: */
              write_all(fds[i], mems[n] + delta, size);
              write_all(fds[i], mem_change_masks[n][i] + delta, size);
              
              /* Since all changes now have been sent, clear the
                 change mask: */
              memset(mem_change_masks[n][i] + delta, 0, size);
            }
            break;
          case MSG_WRITE:
            {
              int n, size, j, k, delta;
              void *addr;
              char *data, *mask;

              /* Get address and size */
              read_all(fds[i], &addr, sizeof(void*));
              read_all(fds[i], &size, sizeof(int));

              n = find_mem(addr);
              delta = (char *)addr - mem_addrs[n];

              /* Read copy of memory from client site along with the mask
                 indicating which bytes it changed: */
              data = (char *)malloc(size);
              mask = (char *)malloc(size);
              read_all(fds[i], data, size);
              read_all(fds[i], mask, size);

              /* For each changed byte, update central copy of memory
                 and add byte to change mask of all other sites: */
              for (k = 0; k < size; k++) {
                if (mask[k]) {
                  mems[n][k+delta] = data[k];
                  for (j = 0; j < NUM_SITES; j++) {
                    if (j != i) {
                      mem_change_masks[n][j][k+delta] = 1;
                    }
                  }
                }
              }

              free(data);
              free(mask);
            }
            break;
          case MSG_DONE:
            {
              /* Site `i' is done: */
              close(fds[i]);
              fds[i] = 0;
            }
            break;
          default:
            fprintf(stderr, "unknown message: %d\n", buffer[0]);
            abort();
            break;
          }
        }
      }
    }
  }
}

int main(int argc, char **argv, char **envp)
{
  fds = start_server(argv[1], NUM_SITES);

  init_all();

  handle_messages();

  return 1;
}
