#!/usr/bin/env python
# -*- coding: utf-8 -*-

# てかデータベースとかの扱いに慣れてればパフォーマンス的に
# ずっとマトモなモノが作れたような。

from vector3d import *

import re
import sys

# data must be already sorted
def uniq( data ):
  ret = []
  for d in data:
    if len( ret ) == 0:
      ret.append( d )
    elif ret[-1] != d:
      ret.append( d )
  return( ret )

# create a PdbQuery to extract amino acid
# some non-amino acid names are added for AMBER and CHARMM nomenclature
def isAminoAcid():
  resnames = [ "ALA", "GLY", "VAL", "LEU", "ILE", "PHE", "TRP", "CYS", "MET", "SER", "THR", "TYR", "ASP", "GLU", "ASN", "GLN", "HIS", "LYS", "ARG", "PRO", "HSD", "HSE", "HSP", "HID", "HIE", "HIP", "RET", "LYR", "HC4", "GLH", "ASH", "CYX" ]
  ret = PdbQuery()
  for res in resnames:
    ret.selectResidues( res )
  return( ret )

class Elements:
  elements_base = [ "H", "LI", "B", "C", "N", "O", "F",
                    "NA", "MG", "P", "S", "CL",
                    "K", "CA", "TI", "V", "MN", "FE",
                    "NI", "ZN", "GA", "GE", "AS",
                    "AG", "CS", "IR", "AU" ]
  #elements_base = [ "H", "HE", "LI", "BE", "B", "C", "N", "O", "F", "NE",
  #                  "NA", "MG", "AL", "SI", "P", "S", "CL", "AR",
  #                  "K", "CA", "SC", "TI", "V", "CR", "MN", "FE", "CO",
  #                  "NI", "CU", "ZN", "GA", "GE", "AS", "SE", "BR", "KR",
  #                  "RB", "SR", "Y", "ZR", "NB", "MO", "TC", "RU", "RH",
  #                  "PD", "AG", "CD", "IN", "SN", "SB", "TE", "I", "XE",
  #                  "CS", "BA", "HF", "TA", "W", "RE", "OS", "IR", "PT",
  #                  "AU", "HG", "TL", "PB", "BI", "PO", "AT", "RN",
  #                  "FR", "RA" ]
  elements = set( elements_base )
  
  def isElementName( name ):
    if name in Elements.elements:
      return( True )
    return( False )

  isElementName = staticmethod( isElementName )

# an atom in pdb file; ATOM and HETATM record
class Atom:
  def __init__( self ):
    self.index = -1
    self.name = ""
    self.alternative = ""
    self.pos = Vector3D()
    self.resname = ""
    self.chainID = ""
    self.resindex = -1
    self.occupancy = 1.0
    self.bfactor = 0.0
    self.atomtype = ""

  def readIn( self, line ):
    self.index = int( line[6:11] )
    self.name = line[12:16].strip()
    if line[16] != " ":
      self.alternative = line[16]
    self.resname = line[17:20].strip()
    if line[21] != " ":
      self.chainID = line[21]
    self.resindex = int( line[22:26] )
    self.pos.x = float( line[30:38] )
    self.pos.y = float( line[38:46] )
    self.pos.z = float( line[46:54] )
    self.occupancy = float( line[54:60] )
    self.bfactor = float( line[60:66] )
    if len( line ) > 79:
      self.atomtype = line[76:78].strip()
    else:
      if Elements.isElementName( self.name ):
        self.atomtype = self.name
      else:
        if len( self.name ) >= 4:
          if Elements.isElementName( self.name[1] ):
            self.atomtype = self.name[1]
            return
        for c in self.name:
          if Elements.isElementName( c ):
            self.atomtype = c
            break

  def __cmp__( self, at ):
    if self == at:
      return( 0 )
    elif self < at:
      return( -1 )
    return( 1 )

  def __eq__( self, at ):
    if self.index != at.index:
      return( False )
    if self.resindex != at.resindex:
      return( False )
    if self.name != at.name:
      return( False )
    return( True )

  def __lt__( self, at ):
    if self.index != at.index:
      return( self.index < at.index )
    if self.resindex != at.resindex:
      return( self.resindex < at.resindex )
    if self.name != at.name:
      return( self.name < at.name )
    return( False )

  def __str__( self ):
    return( str( self.index ) + self.name + " " + self.alternative + self.resname + " " + str( self.resindex ) + " " + str( self.pos ) + " " + str( self.occupancy ) + " " + str( self.bfactor ) ) + " " + self.atomtype

  def write( self, toWhat = sys.stdout ):
    print >> toWhat, self.toPdbString()

  def toPdbString( self ):
    ret = "ATOM  " + str(self.index).rjust(5) + " "
    if len(self.name) == 4:
      ret += self.name
    else:
      ret += " " + self.name
      for i in range( 0, 3 - len( self.name ) ):
        ret += " "
    ret += self.alternative + self.resname.rjust(4) + " "
    ret += str(self.resindex).rjust(5) + "    "
    ret += "%8.3f%8.3f%8.3f" % ( self.pos.x, self.pos.y, self.pos.z )
    ret += "%6.2f%6.2f" % ( self.occupancy, self.bfactor )
    ret += "          " + self.atomtype.rjust(2)
    return( ret )

  def matchQuery( self, query, resQuery = False ):
    if len( query ) == 0:
      return( False )
    if resQuery:
      for que in query:
        if que.indexQuery:
          if que.index != self.resindex:
            return( False )
        elif que.rangeQuery:
          if que.initNum > self.resindex or que.lastNum < self.resindex:
            return( False )
        elif que.nameQuery:
          if not que.regQuery.search( self.resname ):
            return( False )
        else:
          return( False )
    else:
      for que in query:
        if que.indexQuery:
          if que.index != self.index:
            return( False )
        elif que.rangeQuery:
          if que.initNum > self.index or que.lastNum < self.index:
            return( False )
        elif que.nameQuery:
          if not que.regQuery.search( self.name ):
            return( False )
        else:
          return( False )
    return( True )

