/* implements lookup, bind, err, close & poll functions. Also defines the 
   utp_prot structure */

#include <asm/system.h>   
#include <asm/uaccess.h>
#include <linux/types.h>
#include <linux/sched.h>
#include <linux/fcntl.h>
#include <linux/socket.h>   
#include <linux/sockios.h>
#include <linux/in.h>
#include <linux/errno.h>
#include <linux/timer.h>
#include <linux/termios.h>
#include <linux/mm.h>
#include <linux/config.h>
#include <linux/inet.h> 
#include <linux/netdevice.h>
#include <linux/poll.h>
#include <net/snmp.h>
#include <net/ip.h>
#include <net/protocol.h>
#include <net/tcp.h>
#include <linux/skbuff.h>
#include <net/sock.h>
#include <net/icmp.h>
#include <net/utp.h>
#include <linux/utp/utppcb.h>
#include <linux/utp/utppdu.h>

/* reffered to udp.c */

/*
 *	All we need to do is get the socket, and then do a checksum. 
 */

struct sock *utp_v4_lookup_longway(u32 saddr, u16 sport, u32 daddr, u16 dport, 
int dif)
{
        struct sock *sk, *result = NULL;
        unsigned short hnum = ntohs(dport);
	int score=0;
        //int badness = -1;
        int badness = 0;


        for(sk = utp_hash[hnum & (UTP_HTABLE_SIZE - 1)]; sk != NULL; sk = sk->next) {
                if((sk->num == hnum) && !(sk->dead && (sk->state == TCP_CLOSE)))
 {
                        score = 1;
                        if(sk->rcv_saddr) {
                                if(sk->rcv_saddr != daddr)
                                        continue;
                                score++;
                        }
                        if(sk->daddr) {
                                if(sk->daddr != saddr)
                                        continue;
                                score++;
                        }
                        if(sk->dport) {
                                if(sk->dport != sport)
                                        continue;
                                score++;
                        }
                        /*if(sk->bound_dev_if) {
                                if(sk->bound_dev_if != dif)
                                        continue;
                                score++;
                        }*/
                        if(score == 4) {
                                result = sk;
                                break;
                        } else if(score > badness) {
                                result = sk;
                                badness = score;
                        }
                }
        }
        return result;
}

/* Last hit UTP socket cache, this is ipv4 specific so make it static. */
static u32 uth_cache_saddr, uth_cache_daddr;
static u16 uth_cache_dport, uth_cache_sport;
static struct sock *uth_cache_sk = NULL; 

static void utp_v4_hash(struct sock *sk)
{
        struct sock **skp;      
        int num = sk->num;      

        num &= (UTP_HTABLE_SIZE - 1);   
        skp = &utp_hash[num];
 
        SOCKHASH_LOCK();
        sk->next = *skp;
        *skp = sk;
        sk->hashent = num;
        SOCKHASH_UNLOCK();
}

void utp_v4_unhash(struct sock *sk)
{ 
        struct sock **skp;
        int num = sk->num;

        num &= (UTP_HTABLE_SIZE - 1);
        skp = &utp_hash[num];

        SOCKHASH_LOCK();
        while(*skp != NULL) {
                if(*skp == sk) {
                        *skp = sk->next;
			/* added for testing */
			if(*skp == sk)
				*skp = NULL;
                        break;
                }
                skp = &((*skp)->next);
        }
        if(uth_cache_sk == sk)
                uth_cache_sk = NULL;
        SOCKHASH_UNLOCK();
}

static void utp_v4_rehash(struct sock *sk)
{
        struct sock **skp;
        int num = sk->num;
        int oldnum = sk->hashent;

        num &= (UTP_HTABLE_SIZE - 1);
        skp = &utp_hash[oldnum];

        SOCKHASH_LOCK();
        while(*skp != NULL) {
                if(*skp == sk) {
                        *skp = sk->next;
                        break;
                }
                skp = &((*skp)->next);
		if (signal_pending(current))
                        break;
        }
        sk->next = utp_hash[num];
        utp_hash[num] = sk;
        sk->hashent = num;
        if(uth_cache_sk == sk)
                uth_cache_sk = NULL;
        SOCKHASH_UNLOCK();
}

