/* 
 * Copyright (C) 2005  Network Applied Communication Laboratory Co., Ltd.
 *
 * This file is part of Rast.
 * See the file COPYING for redistribution information.
 *
 */

#include <apr_strings.h>
#include <st.h>

#include "rast/ruby.h"
#include "rast/config.h"
#include "rast/text_index.h"
#include "rast/local_db.h"
#include "rast/query.h"
#include "rast/xmlrpc_client.h"
#include "rast/merger.h"

static VALUE cTerm, cTermFrequency, cCandidate, cQueryResult;
static VALUE cTextIndexer, cEncoding;

static rast_encoding_module_t *
get_encoding_module(const char *name)
{
    rast_encoding_module_t *encoding_module;

    rast_rb_raise_error(rast_get_encoding_module(name, &encoding_module));
    return encoding_module;
}

static VALUE
encoding_new(char *name)
{
    rast_error_t *error;
    rast_encoding_module_t *encoding_module;

    error = rast_get_encoding_module(name, &encoding_module);
    if (error != RAST_OK) {
        rast_error_destroy(error);
        return Qnil;
    }
    return Data_Wrap_Struct(cEncoding, NULL, NULL, encoding_module);
}

static VALUE
encoding_s_aref(VALUE self, VALUE name)
{
    return encoding_new(StringValuePtr(name));
}

static VALUE
encoding_register_tokenize(VALUE self, VALUE text)
{
    VALUE vpool;
    apr_pool_t *pool;
    rast_encoding_module_t *encoding_module;
    rast_tokenizer_t *tokenizer;
    rast_token_t token;
    rast_error_t *error;
    VALUE ary;

    StringValue(text);
    pool = rast_rb_pool_new(&vpool);

    Data_Get_Struct(self, rast_encoding_module_t, encoding_module);
    tokenizer = rast_register_tokenizer_create(pool, encoding_module,
                                               RSTRING(text)->ptr,
                                               RSTRING(text)->len);
    ary = rb_ary_new2(3);
    while (!rast_register_tokenizer_is_done(tokenizer)) {
        error = rast_register_tokenizer_get_current(tokenizer, &token);
        rast_rb_raise_error(error);
        rb_ary_store(ary, 0, rb_str_new(token.ptr, token.nbytes));
        rb_ary_store(ary, 1, INT2NUM(token.pos));
        rb_ary_store(ary, 2, token.is_complete ? Qtrue : Qfalse);
        rb_yield(ary);
        rast_register_tokenizer_next(tokenizer);
    }
    return Qnil;
}

static VALUE
encoding_search_tokenize(VALUE self, VALUE text)
{
    VALUE vpool;
    apr_pool_t *pool;
    rast_encoding_module_t *encoding_module;
    rast_tokenizer_t *tokenizer;
    rast_token_t token;
    VALUE ary;

    StringValue(text);
    pool = rast_rb_pool_new(&vpool);

    Data_Get_Struct(self, rast_encoding_module_t, encoding_module);
    tokenizer = rast_search_tokenizer_create(pool, encoding_module,
                                             RSTRING(text)->ptr,
                                             RSTRING(text)->len);
    ary = rb_ary_new2(3);
    while (!rast_search_tokenizer_is_done(tokenizer)) {
        rast_search_tokenizer_get_current(tokenizer, &token);
        rb_ary_store(ary, 0, rb_str_new(token.ptr, token.nbytes));
        rb_ary_store(ary, 1, INT2NUM(token.pos));
        rb_ary_store(ary, 2, token.is_complete ? Qtrue : Qfalse);
        rb_yield(ary);
        rast_search_tokenizer_next(tokenizer);
    }
    return Qnil;
}

static VALUE
encoding_normalize_text(VALUE self, VALUE src_text)
{
    VALUE vpool;
    apr_pool_t *pool;
    rast_encoding_module_t *encoding_module;
    char *dst_text;
    rast_size_t dst_len;

    StringValue(src_text);
    pool = rast_rb_pool_new(&vpool);
    Data_Get_Struct(self, rast_encoding_module_t, encoding_module);
    encoding_module->normalize_text(pool,
                                    RSTRING(src_text)->ptr,
                                    RSTRING(src_text)->len,
                                    &dst_text, &dst_len);
    return rb_str_new(dst_text, dst_len);
}

