#include <linux/init.h>
#include <linux/module.h>
#include <linux/version.h>
#include <net/tcp.h>
#include <linux/netfilter.h>
#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_conntrack_acct.h>
#include <linux/skbuff.h>
#include <net/ip.h>
#include <linux/types.h>
#include <net/sock.h>
#include <linux/etherdevice.h>
#include <linux/cdev.h>
#include <linux/device.h>
#include "pkt_filter.h"
#include <linux/netfilter_bridge.h>  

#include "pkt_filter.h"
#include "filter_proc.h"
#include "filter_log.h"

struct list_head domain_head = LIST_HEAD_INIT(domain_head);
struct list_head ip_head = LIST_HEAD_INIT(ip_head);

DEFINE_RWLOCK(af_feature_lock);

#define feature_list_read_lock() 		read_lock_bh(&af_feature_lock);
#define feature_list_read_unlock() 		read_unlock_bh(&af_feature_lock);
#define feature_list_write_lock() 		write_lock_bh(&af_feature_lock);
#define feature_list_write_unlock()		write_unlock_bh(&af_feature_lock);


void ip2str(char *szIP, int	ulIP)
{
	uint8_t	*puchIP = (uint8_t*)&ulIP;
	sprintf(szIP, "%d.%d.%d.%d", puchIP[0], puchIP[1], puchIP[2], puchIP[3]);
}

int add_domain_node(const char *domain)
{
	domain_node_t *node = NULL;
	node = kzalloc(sizeof(domain_node_t), GFP_KERNEL);
	if (node == NULL) {
		printk("malloc feature memory error\n");
		return -1;
	}
	else {
		strcpy(node->domain, domain);
		feature_list_write_lock();
		list_add(&(node->head), &domain_head);
		feature_list_write_unlock();
	}
	return 0;
}

int add_ip_node(const char *ip)
{
	ip_node_t *node = NULL;
	node = kzalloc(sizeof(ip_node_t), GFP_KERNEL);
	if (node == NULL) {
		printk("malloc feature memory error\n");
		return -1;
	}
	else {
		strcpy(node->ip, ip);
		feature_list_write_lock();
		list_add(&(node->head), &ip_head);
		feature_list_write_unlock();
	}
	return 0;
}

void domain_clean_list(void)
{
	domain_node_t *n,*node;
	feature_list_write_lock();
	while(!list_empty(&domain_head)) {
		node = list_first_entry(&domain_head, domain_node_t, head);
		list_del(&(node->head));
		kfree(node);
	}
	feature_list_write_unlock();
}

void ip_clean_list(void)
{
	ip_node_t *n,*node;
	feature_list_write_lock();
	while(!list_empty(&ip_head)) {
		node = list_first_entry(&ip_head, ip_node_t, head);
		list_del(&(node->head));
		kfree(node);
	}
	feature_list_write_unlock();
}

int cmp_dmain_list(const char *domain)
{
	domain_node_t *n,*node;
	feature_list_read_lock();
	
	if(!list_empty(&domain_head)) { 
		list_for_each_entry_safe(node, n, &domain_head, head) {
		//	printk("func:%s domain:%s len1=%d; node->domain:%s;len2:%d cmp value:%d\n",__func__,domain, \
				strlen(domain),node->domain,strlen(node->domain),strcasecmp(domain, node->domain));
			if(strstr(domain, node->domain) || strcasecmp(domain, node->domain)==0) 
			{
					//drop
					AF_DEBUG("Drop Domain: %s\n", domain);
					feature_list_read_unlock();
					return -1;	
			}
		}
	}
	feature_list_read_unlock();
	return 0;
}

int cmp_dmain_whitelist(const char *domain)
{
	domain_node_t *n,*node;
	feature_list_read_lock();
	
	if(!list_empty(&domain_head)) { 
		list_for_each_entry_safe(node, n, &domain_head, head) {
		//	printk("func:%s domain:%s len1=%d; node->domain:%s;len2:%d cmp value:%d\n",__func__,domain, \
				strlen(domain),node->domain,strlen(node->domain),strcasecmp(domain, node->domain));
			if(strstr(domain, node->domain) || strcasecmp(domain, node->domain)==0) 
			{
					//drop
					AF_DEBUG("White list Domain: %s\n", domain);
					feature_list_read_unlock();
					return 0;	
			}
		}
	}
	feature_list_read_unlock();
	return -1;
}


