/*  system.c
 *
 * Copyright (c) 2010 SeaD <sead at deep.perm.ru>
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 *  $Id: system.c,v 1.15 2010/07/12 03:46:36 sead Exp $
 *
 */

#include <pwd.h>
#include <signal.h>
#include <paths.h>
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>

#ifdef __linux__
# include <time.h>
#endif  /* __linux__ */

#include "ipguard.h"

static int SIGNAL = 0;
static FILE *log_fp = NULL, *pid_fp = NULL;

void exit_ipguard(int reason) {
    if (verbose) {
        snprintf(s, 64, "Exit on reason %d (pid: %u)", reason, getpid());
        log_str(NOTICE, s, "");
    }
    packet_destroy();
    pid_unlink();
    log_close();
    exit(reason);
}

void sig_init(void) {
    if (signal(SIGINT, sig_func) == SIG_ERR) {
        log_str(ERROR, "signal(SIGINT):", strerror(errno));
        exit(EXIT_FAILURE);
    }
    if (signal(SIGTERM, sig_func) == SIG_ERR) {
        log_str(ERROR, "signal(SIGTERM):", strerror(errno));
        exit(EXIT_FAILURE);
    }
    if (signal(SIGHUP, sig_func) == SIG_ERR) {
        log_str(ERROR, "signal(SIGHUP):", strerror(errno));
        exit(EXIT_FAILURE);
    }
    if (signal(SIGUSR1, sig_func) == SIG_ERR) {
        log_str(ERROR, "signal(SIGUSR1):", strerror(errno));
        exit(EXIT_FAILURE);
    }
    if (signal(SIGUSR2, sig_func) == SIG_ERR) {
        log_str(ERROR, "signal(SIGUSR2):", strerror(errno));
        exit(EXIT_FAILURE);
    }
}

void sig_func(int signal) {
    if (verbose) {
        snprintf(s, 4, "%d", signal);
        log_str(NOTICE, "SIGNAL received", s);
    }
    SIGNAL = signal;
    if ((SIGNAL == SIGINT) || (SIGNAL == SIGTERM)) exit_ipguard(SIGNAL);
}

void sig_catch(void) {
    switch (SIGNAL) {
        case SIGHUP: ethers_reinit(); if (!getuid()) log_reopen(); SIGNAL = 0; break;
        case SIGUSR1: pair_dump(); if (buffer_num) buffer_dump(); stat_dump(); SIGNAL = 0; break;
        case SIGUSR2: buffer_dump2ethers(); SIGNAL = 0; break;
    }
}

void log_open(void) {
    if (!(log_fp = fopen(log_name, "a"))) {
        fprintf(stderr, "fopen(%s): %s\n", log_name, strerror(errno));
        exit(EXIT_FAILURE);
    }
}

void log_str(int pri, char *ent, char *err) {
    char p[10];

    switch (pri) {
        case ERROR: strncpy(p, "error", 10); break;
        case WARNING: strncpy(p, "warning", 10); break;
        case NOTICE: strncpy(p, "notice", 10); break;
        case INFO: strncpy(p, "info", 10); break;
        default: strncpy(p, "unknown", 10); break;
    }
    if (debug > 1) fprintf(stderr, "%s ", time_get());
    if (debug) fprintf(stderr, "%s %s\n", ent, err);
    fprintf(log_fp, "%s %s %s %s\n", time_get(), p, ent, err);
    fflush(log_fp);
}

void log_close(void) {
    if (fclose(log_fp)) {
        fprintf(stderr, "fclose(%s): %s\n", log_name, strerror(errno));
    }
}

void log_reopen(void) {
    log_close(); log_open();
    log_str(NOTICE, "Log file reopened:", log_name);
}

void pid_check(void) {
    struct stat ps;
    int pid = 0;

    if (stat(pid_name, &ps) == -1) {
        if (errno == ENOENT) return;
        else {
            snprintf(s, 128, "stat(%s):", pid_name);
            log_str(ERROR, s, strerror(errno));
            exit(EXIT_FAILURE);
        }
    } else {
        if (!(pid_fp = fopen(pid_name, "r"))) {
            snprintf(s, 128, "fopen(%s):", pid_name);
            log_str(WARNING, s, strerror(errno));
        }
        fscanf(pid_fp, "%d", &pid);
        fclose(pid_fp);
        if (pid) {
            if (kill(pid, 0)) {
                snprintf(s, 64, "Removing stale pid: %u", pid);
                log_str(WARNING, s, "");
                pid_unlink();
            } else {
                snprintf(s, 64, "Already running, pid: %u", pid);
                log_str(ERROR, s, "");
                exit(EXIT_FAILURE);
            }
        } else {
            snprintf(s, 64, "Wrong pid file found: %s", pid_name);
            log_str(ERROR, s, "");
            exit(EXIT_FAILURE);
        }
    }
}