static VALUE
encoding_normalize_chars(VALUE self, VALUE src)
{
    VALUE vpool;
    apr_pool_t *pool;
    rast_encoding_module_t *encoding_module;
    char *dst;
    rast_size_t dst_len;

    StringValue(src);
    pool = rast_rb_pool_new(&vpool);
    Data_Get_Struct(self, rast_encoding_module_t, encoding_module);
    encoding_module->normalize_chars(pool,
                                     RSTRING(src)->ptr, RSTRING(src)->len,
                                     &dst, &dst_len);
    return rb_str_new(dst, dst_len);
}

typedef struct {
    rast_text_indexer_t *indexer;
    apr_pool_t *pool;
} text_indexer_data_t;

static void
text_indexer_free(text_indexer_data_t *data)
{
    if (data) {
        apr_pool_destroy(data->pool);
        xfree(data);
    }
}

static VALUE
text_indexer_new(rast_text_indexer_t *indexer)
{
    return Data_Wrap_Struct(cTextIndexer, NULL, text_indexer_free, indexer);
}

static VALUE
text_indexer_add(VALUE self, VALUE str)
{
    text_indexer_data_t *data;
    rast_error_t *error;

    StringValue(str);
    Data_Get_Struct(self, text_indexer_data_t, data);
    error = rast_text_indexer_add(data->indexer,
                                  RSTRING(str)->ptr, RSTRING(str)->len, NULL);
    rast_rb_raise_error(error);
    return Qnil;
}

static VALUE
text_indexer_commit(VALUE self)
{
    text_indexer_data_t *data;

    Data_Get_Struct(self, text_indexer_data_t, data);
    rast_rb_raise_error(rast_text_indexer_commit(data->indexer));

    return Qnil;
}

typedef struct {
    rast_text_index_t *index;
    apr_pool_t *pool;
    int closed;
} text_index_data_t;

static void
text_index_free(text_index_data_t *data)
{
    if (data) {
        if (!data->closed) {
            rast_text_index_close(data->index);
        }
        apr_pool_destroy(data->pool);
        xfree(data);
    }
}

static VALUE
text_index_alloc(VALUE klass)
{
    return Data_Wrap_Struct(klass, NULL, text_index_free, NULL);
}

typedef struct {
    apr_hash_t *hash;
    apr_pool_t *pool;
} doc_id_table_arg_t;

static int
rb_hash_to_doc_id_table_i(VALUE vkey, VALUE vvalue, doc_id_table_arg_t *arg)
{
    rast_doc_id_t key, value;

    key = NUM2INT(vkey);
    value = NUM2INT(vvalue);

    apr_hash_set(arg->hash,
                 apr_pmemdup(arg->pool, &key, sizeof(rast_doc_id_t)),
                 sizeof(rast_doc_id_t),
                 apr_pmemdup(arg->pool, &value, sizeof(rast_doc_id_t)));

    return ST_CONTINUE;
}

static apr_hash_t *
rb_hash_to_doc_id_table(VALUE vhash, apr_pool_t *pool)
{
    doc_id_table_arg_t arg;

    Check_Type(vhash, T_HASH);
    arg.hash = apr_hash_make(pool);
    arg.pool = pool;
    rb_hash_foreach(vhash, rb_hash_to_doc_id_table_i, (VALUE) &arg);
    return arg.hash;
}

