#!/usr/bin/python
# -*- coding:utf-8 -*-

import re
import sys

import imagesize

def _debug_print(*args):
    print >> sys.stderr, args

def create_mode(lexi, mode_name):
    modes = lexi.get("modes")
    return Mode(lexi, modes[mode_name], mode_name)

class Mode(object):
    def __init__(self, lexi, attrs, mode_name):
        self._lexi = lexi
        self._attrs = attrs
        self.name = mode_name
        base_mode = attrs.get("extends", False)
        if base_mode:
            self._base = create_mode(self._lexi, base_mode)
        else:
            self._base = None

    def test(self, attr_name, text):
        '''test attr_name matches text'''
        rex = self.attr(attr_name, False)
        #_debug_print(rex)
        if rex and re.search(rex, text):
            return True
        else:
            return False

    def has_attr(self, attr_name):
        '''if attr_name exists, return True'''
        if attr_name in self._attrs:
            return True
        if self._base:
            return self._base.has_attr(attr_name)
        return False

    def attr(self, attr_name, default=None):
        '''return value of attr_name'''
        if attr_name in self._attrs:
            return self._attrs.get(attr_name)
        if self._base:
            return self._base.attr(attr_name, default)
        return default

    def is_true(self, attr_name):
        '''if attr_name is True, return True'''
        return self.attr(attr_name, False)

    def is_false(self, attr_name):
        '''if attr_name is False, return True'''
        return not self.attr(attr_name, True)

    def has_rule(self, rule_name):
        '''if rule_name exists, return True'''
        rules = self.get("rules", {})
        return rule_name in rules

    def _rules(self):
        if self._base:
            rules = self._base._rules()
            rules.update(self._attrs.get('rules', {}))
        else:
            rules = self._attrs.get('rules', {})
        sub_rules = self.attr('includeRule', [])
        for sub_rule in sub_rules:
            rule = self._lexi.get('rules', {}).get(sub_rule, {})
            rules.update(rule)
        return rules

    def rules(self):
        '''return list include sorted rule names'''
        rules = self._rules()
        rule_keys = rules.keys()
        sort_fn = lambda x,y:-cmp(rules[x].get("priority",0), rules[y].get("priority", 0))
        rule_keys.sort(sort_fn)
        return rule_keys

    def rule(self, rule_name):
        '''return rule dict corresponds rule_name'''
        rules = self.attr('rules', {})
        if rule_name in rules:
            return rules.get(rule_name)
        sub_rules = self.attr('includeRule', [])
        for sub_rule in sub_rules:
            rule = self._lexi.get('rules', {}).get(sub_rule, {})
            if rule_name in rule:
                return rule.get(rule_name)
        if self._base:
            return self._base.rule(rule_name)
        return None

    def on_start(self, writer, text):
        '''action when mode started'''
        if self.has_attr("onStart"):
            m = self.attr("onStart")
            if "insert" in m:
                writer.write(m["insert"])
            if "replace" in m:
                text = m["replace"]
        return text

    def on_exit(self, writer, text):
        '''action when mode finished'''
        if self.has_attr("onFinished"):
            m = self.attr("onFinished")
            if "insert" in m:
                writer.write(m["insert"])
            if "replace" in m:
                text = m["replace"]
        return text


class ModeStack(object):
    def __init__(self):
        self.stack = []

    def push(self, mode):
       self.stack.append(mode)

    def pop(self):
        return self.stack.pop()

    def current(self):
        try:
            return self.stack[-1]
        except IndexError:
            return None

class Store(dict):
    def __init__(self, mode_stack):
        self.mode_stores = {}
        self.global_store = {}
        self._mode_stack = mode_stack;

    def _mode(self):
        return self._mode_stack.current()

    def save_global(self, key, value):
        self.global_store[key] = value

    def save(self, key, value):
        mode = self._mode()
        if not mode in self.mode_stores:
            self.mode_stores[mode] = {}
        self.mode_stores[mode][key] = value

    def load(self, key, default=None):
        mode = self._mode()
        if (mode in self.mode_stores) and (key in self.mode_stores[mode]):
            return self.mode_stores[mode][key]
        else:
            return self.global_store.get(key, default)

    def delete(self, key):
        mode = self._mode()
        if not mode in self.mode_stores:
            return 
        if not key in self.mode_stores[mode]:
            return
        del self.mode_stores[mode][key]

    def clear(self):
        mode = self._mode()
        if mode in self.mode_stores:
            del self.mode_stores[mode]