void pid_creat(void) {
    pid_check();
    if (!(pid_fp = fopen(pid_name, "w"))) {
        snprintf(s, 128, "fopen(%s):", pid_name);
        log_str(ERROR, s, strerror(errno));
        exit(EXIT_FAILURE);
    }
    fprintf(pid_fp, "%d\n", getpid());
    fclose(pid_fp);
}

void pid_unlink(void) {
    if (unlink(pid_name) == -1) {
        snprintf(s, 128, "unlink(%s):", pid_name);
        log_str(NOTICE, s, strerror(errno));
    }
}

void daemonize(void) {
    int pid, null;

    if ((pid = fork()) == -1) {
        log_str(ERROR, "fork():", strerror(errno));
        exit(EXIT_FAILURE);
    } else if (pid > 0) exit(EXIT_SUCCESS);

    if ((null = open(_PATH_DEVNULL, O_RDWR)) == -1) {
        log_str(ERROR, "open(_PATH_DEVNULL):", strerror(errno));
        exit(EXIT_FAILURE);
    }

    if (dup2(null, 0) == -1) {
        log_str(ERROR, "dup2(STDIN):", strerror(errno));
        exit(EXIT_FAILURE);
    }
    if (dup2(null, 1) == -1) {
        log_str(ERROR, "dup2(STDOUT):", strerror(errno));
        exit(EXIT_FAILURE);
    }
    if (dup2(null, 2) == -1) {
        log_str(ERROR, "dup2(STDERR):", strerror(errno));
        exit(EXIT_FAILURE);
    }

    if (setsid() == -1) {
        log_str(ERROR, "setsid():", strerror(errno));
        exit(EXIT_FAILURE);
    }

    if (chdir("/") == -1) {
        log_str(ERROR, "chdir(/):", strerror(errno));
        exit(EXIT_FAILURE);
    }
}

void set_user(void) {
    struct passwd *pw;
    int uid = 0, gid = 0;

    if ((pw = getpwnam(suser))) {
        uid = pw->pw_uid;
        gid = pw->pw_gid;
    } else if (!(uid = atoi(suser))) {
        pw = getpwuid(uid);
        gid = pw->pw_gid;
    } else {
        log_str(ERROR, "Invalid user", suser);
        exit(EXIT_FAILURE);
    }

    if (setgid(gid) == -1) {
        log_str(ERROR, "setgid():", strerror(errno));
        exit(EXIT_FAILURE);
    }
    if (setuid(uid) == -1) {
        log_str(ERROR, "setuid():", strerror(errno));
        exit(EXIT_FAILURE);
    }

    snprintf(s, 128, "%u/%u", uid, gid);
    if (verbose) log_str(NOTICE, "Running as uid/gid:", s);
}

void mac_rand(char *mac) {
    char c[] = "  ";
    register int n;

    for (n = 0; n < 18; n++)
        if (mac[n] == 'x') { snprintf(c, 2, "%1x", rand() % 16); mac[n] = c[0]; }
}

void mac_regen(char *mac) {
    static time_t count = 0;
    time_t t;
    char c[] = "  ";
    register int n;

    if ((t = time(&t)) < count) return;
    count = t + fake_regen;

    if (debug > 1) fprintf(stderr, "FAKE REGEN: pre %-17s ", mac);

    for (n = 0; n < 18; n++)
        if ((mac[n] >= '0' && mac[n] <= '9') ||
            (mac[n] >= 'A' && mac[n] <= 'F') ||
            (mac[n] >= 'a' && mac[n] <= 'f'))
            { snprintf(c, 2, "%1x", rand() % 16); mac[n] = c[0]; }

    if (debug > 1) fprintf(stderr, "new %-17s\n", mac);
}

char *time_get(void) {
    static char cur_time[16];
    time_t t;

    t = time(&t);
    strncpy(cur_time, (char *) ctime(&t) + 4, 15);
    cur_time[15] = '\0';
    return cur_time;
}    

void ethers_stat(void) {
    struct stat es;
    static time_t ethers_mtime = 0, count = 0;
    time_t t;

    if ((t = time(&t)) < count) return;
    count = t + ethers_update;

    if (stat(ethers_name, &es) == -1) {
        snprintf(s, 128, "stat(%s):", ethers_name);
        log_str(WARNING, s, strerror(errno));
        return;
    }

    if (!ethers_mtime) ethers_mtime = es.st_mtime;

    if (ethers_mtime < es.st_mtime) {
        ethers_reinit(); ethers_mtime = es.st_mtime;
    }
}