static VALUE
text_index_s_optimize(VALUE self, VALUE vold_index_name, VALUE vnew_index_name,
                      VALUE vdoc_id_table)
{
    apr_pool_t *pool;
    apr_hash_t *doc_id_table;
    rast_error_t *error;
    VALUE vpool;
    DB_ENV *bdb_env = NULL;
    DB_TXN *bdb_txn = NULL;
    rast_encoding_module_t *encoding_module;
    rast_text_index_t *old_index;
    rast_size_t pos_block_size = 512;

    SafeStringValue(vold_index_name);
    SafeStringValue(vnew_index_name);
    Check_Type(vdoc_id_table, T_HASH);

    pool = rast_rb_pool_new(&vpool);
    doc_id_table = rb_hash_to_doc_id_table(vdoc_id_table, pool);

    encoding_module = get_encoding_module("utf8");
    error = rast_text_index_open(&old_index, StringValuePtr(vold_index_name),
                                 RAST_DB_RDONLY, encoding_module, bdb_env, 0,
                                 pos_block_size, pool);
    rast_rb_raise_error(error);

    error = rast_text_index_optimize(bdb_env, bdb_txn, 0, old_index,
                                     StringValuePtr(vnew_index_name),
                                     doc_id_table);
    rast_rb_raise_error(error);

    rast_rb_raise_error(rast_text_index_close(old_index));

    return Qnil;
}

static VALUE
text_index_initialize(int argc, VALUE *argv, VALUE self)
{
    VALUE name, vflags, vblock_size;
    int flags = RAST_DB_RDWR;
    rast_encoding_module_t *encoding_module;
    rast_error_t *error;
    apr_pool_t *pool;
    rast_text_index_t *index;
    text_index_data_t *data;
    rast_size_t block_size;

    rb_scan_args(argc, argv, "12", &name, &vflags, &vblock_size);
    if (!NIL_P(vflags)) {
        flags = NUM2INT(vflags);
    }
    if (NIL_P(vblock_size)) {
        block_size = 512;
    }
    else {
        block_size = NUM2INT(vblock_size);
    }
    rast_rb_pool_create_ex(&pool, NULL, NULL);
    encoding_module = get_encoding_module("utf8");
    error = rast_text_index_open(&index, StringValuePtr(name), flags,
                                 encoding_module, NULL, 0, block_size, pool);
    rast_rb_raise_error(error);

    data = ALLOC(text_index_data_t);
    data->index = index;
    data->pool = pool;
    data->closed = 0;
    DATA_PTR(self) = data;
    return Qnil;
}

static VALUE
text_index_close(VALUE self)
{
    text_index_data_t *data;

    Data_Get_Struct(self, text_index_data_t, data);
    rast_rb_raise_error(rast_text_index_close(data->index));

    data->closed = 1;
    return Qnil;
}

static VALUE
text_index_sync(VALUE self)
{
    text_index_data_t *data;

    Data_Get_Struct(self, text_index_data_t, data);
    rast_rb_raise_error(rast_text_index_sync(data->index));

    return Qnil;
}

static VALUE
text_index_register(int argc, VALUE *argv, VALUE self)
{
    text_index_data_t *data;
    rast_error_t *error;
    rast_text_indexer_t *indexer;
    rast_doc_id_t doc_id;
    apr_pool_t *pool;

    Data_Get_Struct(self, text_index_data_t, data);
    if (argc < 1) {
        rb_raise(rb_eArgError, "wrong number of arguments (0 for 1)");
    }
    doc_id = NUM2INT(argv[0]);
    apr_pool_create(&pool, NULL);
    error = rast_text_index_register(data->index, doc_id, &indexer, pool);
    rast_rb_raise_error(error);

    if (argc == 1) {
        return text_indexer_new(indexer);
    }
    else {
        int i;

        for (i = 1; i < argc; i++) {
            StringValue(argv[i]);
            error = rast_text_indexer_add(indexer,
                                          RSTRING(argv[i])->ptr,
                                          RSTRING(argv[i])->len, NULL);
            rast_rb_raise_error(error);
        }
        error = rast_text_indexer_commit(indexer);
        rast_rb_raise_error(error);
        return Qnil;
    }
}