static void  
IP2Str(char *ipaddr, int size, uint32_t ip)  
{  
        snprintf(ipaddr, size, "%d.%d.%d.%d", ( ip >> 24 ) & 0xff  
                                        , ( ip >> 16 ) & 0xff  
                                        , ( ip >> 8 ) & 0xff  
                                        , ip & 0xff);  
}  


int cmp_ip_list(char * ip)
{
	ip_node_t *n,*node;
	//char ip_str[17]={0};
	feature_list_read_lock();
	if(!list_empty(&ip_head)) { 
		list_for_each_entry_safe(node, n, &ip_head, head) {
			//memset(ip_str, 0, sizeof(ip_str));
			//IP2Str(ip_str, sizeof(ip_str), ip);

			AF_DEBUG("func:%s  ip_str=%s len=%d; node->ip:%s len:%d\n", __func__,ip,strlen(ip),node->ip,strlen(node->ip));
			if(strcmp(ip, node->ip) == 0) 
			{
					//drop
					AF_DEBUG("Drop IP: %s\n", ip);
					feature_list_read_unlock();
					return -1;	
			}
		}
	}
	feature_list_read_unlock();
	return 0;
}

int cmp_ip_whitelist(char * ip)
{
	ip_node_t *n,*node;
	//char ip_str[17]={0};
	feature_list_read_lock();
	if(!list_empty(&ip_head)) { 
		list_for_each_entry_safe(node, n, &ip_head, head) {
			//memset(ip_str, 0, sizeof(ip_str));
			//IP2Str(ip_str, sizeof(ip_str), ip);

			AF_DEBUG("func:%s  ip_str=%s len=%d; node->ip:%s len:%d\n", __func__,ip,strlen(ip),node->ip,strlen(node->ip));
			if(strcmp(ip, node->ip) == 0) 
			{
					//drop
					AF_DEBUG("White List IP: %s\n", ip);
					feature_list_read_unlock();
					return 0;	
			}
		}
	}
	feature_list_read_unlock();
	return -1;
}


uint16_t get_name_str(uint8_t *str, char *name )
{
        uint8_t *p;
        uint8_t *q;
        uint8_t len = 0;
        uint8_t tot_len = 0;

        if(str == NULL || name == NULL)
                return 0;

        p = str;
        q = name;
        len = p[0];
 		while((tot_len) < 64)
        {
                tot_len += len + 1;

                p++;
                memcpy(q, p, len);
                p += len;
                q += len;
                len = p[0];

                if(len == 0x00)
                {
                        q[0] = '\0';
                        tot_len++;
                        break;
                }
                else
                {
                        q[0] = '.';
                        q++;
                }
        }

        return tot_len;
}
int  process_domain_name_list(struct sk_buff *skb)
{
        uint8_t *p;
        struct dns_reply_head *dnshead = NULL;
        //struct dns_reply_query dnsquery;
        struct dns_reply_answer *dnsanswer = NULL;
        char req_name[64];
        uint16_t req_name_len = 0;
        uint16_t rep_num = 0;
        // struct ip_white_list ip;
        uint16_t req_type = 0;
		struct iphdr *iph = NULL;

		iph = ip_hdr(skb);
		if (!iph) {
			return 0;
		}

        p = skb->data + iph->ihl * 4 + 8;
        dnshead =(struct dns_reply_head *)p;

 		// if((dnshead->flag & 0xfff0) != 0)
        //         return 0;//have error

        if(dnshead->qnum == 0)
                return 0;

        AF_DEBUG("dns req num %d rep num %d\n", dnshead->qnum, dnshead->anum);
        rep_num = dnshead->anum;
        p += 12;

        memset(req_name, 0, 64);
        req_name_len = get_name_str(p, req_name);
        if(req_name_len == 0)
                return 0;

        AF_DEBUG("req name %s len %d \n", req_name, req_name_len);

        p += req_name_len;
        memcpy(&req_type, p, 2);

        if(req_type != 0x01)//A req ip
                return 0;
		
        if(cmp_dmain_list(req_name))
                 return -1;

		return 0;
}

