/* $USAGI: mip6nl.c,v 1.7 2003/11/13 17:10:09 nakam Exp $ */

/*
 * Copyright (C)2003 USAGI/WIDE Project
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */
/*
 * Authors:
 *	Noriaki TAKAMIYA @USAGI
 *	Masahide NAKAMURA @USAGI
 */
#include <assert.h>
#include <errno.h>
#include <stdio.h>
#include <string.h>

#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <asm/types.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <linux/xfrm.h>

#include <glue.h>
#include <mip6log.h>
#include <mip6nl.h>

#define	BUFLEN	8192

struct nlmsgdata_policy {
	struct xfrm_userpolicy_info pol;
	struct rtattr rta;
	struct xfrm_user_tmpl tmpls[3];
};

struct nlmsgdata_state {
	struct xfrm_usersa_info sa;
#if 0
 	struct rtattr rta;
#endif
};

/* Currently, all selector (in policy_info/id, in state_info/id, ...) is the same.
 */
static int set_selector(const struct mip6nl_parms *mp, struct xfrm_selector *sel)
{
	memset(sel, 0, sizeof(*sel));
	memcpy(&sel->daddr.a6, &mp->sel.daddr, sizeof(sel->daddr.a6));
	memcpy(&sel->saddr.a6, &mp->sel.saddr, sizeof(sel->saddr.a6));
	if (memcmp(&sel->daddr.a6, &in6addr_any, sizeof(sel->daddr.a6)) != 0)
		sel->prefixlen_d = 128;
	if (memcmp(&sel->saddr.a6, &in6addr_any, sizeof(sel->saddr.a6)) != 0)
		sel->prefixlen_s = 128;
	sel->proto = mp->sel.proto;
	sel->ifindex = mp->sel.ifindex;
	sel->user = getuid();
	sel->family = AF_INET6;

	return 0;
}

static int set_lifetime(struct xfrm_lifetime_cfg *lft)
{

	lft->soft_byte_limit = XFRM_INF;
	lft->hard_byte_limit = XFRM_INF;
	lft->soft_packet_limit = XFRM_INF;
	lft->hard_packet_limit = XFRM_INF;
	lft->soft_add_expires_seconds = 0;
	lft->hard_add_expires_seconds = 0;
	lft->soft_use_expires_seconds = 0;
	lft->hard_use_expires_seconds = 0;

	return 0;
}


static int set_template(struct xfrm_user_tmpl *tmpl, __u8 proto,
			const struct in6_addr *id_daddr, const struct in6_addr *id_saddr)
{
	memset(tmpl, 0, sizeof(*tmpl));

	/* XXX: Maybe tmpl->id.daddr and tmpl->saddr will be required when tunneling.
	 * see struct xfrm_tmpl and xfrm_tmpl_resolve().
	 */
	/* id */	
	if (id_daddr)
		memcpy(&tmpl->id.daddr, id_daddr, sizeof(tmpl->id.daddr.a6));
	tmpl->id.spi = MIP6_SPI;
	tmpl->id.proto = proto;

	if (id_saddr)
		memcpy(&tmpl->saddr.a6, id_saddr, sizeof(tmpl->saddr.a6));

	tmpl->reqid = 0;
	tmpl->mode = 1;	/* XXX: mip6 is always used as tunnel mode. */
	tmpl->share = 0;
	/* XXX: this is maybe a flag ignored if state is not found. */
	/* XXX: see xfrm_tmpl_resolve() */
	tmpl->optional = 1;
	tmpl->aalgos = 0;
	tmpl->ealgos = 0;
	tmpl->calgos = 0;

	return 0;
}