static VALUE
query_result_new(rast_query_result_t *result)
{
    rast_term_t *t;
    rast_candidate_t *c;
    rast_term_frequency_t *tf;
    VALUE terms, term, candidates, term_frequencies, term_frequency, candidate;

    terms = rb_ary_new();
    for (t = APR_RING_FIRST(&result->terms);
         t != APR_RING_SENTINEL(&result->terms, rast_term_t, link);
         t = APR_RING_NEXT(t, link)) {
        term = rb_funcall(cTerm, rb_intern("new"), 2,
                          rb_tainted_str_new2(t->term), INT2NUM(t->doc_count));
        rb_ary_push(terms, term);
    }
    candidates = rb_ary_new();
    for (c = APR_RING_FIRST(&result->candidates);
         c != APR_RING_SENTINEL(&result->candidates, rast_candidate_t, link);
         c = APR_RING_NEXT(c, link)) {
        term_frequencies = rb_ary_new();
        for (tf = APR_RING_FIRST(&c->terms);
             tf != APR_RING_SENTINEL(&c->terms, rast_term_frequency_t, link);
             tf = APR_RING_NEXT(tf, link)) {
            term_frequency = rb_funcall(cTermFrequency, rb_intern("new"), 2,
                                        INT2NUM(tf->count), INT2NUM(tf->pos));
            rb_ary_push(term_frequencies, term_frequency);
        }
        candidate = rb_funcall(cCandidate, rb_intern("new"), 2,
                               INT2NUM(c->doc_id), term_frequencies);
        rb_ary_push(candidates, candidate);
    }
    return rb_funcall(cQueryResult, rb_intern("new"), 2, terms, candidates);
}

static VALUE
text_index_search(VALUE self, VALUE text)
{
    text_index_data_t *data;
    rast_error_t *error;
    rast_query_result_t *result;
    VALUE vpool;
    apr_pool_t *pool;

    pool = rast_rb_pool_new(&vpool);
    Data_Get_Struct(self, text_index_data_t, data);
    error = rast_text_index_search(data->index, StringValuePtr(text), 1,
                                   &result, pool);
    rast_rb_raise_error(error);
    return query_result_new(result);
}

typedef struct {
    rast_query_t *query;
    apr_pool_t *pool;
} query_data_t;

static void
query_free(query_data_t *data)
{
    if (data) {
        apr_pool_destroy(data->pool);
        xfree(data);
    }
}

static VALUE
query_alloc(VALUE klass)
{
    return Data_Wrap_Struct(klass, NULL, query_free, NULL);
}

static VALUE
query_initialize(VALUE self, VALUE str)
{
    rast_error_t *error;
    apr_pool_t *pool;
    rast_query_t *query;
    query_data_t *data;
    rast_encoding_module_t *encoding_module;

    SafeStringValue(str);
    rast_rb_pool_create_ex(&pool, NULL, NULL);
    encoding_module = get_encoding_module("utf8");
    error = rast_parse_query(pool, encoding_module, RSTRING(str)->ptr, &query);
    if (error != RAST_OK) {
        apr_pool_destroy(pool);
        rast_rb_raise_error(error);
    }
    data = ALLOC(query_data_t);
    data->query = query;
    data->pool = pool;
    DATA_PTR(self) = data;
    return Qnil;
}

static VALUE
query_inspect(VALUE self)
{
    query_data_t *data;
    VALUE vpool;
    apr_pool_t *pool;
    char *s;

    Data_Get_Struct(self, query_data_t, data);
    pool = rast_rb_pool_new(&vpool);
    s = rast_query_inspect(data->query, pool);
    return rb_tainted_str_new2(s);
}

static VALUE
query_exec(int argc, VALUE *argv, VALUE self)
{
    VALUE vdb, voptions;
    rast_error_t *error;
    query_data_t *data;
    rast_local_db_t *db;
    rast_query_result_t *result;
    VALUE vpool;
    apr_pool_t *pool;
    rast_query_option_t options;

    rb_scan_args(argc, argv, "11", &vdb, &voptions);
    options.score_method = RAST_SCORE_METHOD_TFIDF;
    if (!NIL_P(voptions)) {
        rast_rb_get_int_option(voptions, "score_method",
                               (int *) &options.score_method);
    }
    Data_Get_Struct(self, query_data_t, data);
    db = (rast_local_db_t *) rast_rb_get_db(vdb);
    pool = rast_rb_pool_new(&vpool);
    error = rast_query_exec(data->query, db, &options, &result, pool);
    rast_rb_raise_error(error);

    return query_result_new(result);
}

static VALUE
query_optimize(VALUE self)
{
    rast_error_t *error;
    query_data_t *data;

    Data_Get_Struct(self, query_data_t, data);
    error = rast_query_optimize(data->query, &data->query, data->pool);
    rast_rb_raise_error(error);
    return self;
}