# a residue in pdb file
class Residue:
  def __init__( self ):
    self.name = ""
    self.index = -1
    self.atoms = []

  def addAtom( self, atom ):
    if self.name == "":
      self.name = atom.resname.strip();
    self.atoms.append( atom )

  def empty( self ):
    if len( self.atoms ) == 0:
      return( True )
    return( False )

  def write( self, toWhat = sys.stdout ):
    #print "Residue:", self.name, self.index
    for atom in self.atoms:
      atom.write( toWhat )

  def matchQuery( self, query ):
    if len( query ) == 0:
      return( False )
    for que in query:
      if que.indexQuery:
        if self.index != que.index:
          return( False )
      elif que.rangeQuery:
        if self.index < que.initNum or self.index > que.lastNum:
          return( False )
      elif que.nameQuery:
        if not que.regQuery.search( self.name ):
          return( False )
      else:
        return( False )
    return( True )

  def natom( self ):
    return( len( self.atoms ) )

  def getAll( self ):
    ret = []
    for atom in self.atoms:
      ret.append( atom )
    return( ret )

  def select( self, query ):
    ret = []
    for que in query._residueQuery:
      for queres in que:
        if self.matchQuery( queres ):
          ret.extend( self.getAll() )
    for atom in self.atoms:
      for que in query._atomQuery:
        for queatm in que:
          if atom.matchQuery( queatm, False ):
            ret.append( atom )
    return( ret )

# a molecule in pdb file
class Molecule:
  def __init__( self ):
    self.index = -1
    self.residues = {}

  def addAtom( self, atom ):
    sname = atom.resname.strip()
    key = sname + str( atom.resindex )
    if self.residues.has_key( key ):
      resid = self.residues.get( key )
    else:
      self.residues[key] = Residue()
      resid = self.residues.get( key )
      resid.index = atom.resindex
      resid.name = atom.resname.strip()
    resid.addAtom( atom )

  def empty( self ):
    if len( self.residues ) == 0:
      return( True )
    return( False )

  def write( self, toWhat = sys.stdout ):
    #print "Molecule:", self.index
    for index, resid in self.residues.items():
      resid.write( toWhat )

  def matchQuery( self, query ):
    if len( query ) == 0:
      return( False )
    for que in query:
      if que.indexQuery:
        if self.index != que.index:
          return( False )
      elif que.rangeQuery:
        if self.index < que.initNum and self.index > que.lastNum:
          return( False )
      else:
        return( False )
    return( True )

  def natom( self ):
    ret = 0
    for resid in self.residues.values():
      ret += resid.natom()
    return( ret )

  def getAll( self ):
    ret = []
    for resid in self.residues.values():
      ret.extend( resid.getAll() )
    return( ret )

  def select( self, query ):
    ret = []
    for que in query._molQuery:
      for quemol in que:
        if self.matchQuery( quemol ):
          ret.extend( self.getAll() )
    for resid in self.residues.values():
      ret.extend( resid.select( query ) )
    return( ret )