static int make_msg_policy_info(const struct mip6nl_parms *mp, void *buf, unsigned short int *len)
{
	struct nlmsgdata_policy *data = buf;
	struct xfrm_userpolicy_info *pol = &data->pol;
	struct rtattr *rta = &data->rta;
	int rta_len = 0;
	int i;

	/*memset(pol, 0, sizeof(*pol));*/
	set_selector(mp, &pol->sel);

	/* lifetime_cfg/cur */
	set_lifetime(&pol->lft);
	/* XXX: pol->curlft is needless from userland. */

	pol->priority = mp->pol.priority;
	pol->index = mp->pol.index;
	pol->dir = mp->pol.dir;
	pol->action = XFRM_POLICY_ALLOW; /* XFRM_POLICY_ALLOW or XFRM_POLICY_BLOCK */
	pol->flags = mp->pol.flags;

	pol->share = XFRM_SHARE_ANY; /* XXX: no idea */

	/* the template is placed as rtattr. */
	/* XXX: currently three templates(tunnel, dopt and rt) are always inserted. */
	rta->rta_type = XFRMA_TMPL;

	for (i = 0; i < TMPL_MAX; ++i) {
		if (!mp->id.tmpls[i].proto)
			break;
		set_template(&data->tmpls[i],
			     mp->id.tmpls[i].proto, &mp->id.tmpls[i].daddr, &mp->id.tmpls[i].saddr);
		rta_len += sizeof(data->tmpls[i]);
	}
	rta->rta_len = RTA_LENGTH(rta_len);

	*len = sizeof(*data); /* XXX: fixme: size is always max... */

	return 0;
}

static int make_msg_policy_id(const struct mip6nl_parms *mp, void *buf, unsigned short int *len)
{
	struct xfrm_userpolicy_id *pol;

	pol = (struct xfrm_userpolicy_id *)buf;

	set_selector(mp, &pol->sel);

	pol->index = mp->pol.index;
	pol->dir = mp->pol.dir;

	*len = sizeof(*pol);

	return 0;
}

static int make_msg_state_info(const struct mip6nl_parms *mp, void *buf, unsigned short int *len)
{
	struct nlmsgdata_state *data = buf;
	struct xfrm_usersa_info *sa = &data->sa;
#if 0
	struct rtattr *rta = &data->rta;
#endif

	memset(sa, 0, sizeof(*sa));

	set_selector(mp, &sa->sel);

	/* id */
	/* sa->id.daddr.a6 is treated later */
	sa->id.spi = MIP6_SPI;
	sa->id.proto = mp->id.ext.proto;

	/* sa->saddr.a6 is treated later */

	/* lifetime_cfg/cur */
	set_lifetime(&sa->lft);
	/* XXX: sa->curlft is needless from userland. */

	/* stats */
	;

	sa->seq = 0;
	sa->family = AF_INET6;
	sa->reqid = 0;
	sa->mode = 1; /* always 1 */
	sa->replay_window = 0;

#if 0
	/* the mip6_state is placed as rtattr. */
	rta->rta_type = XFRMA_MIP6;
	rta->rta_len = RTA_LENGTH(sizeof(*mip6_state));
#endif

	switch (mp->id.ext.proto) {
	case IPPROTO_IPV6:
	case IPPROTO_DSTOPTS:
	case IPPROTO_ROUTING:
		memcpy(&sa->id.daddr.a6, &mp->id.ext.daddr, sizeof(sa->id.daddr.a6));
		memcpy(&sa->saddr.a6, &mp->id.ext.saddr, sizeof(sa->saddr.a6));
		break;
	default:
		break;
	}

	*len = sizeof(*data);

	return 0;
}

static int make_msg_state_id(const struct mip6nl_parms *mp, void *buf, unsigned short int *len)
{
	struct xfrm_usersa_id *sa;

	sa = (struct xfrm_usersa_id *)buf;
	memset(sa, 0, sizeof(*sa));

	memcpy(&sa->daddr.a6, &mp->id.ext.daddr, sizeof(sa->daddr.a6));
	memcpy(&sa->saddr.a6, &mp->id.ext.saddr, sizeof(sa->saddr.a6));

	sa->spi = MIP6_SPI;
	sa->family = AF_INET6;
	sa->proto = mp->id.ext.proto;

	*len = sizeof(*sa);

	return 0;
}