static int
pack_number(char *s, int n)
{
    char *p = s;
    div_t d;

    if (n == 0) {
        *p = 0;
        return 1;
    }
    while (n > 0) {
        d = div(n, 128);
        n = d.quot;
        if (n > 0) {
            *p = d.rem | 0x80;
        }
        else {
            *p = d.rem;
        }
        p++;
    }
    return p - s;
}

static int
unpack_number(const char *s, int nbytes, int *np)
{
    const char *p = s, *pend = s + nbytes;
    int base = 1, n = 0;

    while (p < pend) {
        if (*p & 0x80) {
            n += base * (*p & 0x7F);
            base *= 128;
            p++;
        }
        else {
            n += *p * base;
            p++;
            break;
        }
    }
    *np = n;
    return p - s;
}

static VALUE
vnum_pack(VALUE self, VALUE ary)
{
    int i;
    VALUE result;

    Check_Type(ary, T_ARRAY);
    result = rb_str_new("", 0);
    for (i = 0; i < RARRAY(ary)->len; i++) {
        char buf[5];
        int num, nbytes;
        num = NUM2INT(RARRAY(ary)->ptr[i]);
        if (num < 0) {
            rb_raise(rb_eArgError, "negative value - %d", num);
        }
        nbytes = pack_number(buf, num);
        rb_str_cat(result, buf, nbytes);
    }
    return result;
}

static VALUE
vnum_unpack(VALUE self, VALUE str)
{
    VALUE result;
    char *p, *pend;

    Check_Type(str, T_STRING);
    p = RSTRING(str)->ptr;
    pend = p + RSTRING(str)->len;
    result = rb_ary_new();
    while (p < pend) {
        int num, nbytes;
        nbytes = unpack_number(p, pend - p, &num);
        p += nbytes;
        rb_ary_push(result, INT2NUM(num));
    }
    return result;
}

static VALUE
db_register_raw(VALUE self, VALUE text, VALUE vproperty_values)
{
    rast_db_t *db;
    rast_value_t *property_values;
    rast_error_t *error;
    apr_pool_t *pool;
    VALUE vpool;
    rast_doc_id_t doc_id;
    int i, num_properties;

    pool = rast_rb_pool_new(&vpool);
    db = rast_rb_get_db(self);
    Check_Type(vproperty_values, T_ARRAY);
    StringValue(text);

    num_properties = RARRAY(vproperty_values)->len;
    property_values = (rast_value_t *)
        apr_palloc(pool, sizeof(rast_value_t) * num_properties);
    for (i = 0; i < num_properties; i++) {
        VALUE value, str_value;
        char *s;

        value = RARRAY(vproperty_values)->ptr[i];
        switch (TYPE(value)) {
        case T_STRING:
            rast_value_set_string(property_values + i, StringValuePtr(value));
            rast_value_set_type(property_values + i, RAST_TYPE_STRING);
            break;
        case T_FIXNUM:
            rast_value_set_uint(property_values + i, NUM2INT(value));
            rast_value_set_type(property_values + i, RAST_TYPE_UINT);
            break;
        default:
            if (RTEST(rb_obj_is_kind_of(value, rast_rb_cDate))) {
                str_value = rb_funcall(value, rb_intern("strftime"), 1,
                                       rb_str_new2("%F"));
                s = apr_pstrdup(pool, StringValuePtr(str_value));
                rast_value_set_date(property_values + i, s);
                rast_value_set_type(property_values + i, RAST_TYPE_DATE);
            }
            else if (RTEST(rb_obj_is_kind_of(value, rast_rb_cDateTime))) {
                str_value = rb_funcall(value, rb_intern("strftime"), 1,
                                       rb_str_new2("%FT%T"));
                s = apr_pstrdup(pool, StringValuePtr(str_value));
                rast_value_set_datetime(property_values + i, s);
                rast_value_set_type(property_values + i, RAST_TYPE_DATETIME);
            }
            else {
                rb_raise(rast_rb_eError, "unknown property type");
            }
        }
    }

    error = rast_db_register(db, RSTRING(text)->ptr, RSTRING(text)->len,
                             property_values, &doc_id);
    rast_rb_raise_error(error);
    return INT2NUM(doc_id);
}