int  process_domain_name_whitelist(struct sk_buff *skb)
{
        uint8_t *p;
        struct dns_reply_head *dnshead = NULL;
        //struct dns_reply_query dnsquery;
        struct dns_reply_answer *dnsanswer = NULL;
        char req_name[64];
        uint16_t req_name_len = 0;
        uint16_t rep_num = 0;
        // struct ip_white_list ip;
        uint16_t req_type = 0;
		struct iphdr *iph = NULL;

		iph = ip_hdr(skb);
		if (!iph) {
			return 0;
		}

        p = skb->data + iph->ihl * 4 + 8;
        dnshead =(struct dns_reply_head *)p;

 		// if((dnshead->flag & 0xfff0) != 0)
        //         return 0;//have error

        if(dnshead->qnum == 0)
                return 0;

        AF_DEBUG("dns req num %d rep num %d\n", dnshead->qnum, dnshead->anum);
        rep_num = dnshead->anum;
        p += 12;

        memset(req_name, 0, 64);
        req_name_len = get_name_str(p, req_name);
        if(req_name_len == 0)
                return 0;

        AF_DEBUG("req name %s len %d \n", req_name, req_name_len);

        p += req_name_len;
        memcpy(&req_type, p, 2);

        if(req_type != 0x01)//A req ip
                return 0;
		
        if(cmp_dmain_whitelist(req_name))
                 return -1;

		return 0;
}

int  process_ipv6_domain_name_list(struct sk_buff *skb)
{
        uint8_t *p;
        struct dns_reply_head *dnshead = NULL;
        //struct dns_reply_query dnsquery;
        struct dns_reply_answer *dnsanswer = NULL;
        char req_name[64];
        uint16_t req_name_len = 0;
        uint16_t rep_num = 0;
        uint16_t req_type = 0;
		struct ipv6hdr *iph6;;

		iph6 = ipv6_hdr(skb);
   		 if (!iph6)
        return NF_ACCEPT;
		//ipv6 header size 40, udp header size 8
        p = skb->data + 40 + 8;
        dnshead =(struct dns_reply_head *)p;

 		// if((dnshead->flag & 0xfff0) != 0)
        //         return 0;//have error

        if(dnshead->qnum == 0)
                return 0;

        AF_DEBUG("dns req num %d rep num %d\n", dnshead->qnum, dnshead->anum);
        rep_num = dnshead->anum;
        p += 12;

        memset(req_name, 0, 64);
        req_name_len = get_name_str(p, req_name);
        if(req_name_len == 0)
                return 0;

        AF_DEBUG("req name %s len %d \n", req_name, req_name_len);

        p += req_name_len;
        memcpy(&req_type, p, 2);

        if(req_type != 0x01)//A req ip
                return 0;
		
        if(cmp_dmain_list(req_name))
                 return -1;

		return 0;

}

int  process_ipv6_domain_name_whitelist(struct sk_buff *skb)
{
        uint8_t *p;
        struct dns_reply_head *dnshead = NULL;
        //struct dns_reply_query dnsquery;
        struct dns_reply_answer *dnsanswer = NULL;
        char req_name[64];
        uint16_t req_name_len = 0;
        uint16_t rep_num = 0;
        uint16_t req_type = 0;
		struct ipv6hdr *iph6;;

		iph6 = ipv6_hdr(skb);
   		 if (!iph6)
        return NF_ACCEPT;
		//ipv6 header size 40, udp header size 8
        p = skb->data + 40 + 8;
        dnshead =(struct dns_reply_head *)p;

 		// if((dnshead->flag & 0xfff0) != 0)
        //         return 0;//have error

        if(dnshead->qnum == 0)
                return 0;

        AF_DEBUG("dns req num %d rep num %d\n", dnshead->qnum, dnshead->anum);
        rep_num = dnshead->anum;
        p += 12;

        memset(req_name, 0, 64);
        req_name_len = get_name_str(p, req_name);
        if(req_name_len == 0)
                return 0;

        AF_DEBUG("req name %s len %d \n", req_name, req_name_len);

        p += req_name_len;
        memcpy(&req_type, p, 2);

        if(req_type != 0x01)//A req ip
                return 0;
		
        if(cmp_dmain_whitelist(req_name))
                 return -1;

		return 0;

}