class Parser(object):
    """markupper class"""

    def __init__(self, lexi):
        """Create TextWriter object.
        @param {dict} lexi text lexical object
        @return TextWriter object
        """
        self.lexi = lexi
        self.mode_stack = ModeStack()
        self.mode_stack.push("global")

        self.store = Store(self.mode_stack)
        self.functions = {
            "getImageGeometry": getImageGeometry,
            }

    def _get_mode(self, mode_name):
        return create_mode(self.lexi, mode_name)

    def current_mode(self):
        mode_name = self.mode_stack.current()
        if mode_name:
            return self._get_mode(mode_name)

    def markup(self, iter_in, stream_out):
        """read from iter_in and output to stream_out
        @param {iterator} iter_in input iterator
        @param {stream} stream_out output stream
        """
        self.stream_out = stream_out
        try:
            while self.current_mode():
                l = iter_in.next().strip('\r\n')
                out = self._markup(l)
                self.write(out)
                self.write('\n')
        except StopIteration:
            return

    def write(self, text):
        self.stream_out.write(text)
        
    def _expand_variable(self, text, match=None):
        if text.find(u'$') < 0:
            return text
        # expand $[0-9]+
        rex = re.compile('\$([0-9]+)')
        m = rex.search(text)
        if m and match:
            sub_func = lambda x:match.group(int(x.group(1)))
            text = rex.sub(sub_func, text)

        # expand vars
        rex = re.compile('\${?([A-Za-z0-9_]+)}?')
        m = rex.search(text)
        sub_func = lambda x:self.store.load(x.group(1), '')
        text = rex.sub(sub_func, text)
        return text

    def _markup(self, text):
        # check global rule
        gi = self.lexi.get("globalIdentifier", False)
        if gi:
            m_gvi = re.search(gi, text)
            if m_gvi:
                self.store.save_global(m_gvi.group(1), m_gvi.group(2))
                return ''

        mode = self.current_mode()
        #_debug_print(mode.name)
        if mode.test("end", text):
            text = mode.on_exit(self, text)
            self.store.clear()
            self.mode_stack.pop()
            next_mode = self.current_mode()
            if next_mode:
                # start next mode
                text = self._markup(text)
            return text

        if mode.has_attr('transitions'):
            for candidate in mode.attr('transitions'):
                sub_mode = self._get_mode(candidate)
                if sub_mode and sub_mode.test('begin', text):
                    self.mode_stack.push(candidate)
                    next_mode = self.current_mode()
                    text = next_mode.on_start(self, text)
                    text = self._markup(text)
                    return text

        for key in mode.rules():
            (is_finish, text) = self.apply_rule(mode.rule(key), text)
            if is_finish:
                break

        return text

    def execute_action(self, rex, match, rule, text):
        if 'call' in rule:
            param = rule['call']
            (func, args) = param[0], param[1:]
            if func in self.functions:
                context = self.store
                args = [self._expand_variable(x, match) for x in args]
                results = self.functions[func](context, args)
                for (k, v) in results:
                    self.store.save(k, v)

            
        if 'apply' in rule:
            self.mode_stack.push(rule["apply"])
            text = self._markup(text)
            self.mode_stack.pop()

        if 'switch' in rule:
            for key in rule['switch'].keys():
                value = self.store.load(key)
                if value == None:
                    continue
                if value not in rule['switch'][key]:
                    continue
                new_rule = rule['switch'][key][value]
                return self.execute_action(rex, match, new_rule, text)

        if 'store' in rule:
            arg = rule['store']
            if isinstance(arg, list):
                for index in range(len(arg)):
                    self.store.save(arg[index], match.group(index+1))
            else:
                self.store.save(arg, match.group(1))

        if 'unset' in rule:
            key = rule['unset']
            self.store.delete(key) 

        if 'set' in rule:
            arg = rule['set']
            self.store.save(arg[0], arg[1]) 

        if 'replace' in rule:
            text = rex.sub(rule['replace'], text)
            text = self._expand_variable(text)

        if 'continue' in rule:
            if rule['continue'] == False:
                return (True, text)
        return (False, text)
        

    def apply_rule(self, rule, text):
        # if 'pass' rule is True, exit
        if rule.get('pass', False):
            return (False, text)

        if 'regexp' in rule:
            rex = re.compile(rule['regexp'])
            match = rex.search(text)
            if match:
                return self.execute_action(rex, match, rule, text)
        return (False, text)


def getImageGeometry(context, args):
    filepath = args[0]
    f = open(filepath, 'r')
    data = f.read()
    f.close()
    (w, h) = imagesize.get_image_size(data)
    return [
        ('width', str(w)),
        ('height', str(h))
        ]
