# GNU Solfege - free ear training software
# vim: set fileencoding=utf-8 :
# Copyright (C) 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009 Tom Cato Amundsen
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import logging
import sqlite3
import hashlib
import pickle
import os
import time
import filesystem
import lessonfile
import utils
import mpd

db = None

def hash_lessonfile_text(s):
    """
    Return the hashvalue of the string s, after filtering out:
    * lines starting with '#'
    * empty lines
    """
    lines = s.split("\n")
    lines = [x for x in lines if (not x.startswith("#") and bool(x))]
    sha1 = hashlib.sha1()
    sha1.update("\n".join(lines))
    return sha1.hexdigest()


def hash_of_lessonfile(filename):
    return hash_lessonfile_text(open(filename, 'r').read())


class DB(object):
    """
    This DB class is currently located in src.statistics.py, but it might
    move to a separate/another file later, since it will store more data
    than just statistics.
   
    At startup on a fresh install:
    * Scan all lesson files (only standard files, not regression) to
      get uuid, filename, title, hash etc into the lessonfiles table.

    On every startup:
    * After the GUI is ready, we scan all lesson files in the background
      to check if any files have changed.

    When we start to practise an exercise:
    * Check if the lesson file has changed, and if so, we delete the statistics.
    """
    def __init__(self, callback=None):
        """
        callback is called to display progress when scanning lesson files.
        """
        try:
            if testsuite_is_running:
                statistics_filename = ":memory:"
        except NameError:
            statistics_filename = self.get_statistics_filename()
        self.must_recreate = not os.path.exists(statistics_filename)
        try:
            if testsuite_is_running:
                self.must_recreate = False
        except NameError:
            pass
        self.conn = sqlite3.connect(statistics_filename)
        self.setup_tables()
        if self.must_recreate:
            self.scan_lessonfiles(callback)
            self.read_old_data()
    @staticmethod
    def get_statistics_filename():
        return os.path.join(filesystem.app_data(), "statistics.sqlite")
    def setup_tables(self):
        """
        lessonfiles table.
        ==================
        uuid: the uuid of the lesson file, as written in the lessonfile header.
        hash: hash value of the file the last time some statistics for the
              file was entered into the database.
        test_result: The result of the last time a test was made.
        test_passed: If the test was passed the last time taken.
        title: The lesson file, from the lesson file header
        module: The exercise module, taken from the lessonfile header
        """
        self.conn.execute('''create table if not exists lessonfiles
            (uuid text primary key not null,
             hash text not null,
             test_result float default 0.0,
             test_passed int default 0
            )''')
        columns = [x[1] for x in
              self.conn.execute('pragma table_info(lessonfiles)').fetchall()]
        if 'title' not in columns:
            self.conn.execute('alter table lessonfiles add column title txt')
        if 'module' not in columns:
            self.conn.execute('alter table lessonfiles add column module txt')
        if 'filename' not in columns:
            self.conn.execute('alter table lessonfiles add column filename txt')
        self.conn.execute('''create table if not exists sessions
            (uuid text, timestamp int, answerkey text, guessed text, count int, unique (uuid, timestamp, answerkey, guessed) )''')
        self.conn.execute('''create table if not exists tests
            (uuid text, timestamp int, answerkey text, guessed text, count int, unique (uuid, timestamp, answerkey, guessed) )''')
    def get_field(self, uuid, fields):
        if not isinstance(fields, (tuple, list)):
            vfields = [fields]
        else:
            vfields = fields
        row = self.conn.execute('select %s from lessonfiles where uuid=?'
                                % ', '.join(vfields), (uuid,)).fetchone()
        if isinstance(fields, (tuple, list)):
            return row
        else:
            return row[0]
    def scan_lessonfiles(self, callback=None):
        """
        Read info about standard lesson files into the lessonfiles table.
        Current lessonfile manager code cause us to parse the user contributed
        files in ~/lessonfiles too.
        """
        count = 0
        cc = lessonfile.manager.old_count()
        old_uuid = None
        uuid_count = 0
        progress = 0
        old_progress = 0
        for uuid, data in lessonfile.manager.parse(False):
            if callback and uuid != old_uuid:
                uuid_count += 1
                old_uuid = uuid
            self.set_lessonfile_data(uuid, data['header']['title'],
                    data['header']['module'], data['filename'])
            if callback:
                progress = uuid_count * 100 / cc
                if old_progress != progress:
                    callback(progress)
                    old_progress = progress
            count += 1
        logging.info("scan_lessonfiles: scanned %i files." % count)
    def set_lessonfile_data(self, uuid, title, module, filename):
        """
        Set the values title, module, filename and hash of uuid
        if the lesson file has changed.
        """
        row = self.conn.execute("select uuid, hash from lessonfiles where uuid=?",
                (uuid,)).fetchone()
        hashvalue = hash_of_lessonfile(filename)
        if row:
            if hashvalue != row[1]:
                self.validate_stored_statistics(uuid, filename)
                self.conn.execute("update lessonfiles "
                    "set title=:title, module=:module, filename=:filename, "
                    "hash=:hashvalue "
                    "where uuid=:uuid", locals())
        else:
            self.conn.execute("insert into lessonfiles "
                    "(uuid, title, module, filename, hash) "
                    "values (?, ?, ?, ?, ?)",
                    (uuid, title, module, filename, hashvalue))
        self.conn.commit()
    def read_old_data(self):
        """
        This function should be run once to import the old format statistics
        from ~/.solfege/statistics and into the sqlite database.
        We will import statistics for all uuids that exist in the lessonfiles table.
        """
        st_home = os.path.join(filesystem.app_data(), 'statistics')
        ign = 0
        imp = 0
        for row in self.conn.execute('select uuid from lessonfiles'):
            uuid = row[0]
            if not os.path.exists(os.path.join(st_home, uuid)):
                ign += 1
                continue
            for timestamp in os.listdir(os.path.join(st_home, uuid)):
                f = open(os.path.join(st_home, uuid, timestamp), 'r')
                session = pickle.load(f)
                f.close()
                for correct in session.keys():
                    for guess in session[correct]:
                        self.conn.execute("insert into sessions values(?, ?, ?, ?, ?)",
                        (uuid, timestamp, unicode(correct), unicode(guess), session[correct][guess]))
                imp += 1
        logging.info("imported statistics for %i lessonfiles, ignored %i" % (imp, ign))
        self.conn.commit()
    def last_test_result(self, uuid):
        """
        Return a float in the range 0.0 to 1.0 telling the score of the
        last test run on this uuid.
        Return None if no tests found.
        """
        cursor = self.conn.cursor()
        row = cursor.execute("select timestamp from tests where uuid=? order by -timestamp", (uuid,)).fetchone()
        # timestamp is not the timestamp of the last test run on this uuid.
        if row:
            timestamp = row[0]
            correct_count = cursor.execute("select sum(count) from tests where uuid=? and timestamp=? and answerkey=guessed", (uuid, timestamp)).fetchone()[0]
            total_count = cursor.execute("select sum(count) from tests where uuid=? and timestamp=?", (uuid, timestamp)).fetchone()[0]
            return correct_count * 1.0 / total_count
        return None
    def is_test_passed(self, uuid):
        row = self.conn.execute("select test_passed from lessonfiles "
                                "where uuid=?", (uuid,)).fetchone()
        if row:
            return row[0]
        return False
    def validate_stored_statistics(self, uuid, filename):
        """
        Check if the lesson file has changed since last time statistics for
        this lessonfile was added. Remove old statistics if the 
        m_t.m_P has already changed to the new file when this method is called.
        """
        cursor = self.conn.cursor()
        row = cursor.execute("select hash from lessonfiles where uuid=?", (uuid,)).fetchone()
        cur_lessonfile_hash_value = hash_of_lessonfile(filename)
        if not row:
            # Ususally the uuid exists in the database, but when running the
            # test suite, it does not, so we have the code here to add it.
            cursor.execute("insert into lessonfiles (uuid, hash) values(?, ?)",
                    (uuid, cur_lessonfile_hash_value))
            self.conn.commit()
        else:
            # row[0] is the hash value of the lesson file when all the existing
            # statistics for the lesson file with this uuid was saved.
            if row[0] != cur_lessonfile_hash_value:
                cursor.execute("delete from sessions where uuid=?", (uuid,))
                cursor.execute("update lessonfiles "
                    "set hash=?, test_passed=0, test_result=0.0 where uuid=?",
                    (cur_lessonfile_hash_value, uuid))
                self.conn.commit()
    def search(self, searchfor, columns):
        """
        Search for searchfor in the title column (and possible other places
        later). Return the columns named by the list columns.
        """
        searchfor = "%%%s%%" % searchfor
        n =  self.conn.execute("select %s " % ", ".join(columns)
             + "from lessonfiles where title like ?", (searchfor,)).fetchall()
        if not n:
            return []
        return n