#if LINUX_VERSION_CODE >= KERNEL_VERSION(4,4,0)
static u_int32_t pkt_filter_hook(void *priv,
			       struct sk_buff *skb,
			       const struct nf_hook_state *state) {
#else
static u_int32_t pkt_filter_hook(unsigned int hook,
						    	struct sk_buff *skb,
					           const struct net_device *in,
					           const struct net_device *out,
					           int (*okfn)(struct sk_buff *)){
#endif
	
	const struct iphdr *iph;  
    const struct udphdr *udph;
	int ret = -1;

	if( unlikely(!skb) ) {  
                return NF_ACCEPT;  
     }  

	if (!skb)  
        return NF_ACCEPT;

	if(skb->protocol != htons(0x0800)) //IPV4 Header
         return NF_ACCEPT;
	
	iph = ip_hdr(skb); 
	if( unlikely(!iph) ) {  
        return NF_ACCEPT;  
    }  

	if(g_ip_filter_enable)
	{
		char ipaddr[17];  
		memset(ipaddr, 0, sizeof(ipaddr));  
		IP2Str(ipaddr, sizeof(ipaddr), ntohl(iph->daddr)); 
		//printk("ipaddr:%s\n",ipaddr);
		//if(cmp_ip_list(ntohl(iph->daddr)))
		if(cmp_ip_list(ipaddr))
		{
			return NF_DROP;
		}
	} 

	if(g_ip_whitelist_enable)
	{
		char ipaddr[17];  
		memset(ipaddr, 0, sizeof(ipaddr));  
		IP2Str(ipaddr, sizeof(ipaddr), ntohl(iph->daddr)); 
		//printk("ipaddr:%s\n",ipaddr);
		//if(cmp_ip_list(ntohl(iph->daddr)))
		if(strcasecmp(ipaddr,"127.0.0.1"))
			return NF_ACCEPT;

		if(cmp_ip_whitelist(ipaddr))
		{
			return NF_DROP;
		}
	} 

	if (iph->protocol == IPPROTO_UDP)//udp
	{
		udph = udp_hdr(skb);
				
		if(unlikely(!udph))
		{
			return NF_ACCEPT;
		}	
		
		if( udph->dest == 53)  //DNS Request
		{
			if(g_domain_filter_enable)
			{
				ret = process_domain_name_list(skb);
				if(ret < 0)
					return NF_DROP;
			}

			if(g_domain_whitelist_enable)
			{
				ret = process_domain_name_whitelist(skb);
				if(ret < 0)
					return NF_DROP;
			}
				
		}	
	}

	return NF_ACCEPT;
}


#if LINUX_VERSION_CODE >= KERNEL_VERSION(4,4,0)
static u_int32_t pkt_filter_hook_ipv6(void *priv,
			       struct sk_buff *skb,
			       const struct nf_hook_state *state) {
#else
static u_int32_t pkt_filter_hook_ipv6(unsigned int hook,
						    	struct sk_buff *skb,
					           const struct net_device *in,
					           const struct net_device *out,
					           int (*okfn)(struct sk_buff *)){
#endif
	
	const struct ipv6hdr *iph6;  
    const struct udphdr *udph;
	int ret = -1;

	if( unlikely(!skb) ) {  
                return NF_ACCEPT;  
     }  

	if (!skb)  
        return NF_ACCEPT;


	if(skb->protocol != htons(0x86dd)) //IPV6 Header
         return NF_ACCEPT;

	iph6 = ipv6_hdr(skb);
    if (!iph6)
        return NF_ACCEPT;
	
	if (iph6->nexthdr != NEXTHDR_UDP)
        return NF_ACCEPT;

	
	udph = udp_hdr(skb);
			
		if(unlikely(!udph))
		{
			return NF_ACCEPT;
		}	
		
		if( udph->dest == 53)  //DNS Request
		{
			//printk("g_domain_filter_enable:%d\n",g_domain_filter_enable);
			if(g_domain_filter_enable)
			{
				AF_DEBUG("g_domain_filter_enable:%d\n",g_domain_filter_enable);
				ret = process_ipv6_domain_name_list(skb);
				if(ret < 0)
					return NF_DROP;
			}	

			if(g_domain_whitelist_enable)
			{
				AF_DEBUG("g_domain_whitelist_enable:%d\n",g_domain_whitelist_enable);
				ret = process_ipv6_domain_name_whitelist(skb);
				if(ret < 0)
					return NF_DROP;
			}	
		}	
	
	return NF_ACCEPT;
}


#if LINUX_VERSION_CODE >= KERNEL_VERSION(4,4,0)
static struct nf_hook_ops pkt_filter_ops[] __read_mostly = {
	{
		.hook		= pkt_filter_hook,
		.pf			= PF_INET,
		.hooknum	= NF_INET_PRE_ROUTING,
		.priority	=  NF_IP_PRI_FIRST,
	},
	{
		.hook		= pkt_filter_hook_ipv6,
		.owner		= THIS_MODULE,
		.pf			= PF_INET6,
		.hooknum	= NF_INET_PRE_ROUTING,
		.priority	= NF_IP_PRI_FIRST,
		//.pf			= PF_BRIDGE,
		//.hooknum	= NF_BR_PRE_ROUTING,
		//.priority	= NF_BR_PRI_FIRST,
	},
};
#else
static struct nf_hook_ops pkt_filter_ops[] __read_mostly = {
	{
		.hook		= pkt_filter_hook,
		.owner		= THIS_MODULE,
		.pf			= PF_INET,
		.hooknum	= NF_INET_PRE_ROUTING,
		.priority	= NF_IP_PRI_FIRST,
		//.pf			= PF_BRIDGE,
		//.hooknum	= NF_BR_PRE_ROUTING,
		//.priority	= NF_BR_PRI_FIRST,
	},
	{
		.hook		= pkt_filter_hook_ipv6,
		.owner		= THIS_MODULE,
		.pf			= PF_INET6,
		.hooknum	= NF_INET_PRE_ROUTING,
		.priority	= NF_IP_PRI_FIRST,
		//.pf			= PF_BRIDGE,
		//.hooknum	= NF_BR_PRE_ROUTING,
		//.priority	= NF_BR_PRI_FIRST,
	},
};
#endif

static int __init pkt_filter_init(void)
{
	printk("-----pkt_filter_init----\n");
	af_log_init();
	init_config();
#if LINUX_VERSION_CODE >= KERNEL_VERSION(4,13,0)
    nf_register_net_hooks(&init_net, pkt_filter_ops, ARRAY_SIZE(pkt_filter_ops));
#else
	nf_register_hooks(pkt_filter_ops, ARRAY_SIZE(pkt_filter_ops));
#endif
    return 0;
}

static void pkt_filter_exit(void)
{
	printk("-----pkt_filter_exit----\n");
	af_log_exit();
	deinit_config();
#if LINUX_VERSION_CODE >= KERNEL_VERSION(4,13,0)
    nf_unregister_net_hooks(&init_net, pkt_filter_ops, ARRAY_SIZE(pkt_filter_ops));
#else
	nf_unregister_hooks(pkt_filter_ops, ARRAY_SIZE(pkt_filter_ops));
#endif
    return;
}

module_init(pkt_filter_init);
module_exit(pkt_filter_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("dawsen_gao@163.com");
MODULE_DESCRIPTION("IP Or Doamin filter module");
MODULE_VERSION("1.0.0");