/* lookup function for utp */
__inline__ struct sock *utp_v4_lookup(u32 saddr, u16 sport, u32 daddr, u16 dport, int dif)
{
        struct sock *sk;

        if(!dif && uth_cache_sk          &&
           uth_cache_saddr == saddr      &&
           uth_cache_sport == sport      &&
           uth_cache_dport == dport      &&
           uth_cache_daddr == daddr)
                return uth_cache_sk;

        sk = utp_v4_lookup_longway(saddr, sport, daddr, dport, dif);
        if(!dif) {
                uth_cache_sk     = sk;
                uth_cache_saddr  = saddr;
                uth_cache_daddr  = daddr;
                uth_cache_sport  = sport;
                uth_cache_dport  = dport;
        }
        return sk;
}


/*
 * This routine is called by the ICMP module when it gets some 
 * sort of error condition.  If err < 0 then the socket should
 * be closed and the error returned to the user.  If err > 0 
 * it's just the icmp type << 8 | icmp code.
 */

void utp_err(struct sk_buff *skb, unsigned char *dp, int len)
{
        struct iphdr *iph = (struct iphdr*)dp;
        struct utphdr *uh = (struct utphdr*)(dp+(iph->ihl<<2));
        int type = skb->h.icmph->type;
        int code = skb->h.icmph->code;
        struct sock *sk;
        int harderr;
        u32 info;
        int err;

        if (len < (iph->ihl<<2)+sizeof(struct utphdr)) {
                icmp_statistics.IcmpInErrors++;
                return;
        }

        sk = utp_v4_lookup(iph->daddr, uh->dest, iph->saddr, uh->source, skb->dev->ifindex);
        if (sk == NULL) {
                icmp_statistics.IcmpInErrors++;
                return; /* No socket for error */
        }

        err = 0;
        info = 0;
        harderr = 0;

        switch (type) {
        default:
        case ICMP_TIME_EXCEEDED:
                err = EHOSTUNREACH;
                break;
        case ICMP_SOURCE_QUENCH:
                return;
        case ICMP_PARAMETERPROB:
                err = EPROTO;
                info = ntohl(skb->h.icmph->un.gateway)>>24;
                harderr = 1;
                break;
        case ICMP_DEST_UNREACH:
                if (code == ICMP_FRAG_NEEDED) { /* Path MTU discovery */
                        if (sk->ip_pmtudisc != IP_PMTUDISC_DONT) {
                                err = EMSGSIZE;
                                info = ntohs(skb->h.icmph->un.frag.mtu);
                                harderr = 1;
                                break;
                        }
                        return;
                }
                err = EHOSTUNREACH;
                if (code <= NR_ICMP_UNREACH) {
                        harderr = icmp_err_convert[code].fatal;
                        err = icmp_err_convert[code].errno;
                }
                break;
        }

        /*
         *      Various people wanted BSD UDP semantics. Well they've come
         *      back out because they slow down response to stuff like dead
         *      or unreachable name servers and they screw term users something
         *      chronic. Oh and it violates RFC1122. So basically fix your
         *      client code people.
         */

        /*
         *      RFC1122: OK.  Passes ICMP errors back to application, as per
         *      4.1.3.3. After the comment above, that should be no surprise.
         */

        if (!harderr && !sk->ip_recverr)
                return;

        /*
         *      4.x BSD compatibility item. Break RFC1122 to
         *      get BSD socket semantics.
         */
        if(sk->bsdism && sk->state!=TCP_ESTABLISHED)
                return;

        if (sk->ip_recverr)
                ip_icmp_error(sk, skb, err, uh->dest, info, (u8*)(uh+1));
        sk->err = err;
        sk->error_report(sk);
}

static int utp_verify_bind(struct sock *sk, unsigned short snum)
{
        struct sock *sk2;
        int retval = 0, sk_reuse = sk->reuse;
 
        SOCKHASH_LOCK();
        for(sk2 = utp_hash[snum & (UTP_HTABLE_SIZE - 1)]; sk2 != NULL; sk2 = sk2->next) {       
                if((sk2->num == snum) && (sk2 != sk)) {
                        unsigned char state = sk2->state;
                        int sk2_reuse = sk2->reuse;
 
                        /* Two sockets can be bound to the same port if they're
                         * bound to different interfaces.
                         */
 
                        if(!sk2->rcv_saddr || !sk->rcv_saddr) {
                                if((!sk2_reuse)                 ||
                                   (!sk_reuse)                  ||
                                   (state == TCP_LISTEN)) {
                                        retval = 1;
                                        break;
                                }
                        } else if(sk2->rcv_saddr == sk->rcv_saddr) {
                                if((!sk_reuse)                  ||
                                   (!sk2_reuse)                 ||
                                   (state == TCP_LISTEN)) {
                                        retval = 1;
                                        break;
                                }
                        }
                }
        }
        SOCKHASH_UNLOCK();
        return retval;
}