static int make_msg_one(const struct mip6nl_parms *mp, void *buf, unsigned short int *len)
{
	struct nlmsghdr	*hdr = (struct nlmsghdr *)buf;
	unsigned short int data_len = 0;
	int ret;

	switch (mp->type) {
	case XFRM_MSG_NEWPOLICY:
	case XFRM_MSG_UPDPOLICY:
		ret = make_msg_policy_info(mp, NLMSG_DATA(hdr), &data_len);
		break;
	case XFRM_MSG_DELPOLICY:
	case XFRM_MSG_GETPOLICY:
		ret = make_msg_policy_id(mp, NLMSG_DATA(hdr), &data_len);
		break;

	case XFRM_MSG_NEWSA:
	case XFRM_MSG_UPDSA:
		ret = make_msg_state_info(mp, NLMSG_DATA(hdr), &data_len);
		break;

	case XFRM_MSG_DELSA:
	case XFRM_MSG_GETSA:
		ret = make_msg_state_id(mp, NLMSG_DATA(hdr), &data_len);
		break;

	default:
		ret = -EINVAL;
		__eprintf("unrecognized type: %d\n", mp->type);
		break;
	}

	if (!ret) {
		assert(data_len > 0);

		hdr->nlmsg_len = NLMSG_LENGTH(data_len);
		hdr->nlmsg_type = mp->type; /* XFRM_MSG_... */
		hdr->nlmsg_flags =  NLM_F_REQUEST | mp->flag;

		*len = NLMSG_SPACE(data_len);
	}
	return ret;
}

int mip6nl_make_msg(const struct mip6nl_parms *mp, void *buf, unsigned short int *len)
{
	return make_msg_one(mp, buf, len);
}

static int probe_msg_one(const struct mip6nl_parms *mp,
			 void *buf, unsigned int len)
{
	int ret = 0;

	switch (mp->type) {
	case XFRM_MSG_NEWPOLICY:
	case XFRM_MSG_DELPOLICY:
	case XFRM_MSG_GETPOLICY:
	case XFRM_MSG_UPDPOLICY:
	{
		struct nlmsgdata_policy *data = buf;

#if 0
		if (len < sizeof(*data)) {
			__eprintf("nlmsg data is too short:%u\n", len);
			ret = -EINVAL;
			break;
		}
#endif
		if (len < sizeof(struct xfrm_userpolicy_info)) {
			__eprintf("nlmsg data is too short:%u\n", len);
			ret = -EINVAL;
			break;
		}
		if (!mp->probe_handler_policy)
			break;
		ret = mp->probe_handler_policy(mp, &data->pol);
		break;
	}
	case XFRM_MSG_NEWSA:
	case XFRM_MSG_DELSA:
	case XFRM_MSG_GETSA:
	case XFRM_MSG_UPDSA:
	{
		struct nlmsgdata_state *data = buf;

#if 0
		if (len < sizeof(*data)) {
			__eprintf("nlmsg data is too short:%u\n", len);
			ret = -EINVAL;
			break;
		}
#endif
		if (len < sizeof(struct xfrm_usersa_info)) {
			__eprintf("nlmsg data is too short:%u\n", len);
			ret = -EINVAL;
			break;
		}
		if (!mp->probe_handler_state)
			break;
		ret = mp->probe_handler_state(mp, &data->sa);

		break;
	}
	default:
		break;
	}

	return ret;
}

/*
 * buf is pointed to nlmsghdr, lenp is length of buf including current
 * nlmsghdr's length.
 */