void
Init_rast_test()
{
    VALUE cTextIndex, cQuery;
    VALUE mVNUM;
    VALUE cDB;

    rb_require("rast");

    rb_define_module_function(rast_rb_mRast, "fatal",
                              (VALUE (*)(VALUE)) rb_exc_fatal, 1);

    cDB = rb_const_get(rast_rb_mRast, rb_intern("DB"));
    cTerm = rb_const_get(rast_rb_mRast, rb_intern("Term"));
    cTermFrequency = rb_const_get(rast_rb_mRast, rb_intern("TermFrequency"));
    cCandidate = rb_const_get(rast_rb_mRast, rb_intern("Candidate"));
    cQueryResult = rb_const_get(rast_rb_mRast, rb_intern("QueryResult"));

    cEncoding = rb_define_class_under(rast_rb_mRast, "Encoding", rb_cObject);
    rb_undef_method(CLASS_OF(cEncoding), "new");
    rb_define_singleton_method(cEncoding, "[]", encoding_s_aref, 1);
    rb_define_method(cEncoding, "register_tokenize",
                     encoding_register_tokenize, 1);
    rb_define_method(cEncoding, "search_tokenize",
                     encoding_search_tokenize, 1);
    rb_define_method(cEncoding, "normalize_text", encoding_normalize_text, 1);
    rb_define_method(cEncoding, "normalize_chars",
                     encoding_normalize_chars, 1);

    mVNUM = rb_define_module_under(rast_rb_mRast, "VNUM");
    rb_define_module_function(mVNUM, "pack", vnum_pack, 1);
    rb_define_module_function(mVNUM, "unpack", vnum_unpack, 1);

    cEncoding = rb_define_class_under(rast_rb_mRast, "Encoding", rb_cObject);
    rb_undef_method(CLASS_OF(cEncoding), "new");
    rb_define_singleton_method(cEncoding, "[]", encoding_s_aref, 1);
    rb_define_method(cEncoding, "register_tokenize",
                     encoding_register_tokenize, 1);
    rb_define_method(cEncoding, "search_tokenize",
                     encoding_search_tokenize, 1);
    rb_define_method(cEncoding, "normalize_text", encoding_normalize_text, 1);
    rb_define_method(cEncoding, "normalize_chars",
                     encoding_normalize_chars, 1);

    cTextIndexer = rb_define_class_under(rast_rb_mRast, "TextIndexer",
                                         rb_cObject);
    rb_undef_method(CLASS_OF(cTextIndexer), "new");
    rb_define_method(cTextIndexer, "add", text_indexer_add, 1);
    rb_define_method(cTextIndexer, "commit", text_indexer_commit, 0);

    cTextIndex = rb_define_class_under(rast_rb_mRast, "TextIndex", rb_cObject);
    rb_define_alloc_func(cTextIndex, text_index_alloc);
    rb_define_singleton_method(cTextIndex, "optimize",
                               text_index_s_optimize, 3);
    rb_define_method(cTextIndex, "initialize", text_index_initialize, -1);
    rb_define_method(cTextIndex, "close", text_index_close, 0);
    rb_define_method(cTextIndex, "sync", text_index_sync, 0);
    rb_define_method(cTextIndex, "register", text_index_register, -1);
    rb_define_method(cTextIndex, "search", text_index_search, 1);

    cQuery = rb_define_class_under(rast_rb_mRast, "Query", rb_cObject);
    rb_define_alias(CLASS_OF(cQuery), "parse", "new");
    rb_define_alloc_func(cQuery, query_alloc);
    rb_define_method(cQuery, "initialize", query_initialize, 1);
    rb_define_method(cQuery, "inspect", query_inspect, 0);
    rb_define_method(cQuery, "exec", query_exec, -1);
    rb_define_method(cQuery, "optimize", query_optimize, 0);

    rb_define_method(cDB, "register_raw", db_register_raw, 2);
}

/* vim: set filetype=c sw=4 expandtab : */