static inline int utp_lport_inuse(u16 num)
{
        struct sock *sk = utp_hash[num & (UTP_HTABLE_SIZE - 1)];

        for(; sk != NULL; sk = sk->next) 
	{
                if(sk->num == num)
                        return 1;
        }
        return 0;
}

/* Shared by v4/v6 tcp. */      
unsigned short utp_good_socknum(void)   
{
        int result;
        static int start = 0;
        int i, best, best_size_so_far;  

        SOCKHASH_LOCK();                
        if (start > sysctl_local_port_range[1] || start < sysctl_local_port_range[0])  
	start = sysctl_local_port_range[0]; 

        best_size_so_far = 32767;       /* "big" num */ 
        best = result = start;          

        for(i = 0; i < UTP_HTABLE_SIZE; i++, result++) { 
                struct sock *sk;        
                int size;               

                sk = utp_hash[result & (UTP_HTABLE_SIZE - 1)];

                if(!sk) {
                        if (result > sysctl_local_port_range[1])
                                result = sysctl_local_port_range[0]
                                        + ((result - sysctl_local_port_range[0])
 & (UTP_HTABLE_SIZE - 1));
                        goto out;
                }

                /* Is this one better than our best so far? */
                size = 0;
                do {
                        if(++size >= best_size_so_far)
                                goto next;
                } while((sk = sk->next) != NULL);
                best_size_so_far = size;
                best = result;
        next:
        }

        result = best;

        for(;; result += UTP_HTABLE_SIZE) {
                /* Get into range (but preserve hash bin)... */
                if (result > sysctl_local_port_range[1])
                        result = sysctl_local_port_range[0]
                                + ((result - sysctl_local_port_range[0]) & (UTP_HTABLE_SIZE - 1));
                if (!utp_lport_inuse(result))
                        break;
        }
out:
        start = result;
        SOCKHASH_UNLOCK();
        return result;
}

void utp_close(struct sock *sk, long timeout) 
{
        sk->state = TCP_CLOSE;          
        utp_v4_unhash(sk);              
        sk->dead = 1;                   
        destroy_sock(sk);               
	printk("At the end of utp close\n");
}

static unsigned int utp_listen_poll(struct sock *sk, poll_table *wait)
{
	pcb_t *p_pcb = &pcbs[1];

        lock_sock(sk);
	if(p_pcb->up_rdyfirst != NULL)
                return POLLIN | POLLRDNORM;
        release_sock(sk);
        return 0;
}

unsigned int utp_poll(struct file * file, struct socket *sock, poll_table *wait){
        unsigned int mask;
        struct sock *sk = sock->sk;

        poll_wait(file, sk->sleep, wait);
        if (sk->state == TCP_LISTEN)
                return utp_listen_poll(sk, wait);

        mask = 0;
 /* exceptional events? */
        if (sk->err || !skb_queue_empty(&sk->error_queue))
                mask |= POLLERR;
        if (sk->shutdown & RCV_SHUTDOWN)
                mask |= POLLHUP;

        /* readable? */
        if (!skb_queue_empty(&sk->receive_queue))
                mask |= POLLIN | POLLRDNORM;

        /* writable? */
        if (sock_writeable(sk))
                mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
        else
                sk->socket->flags |= SO_NOSPACE;

        return mask;
}

/* for sendmsg, recvmsg, close, connect we will use the system calls as specified in the specification */

struct proto utp_prot = {
	(struct sock *)&utp_prot,	/* sklist_next */
	(struct sock *)&utp_prot,	/* sklist_prev */
	utp_close,			/* close */
	NULL,				/* connect */
	NULL,				/* accept */
	NULL,				/* retransmit */
	NULL,				/* write_wakeup */
	NULL,				/* read_wakeup */
	utp_poll,			/* select */
	NULL,				/* ioctl */
	NULL,				/* init */
	NULL,				/* destroy */
	NULL,				/* shutdown */
	ip_setsockopt,			/* setsockopt */
	ip_getsockopt,			/* getsockopt */
	NULL,				/* sendmsg */
	NULL,				/* recvmsg */
	NULL,				/* bind */
	NULL,				/* backlog rcv */
	utp_v4_hash,			/* hash */
	utp_v4_unhash,			/* unhash */
	utp_v4_rehash,			/* rehash */
	utp_good_socknum,		/* good_socknum */
	utp_verify_bind,		/* verify_bind */
	128,				/* max_header */
	0,				/* retransmits */
	"UTP",				/* name */
	0,				/* inuse */
	0				/* highestinuse */
};