# whole pdb data
class Pdb:
  def __init__( self ):
    self.molecules = []
    self._reg_atom = re.compile( "^(ATOM|HETATM)" )
    self._reg_ter = re.compile( "^TER" )

  def read( self, fromwhat ):
    for line in fromwhat:
      if not self._reg_atom.match( line ):
        if self._reg_ter.match( line ):
          self.addNewMol()
        continue
      line = line.strip()
      atom = Atom()
      atom.readIn( line )
      self.addAtom( atom )

  def addAtom( self, atom ):
    if self.empty():
      self.addNewMol()
    self.molecules[-1].addAtom( atom )

  def addNewMol( self ):
    if not self.empty():
      if self.molecules[-1].empty():
        return
    mol = Molecule()
    mol.index = len( self.molecules )
    self.molecules.append( mol )

  def natom( self ):
    ret = 0
    for mol in self.molecules:
      ret += mol.natom()
    return( ret )

  def empty( self ):
    if len( self.molecules ) == 0:
      return( True )
    return( False )

  def write( self, toWhat = sys.stdout ):
    for mol in self.molecules:
      mol.write( toWhat )

  def getAll( self ):
    ret = []
    for mol in self.molecules:
      ret.extend( mol.getAll() )
    ret.sort()
    ret = uniq( ret )
    return( ret )

# query:
# examples -
#   1-10 6
#   1 2 3 4 5
#   HIS SER LYS
#   HI? SE*
#   within(5,23) : not implemented now
# some regular expressions are avilable
class Query:
  def __init__( self ):
    self.indexQuery = False
    self.rangeQuery = False
    self.nameQuery = False
    self.specialQuery = False

  def genIndexQuery( self, i ):
    self.indexQuery = True
    self.index = i;

  def genRangeQuery( self, i, l ):
    self.rangeQuery = True
    self.initNum = i
    self.lastNum = l

  def genNameQuery( self, string ):
    self.nameQuery = True
    self.regQuery = re.compile( string )

  # future plan
  def genSpecialQuery( self ):
    pass

  def empty( self ):
    if not self.indexQuery and not self.rangeQuery and not self.nameQuery and not self.specialQuery:
      return( True )
    return( False )

# Query class: will be used to extract some information from Pdb class
class PdbQuery:
  def __init__( self ):
    self.clear()

  def clear( self ):
    self._molQuery = []
    self._residueQuery = []
    self._atomQuery = []

  def empty( self ):
    if len( self._molQuery ) == 0 and len( self._residueQuery ) == 0 and len( self._atomQuery ) == 0:
      return( True )
    return( False )

  def selectMolecules( self, query ):
    self._molQuery.append( self.parseQuery( query ) )

  def selectResidues( self, query ):
    self._residueQuery.append( self.parseQuery( query ) )

  def selectAtoms( self, query ):
    self._atomQuery.append( self.parseQuery( query ) )

  def parseQuery( self, query ):
    ret = []
    # remove spaces before and after "&"/"-"
    reg = re.compile( "[ ]*&[ ]*" )
    query = reg.sub( "&", query )
    reg = re.compile( "[ ]*-[ ]*" )
    query = reg.sub( "-", query )
    # at first, split the query string
    component = query.split()
    # search range qurey
    for each in component:
      ret.append( [] )
      r = each.split( '&' )
      if len(r) > 1:
        for q in r:
          ret[-1].append( self.genQuery( q ) )
      else:
        ret[-1].append( self.genQuery( each ) )
      if len( ret[-1] ) == 0:
        ret.pop()
    return( ret )

  def genQuery( self, each ):
    ret = Query()
    r = each.split( '-' )
    if len(r) == 2:
      # may be range query, verify it
      if r[0].isdigit() and r[1].isdigit():
        ret.genRangeQuery( int(r[0]), int(r[1]) )
      else:
        print "range query permits only number:", each
        raise
    elif len(r) > 3:
      # apparently it is error
      print "invalid range query string found:", each
      raise
    elif each.isdigit():
      ret.genIndexQuery( int(each) )
    else:
      ret.genNameQuery( each )
    return( ret )

  def select( self, pdb ):
    ret = []
    if self.empty():
      return( pdb.getAll() )
    elif isinstance( pdb, Pdb ):
      for mol in pdb.molecules:
        ret.extend( mol.select( self ) )
    else:
      for atom in pdb:
        for que in self._residueQuery:
          for queres in que:
            if atom.matchQuery( queres, True ):
              ret.append( atom )
        for que in self._atomQuery:
          for queatm in que:
            if atom.matchQuery( queatm, False ):
              ret.append( atom )
    ret.sort()
    ret = uniq( ret )
    return( ret )