int mip6nl_probe_msg(const struct mip6nl_parms *mp,
		     void *buf, unsigned int *lenp)
{
	struct nlmsghdr	*hdr = (struct nlmsghdr *)buf;
	void *nextbuf;
	unsigned int len = *lenp;
	int ret;

	if (len == 0) {
		__dprintf("no more buffer(len=0)\n");
		goto done;
	}
	if (len < sizeof(*hdr)) {
		__eprintf("nlmsg cannot be parsed: len=%u\n", len);
		return -EINVAL;
	}
	if (!NLMSG_OK(hdr, len)) {
		__eprintf("nlmsg is not ok: len=%d, type=%u\n", len, hdr->nlmsg_type);

		 /* XXX: type == 0 is able to be occured... */
		if (hdr->nlmsg_type == NLMSG_DONE || !hdr->nlmsg_type) {
			__eprintf("ignored error. this is ok.\n");
			goto done;
		}

		return -EINVAL;
	}

	__dprintf("nlmsg len=%u\n", len);

	switch (hdr->nlmsg_type) {

	case NLMSG_ERROR:	/* Error		*/
	{
		struct nlmsgerr *err_hdr = (struct nlmsgerr *)NLMSG_DATA(hdr);

		__eprintf("nlmsg code: %s(%d)\n", strerror(-(err_hdr->error)),-(err_hdr->error));

		/*ret = 0;*/
		ret = -err_hdr->error;

		if (mp->error_handler)
			mp->error_handler(mp, &err_hdr->msg, err_hdr->error);

		break;
	}
	case NLMSG_OVERRUN:	/* Data lost		*/
		ret = 0;
		__eprintf("nlmsg data lost; type=OVERRUN\n");
		break;

	case NLMSG_NOOP:	/* Nothing.		*/
		ret = 0;
		__eprintf("nlmsg data is nothing; type=NOOP\n");
		break;

	case NLMSG_DONE:	/* End of a dump	*/
		__dprintf("nlmsg data is done; type=DONE\n");

		ret = probe_msg_one(mp, NLMSG_DATA(hdr), hdr->nlmsg_len);
		if (ret != 0)
			break;

		len -= NLMSG_ALIGN(hdr->nlmsg_len);
		*lenp = len;
		goto done;

	default:
		__dprintf("nlmsg data type=%d\n", hdr->nlmsg_type);

		ret = probe_msg_one(mp, NLMSG_DATA(hdr), hdr->nlmsg_len);
		if (ret != 0)
			break;

		/* NLMSG_NEXT split length of next nlmsghdr. */
		nextbuf = NLMSG_NEXT(hdr, len);

		*lenp = len;
		/* recursive call */
		ret = mip6nl_probe_msg(mp, nextbuf, lenp);

		break;
	}

	__dprintf("returns=%d; message remains len=%lld\n", ret, (long long int)*lenp);
	return ret;

 done:
	__dprintf("returns=0; message remains len=%lld\n", (long long int)*lenp);
	return 0;
}

int mip6nl_verify(const struct mip6nl_parms *mp)
{
	if (!mp) {
		__eprintf("NULL\n");
		return -EINVAL;
	}

	if (mp->sock < 0) {
		__eprintf("socket is not available\n");
		return -EINVAL;
	}
	return 0;
}

int mip6nl_talk(const struct mip6nl_parms *mp)
{
	int sock = 0;
	char sendbuf[BUFLEN];
	char recvbuf[BUFLEN];
	unsigned short int data_len = 0;
	int err = 0;
	int len;
#if 0
	fd_set rfds;
	int ret;
#endif

	err = mip6nl_verify(mp);
	if (err)
		goto fin;

	sock = mp->sock;

	memset(&sendbuf, 0, sizeof(sendbuf));

	err = mip6nl_make_msg(mp, (void *)&sendbuf, &data_len);
	if (err)
		goto fin;

	assert(data_len > 0);
	assert(data_len < sizeof(sendbuf));

	len = send(sock, &sendbuf, data_len, 0);
	if (len < 0) {
		perror("send");
		err = -errno;
		goto fin;
	}
#if 0
	__dprintf("sent = %d\n", len);
#endif

	memset(recvbuf, 0, sizeof(recvbuf));

#if 0
	FD_ZERO(&rfds);
	FD_SET(sock, &rfds);
	ret = select(1, &rfds, NULL, NULL, NULL);
	if (ret < 0) {
		perror("select");
		goto fin;
	}
	if (! FD_ISSET(sock, &rfds)) {
		printf("unknown discriptor\n");
		goto fin;
	}
#endif

	len = recv(sock, recvbuf, sizeof(recvbuf), 0);
	if (len < 0) {
		perror("recv");
		goto fin;
	}
#if 0
	__dprintf("recv = %d\n", len);
#endif
	assert(len <=  sizeof(recvbuf));

	err = mip6nl_probe_msg(mp, recvbuf, &len);
	if (err)
		goto fin;

 fin:
	return err;
}

void mip6nl_init(struct mip6nl_parms *mp, int sock)
{
	mp->sock = sock;
	mp->flag |= NLM_F_ACK;
}

int mip6nl_open(struct mip6nl_parms *mp)
{
	int sock = -1;
	int err = 0;

	memset(mp, 0, sizeof(mp));

	sock = socket(PF_NETLINK, SOCK_RAW, NETLINK_XFRM);
	if (sock < 0) {
		perror("socket");
		err = -errno;
		goto fin;
	}

	mip6nl_init(mp, sock);

 fin:
	return err;
}

void mip6nl_close(struct mip6nl_parms *mp)
{
	if (!mp)
		return;
	if (mp->sock >= 0)
		close(mp->sock);
}