class AbstractStatistics(object):
    def __init__(self, teacher):
        self.m_t = teacher
        self.m_timestamp = None
        self.m_test_mode = False
    def get_keys(self, all_keys=False):
        """
        Return the keys for all questions that have been answered correctly.
        If all_keys are false, it should also return the correct key for all
        questions that only have been answered wrongly.
        """
        if all_keys:
            c = db.conn.execute("select distinct(answerkey) from sessions where uuid=?", (self.m_t.m_P.header.lesson_id,))
        else:
            c = db.conn.execute("select distinct(answerkey) from sessions where uuid=? and answerkey=guessed", (self.m_t.m_P.header.lesson_id,))
        return [x[0] for x in c]
    def get_statistics(self, seconds):
        """
        return a dict with statistics more recent than 'seconds' seconds.
        The keys of dict are the correct answers for the lesson file.
        And the values of dict are new dicts where the keys are all the
        answers the user have given and the values are the number of times
        that particular answer has been given.
        """
        if seconds == -1:
            q = db.conn.execute("select answerkey, guessed, sum(count) from sessions where uuid=? group by answerkey, guessed", (self.m_t.m_P.header.lesson_id,))
        elif seconds == 0:
            q = db.conn.execute("select answerkey, guessed, sum(count) from sessions where uuid=? and timestamp=? group by answerkey, guessed", (self.m_t.m_P.header.lesson_id, self.m_timestamp))
        else:
            q = db.conn.execute("select answerkey, guessed, sum(count) from sessions where uuid=? and timestamp>? group by answerkey, guessed", (self.m_t.m_P.header.lesson_id, self.m_timestamp - seconds))
        ret = {}
        for answer, guess, count in q.fetchall():
            ret.setdefault(answer, {})
            ret[answer][guess] = count
        return ret
    def reset_session(self):
        """
        Start a new practise session.
        """
        self.m_timestamp = int(time.time())
    def enter_test_mode(self):
        self.m_test_mode = True
    def exit_test_mode(self):
        self.m_test_mode = False
        test_result = db.last_test_result(self.m_t.m_P.header.lesson_id)
        db.conn.execute("update lessonfiles "
                        "set test_result=?, test_passed=? where uuid=?",
            (test_result,
             test_result >= self.m_t.m_P.get_test_requirement(),
             self.m_t.m_P.header.lesson_id,))
        db.conn.commit()
    def _add(self, question, answer):
        """
        Register that for the question 'question' the user answered 'answer'.
        """
        assert self.m_timestamp
        # tuples must be converted to str to store them in sqlite.
        if isinstance(question, tuple):
            question = str(question)
        if isinstance(answer, tuple):
            answer = str(answer)
        table = {True: 'tests', False: 'sessions'}[self.m_test_mode]
        cursor = db.conn.cursor()
        row = cursor.execute(
                "select count from %s where uuid=? and timestamp=? "
                "and answerkey=? and guessed=?" % table,
                (self.m_t.m_P.header.lesson_id, 
                 self.m_timestamp, question, answer)).fetchone()
        if not row:
            cursor.execute(
                "insert into %s "
                "(uuid, timestamp, answerkey, guessed, count) "
                "values(?, ?, ?, ?, ?)" % table,
                (self.m_t.m_P.header.lesson_id, self.m_timestamp,
                 unicode(question), unicode(answer), 1))
        else:
            assert cursor.fetchone() is None
            cursor.execute(
                "update %s set count=? where "
                "uuid=? and timestamp=? and answerkey=? and guessed=?" % table,
                (row[0] + 1, self.m_t.m_P.header.lesson_id,
                 self.m_timestamp, question, answer))
        db.conn.commit()
    def add_wrong(self, question, answer):
        self._add(question, answer)
    def add_correct(self, answer):
        self._add(answer, answer)
    def get_last_test_result(self):
        """
        Return the test result of the last test ran for this lesson file.
        """
        return db.last_test_result(self.m_t.m_P.header.lesson_id)
    def get_percentage_correct(self):
        """Will return a 0 <= value <= 1.0 that say how many percent is
        correct in this session
        """
        num_correct = db.conn.execute("select sum(count) from sessions where answerkey=guessed and timestamp=? and uuid=?", (self.m_timestamp, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        num_asked = db.conn.execute("select sum(count) from sessions where timestamp=? and  uuid=?", (self.m_timestamp, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        if not num_correct:
            num_correct = 0
        if not num_asked:
            num_asked = 0
            return 0
        return 100.0 * num_correct / num_asked
    def get_percentage_correct_for_key(self, seconds, key):
        """
        Return the percentage correct answer the last 'seconds' seconds.
        """
        # All statistics
        num_guess = self.get_num_guess_for_key(seconds, key)
        if num_guess:
            return 100.0 * self.get_num_correct_for_key(seconds, key) / num_guess
        return 0
    def get_num_correct_for_key(self, seconds, key):
        """
        Return the number of correct answers for the given key 'key' the
        last 'seconds' seconds.
        Special meanings of 'seconds':
            -1  all statistics
             0  statistics from this session
        """
        if seconds == -1:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and guessed=? and uuid=?", (key, key, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        elif seconds == 0:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and guessed=? and timestamp=? and uuid=?", (key, key, self.m_timestamp, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        else:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and guessed=? and timestamp>? and uuid=?", (key, key, self.m_timestamp - seconds, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        if ret:
            return ret
        return 0
    def get_num_guess_for_key(self, seconds, key):
        """
        See get_num_correct_for_key docstring.
        """
        if seconds == -1:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and uuid=?", (key, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        elif seconds == 0:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and timestamp=? and uuid=?", (key, self.m_timestamp, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        else:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and timestamp>? and uuid=?", (key, self.m_timestamp - seconds, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        if ret:
            return ret
        return 0


class LessonStatistics(AbstractStatistics):
    def key_to_pretty_name(self, key):
        def ff(x):
            t = eval(x)
            if type(t) == tuple:
                return lessonfile.LabelObject(t[0], t[1])
            else:
                return t
        for question in self.m_t.m_P.m_questions:
            if question.name.cval == key:
                return ff(question.name)
        return ff(key)

class IntervalStatistics(AbstractStatistics):
    def key_to_pretty_name(self, key):
        return utils.int_to_intervalname(int(key), 1, 1)

class HarmonicIntervalStatistics(AbstractStatistics):
    def key_to_pretty_name(self, key):
        return utils.int_to_intervalname(int(key), 1, 0)

class IdToneStatistics(LessonStatistics):
    def key_to_pretty_name(self, key):
        return mpd.MusicalPitch.new_from_notename(key).get_user_notename()


