# GNU Solfege - free ear training software
# vim: set fileencoding=utf-8 :
# Copyright (C) 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009 Tom Cato Amundsen
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import sqlite3
import hashlib
import pickle
import os
import time
import filesystem
import lessonfile
import utils
import mpd

db = None

def hash_lessonfile_text(s):
    """
    Return the hashvalue of the string s, after filtering out:
    * lines starting with '#'
    * empty lines
    """
    lines = s.split("\n")
    lines = [x for x in lines if (not x.startswith("#") and bool(x))]
    sha1 = hashlib.sha1()
    sha1.update("\n".join(lines))
    return sha1.hexdigest()


def hash_of_lessonfile(filename):
    return hash_lessonfile_text(open(filename, 'r').read())


class DB(object):
    def __init__(self):
        try:
            if testsuite_is_running:
                statistics_filename = ":memory:"
        except NameError:
            statistics_filename = os.path.join(filesystem.app_data(), "statistics.sqlite")
        self.must_recreate = not os.path.exists(statistics_filename)
        try:
            if testsuite_is_running:
                self.must_recreate = False
        except NameError:
            pass
        self.conn = sqlite3.connect(statistics_filename)
        self.setup_tables()
        if self.must_recreate:
            self.read_old_data()
    def setup_tables(self):
        """
        lessonfiles table. Columnns
        uuid: the uuid of the lesson file, as written in the lessonfile header.
        hash: hash value of the file the last time some statistics for the
              file was entered into the database.
        """
        self.conn.execute('''create table if not exists lessonfiles
            (uuid text primary key not null,
             hash text not null,
             test_result float default 0.0,
             test_passed int default 0
            )''')
        self.conn.execute('''create table if not exists sessions
            (uuid text, timestamp int, answerkey text, guessed text, count int, unique (uuid, timestamp, answerkey, guessed) )''')
        self.conn.execute('''create table if not exists tests
            (uuid text, timestamp int, answerkey text, guessed text, count int, unique (uuid, timestamp, answerkey, guessed) )''')
    def read_old_data(self):
        """
        This function should be run once to import the old format statistics
        from ~/.solfege/statistics and into the sqlite database.
        """
        st_home = os.path.join(filesystem.app_data(), 'statistics')
        for fn in os.listdir(st_home):
            if (not fn.endswith('_hash')) and (fn != 'statistics_version'):
                hashvalue_filename = os.path.join(st_home, fn + "_hash")
                if not os.path.exists(hashvalue_filename):
                    # The ~/.solfege/statistics directory contained directories
                    # namad by the uuid, and files uuid_hash that contained the hash
                    # value. The old statistics code would delete the uuid_hash file
                    # and the statistics stored in the directory if the lesson file
                    # changed, but not delete the directory itself. So there might be
                    # empty directories lying around.
                    continue
                hashvalue = hash_of_lessonfile(lessonfile.manager.get(fn, 'filename'))
                self.conn.execute("insert into lessonfiles (uuid, hash) values(?, ?)",
                                  (fn, hashvalue))
        self.conn.commit()
        c = self.conn.execute('select uuid from lessonfiles')
        for uuid in c:
            for timestamp in os.listdir(os.path.join(st_home, uuid[0])):
                f = open(os.path.join(st_home, uuid[0], timestamp), 'r')
                session = pickle.load(f)
                f.close()
                for correct in session.keys():
                    for guess in session[correct]:
                        self.conn.execute("insert into sessions values(?, ?, ?, ?, ?)",
                        (uuid[0], timestamp, unicode(correct), unicode(guess), session[correct][guess]))
        self.conn.commit()
    def last_test_result(self, uuid):
        """
        Return a float in the range 0.0 to 1.0 telling the score of the
        last test run on this uuid.
        Return None if no tests found.
        """
        cursor = self.conn.cursor()
        row = cursor.execute("select timestamp from tests where uuid=? order by -timestamp", (uuid,)).fetchone()
        # timestamp is not the timestamp of the last test run on this uuid.
        if row:
            timestamp = row[0]
            correct_count = cursor.execute("select sum(count) from tests where uuid=? and timestamp=? and answerkey=guessed", (uuid, timestamp)).fetchone()[0]
            total_count = cursor.execute("select sum(count) from tests where uuid=? and timestamp=?", (uuid, timestamp)).fetchone()[0]
            return correct_count * 1.0 / total_count
        return None
    def is_test_passed(self, uuid):
        row = self.conn.execute("select test_passed from lessonfiles "
                                "where uuid=?", (uuid,)).fetchone()
        if row:
            return row[0]
        return False

class AbstractStatistics(object):
    def __init__(self, teacher):
        self.m_t = teacher
        self.m_timestamp = None
        self.m_test_mode = False
    def get_keys(self, all_keys=False):
        """
        Return the keys for all questions that have been answered correctly.
        If all_keys are false, it should also return the correct key for all
        questions that only have been answered wrongly.
        """
        if all_keys:
            c = db.conn.execute("select distinct(answerkey) from sessions where uuid=?", (self.m_t.m_P.header.lesson_id,))
        else:
            c = db.conn.execute("select distinct(answerkey) from sessions where uuid=? and answerkey=guessed", (self.m_t.m_P.header.lesson_id,))
        return [x[0] for x in c]
    def get_statistics(self, seconds):
        """
        return a dict with statistics more recent than 'seconds' seconds.
        The keys of dict are the correct answers for the lesson file.
        And the values of dict are new dicts where the keys are all the
        answers the user have given and the values are the number of times
        that particular answer has been given.
        """
        q = db.conn.execute("select answerkey, guessed, sum(count) from sessions where uuid=? and timestamp>? group by answerkey, guessed", (self.m_t.m_P.header.lesson_id, 0))
        ret = {}
        for answer, guess, count in q.fetchall():
            ret.setdefault(answer, {})
            ret[answer][guess] = count
        return ret
    def lessonfile_changed(self):
        """
        Check if the lesson file has changed since last time statistics for
        this lessonfile was added. Remove old statistics if the 
        m_t.m_P has already changed to the new file when this method is called.
        """
        cursor = db.conn.cursor()
        row = cursor.execute("select hash from lessonfiles where uuid=?", (self.m_t.m_P.header.lesson_id,)).fetchone()
        cur_lessonfile_hash_value = hash_of_lessonfile(self.m_t.m_P.m_filename)
        if not row:
            cursor.execute("insert into lessonfiles (uuid, hash) values(?, ?)",
                    (self.m_t.m_P.header.lesson_id, cur_lessonfile_hash_value))
        else:
            assert row
            # row[0] is the hash value of the lesson file when all the existing
            # statistics for the lesson file with this uuid was saved.
            if row[0] != cur_lessonfile_hash_value:
                ret = cursor.execute("delete from sessions where uuid=?", (self.m_t.m_P.header.lesson_id,))
                ret = cursor.execute("update lessonfiles set hash=?, test_passed=0, test_result=0.0 where uuid=?",
                    (cur_lessonfile_hash_value, self.m_t.m_P.header.lesson_id))
                db.conn.commit()
    def reset_session(self):
        """
        Start a new practise session.
        """
        self.m_timestamp = int(time.time())
    def enter_test_mode(self):
        self.m_test_mode = True
    def exit_test_mode(self):
        self.m_test_mode = False
        row = db.conn.execute("select * from lessonfiles where uuid=?",
                (self.m_t.m_P.header.lesson_id,)).fetchone()
        test_result = db.last_test_result(self.m_t.m_P.header.lesson_id)
        passed = test_result >= self.m_t.m_P.get_test_requirement()
        if row:
            db.conn.execute("update lessonfiles "
                          "set test_result=?, test_passed=? where uuid=?",
                          (test_result, passed, self.m_t.m_P.header.lesson_id,))
        else:
            db.conn.execute("insert into lessonfiles  "
                            "(uuid, hashvalue, test_result, test_passed) "
                            "values(?, ?, ?, ?) ",
                   (self.m_t.m_P.header.lesson_id,
                    hash_of_lessonfile(lessonfile.manager.get(fn, 'filename')),
                    test_result,
                    passed
                    ))
        db.conn.commit()
    def _add(self, question, answer):
        """
        Register that for the question 'question' the user answered 'answer'.
        """
        assert self.m_timestamp
        table = {True: 'tests', False: 'sessions'}[self.m_test_mode]
        cursor = db.conn.cursor()
        row = cursor.execute(
                "select count from %s where uuid=? and timestamp=? "
                "and answerkey=? and guessed=?" % table,
                (self.m_t.m_P.header.lesson_id, 
                 self.m_timestamp, question, answer)).fetchone()
        if not row:
            cursor.execute(
                "insert into %s "
                "(uuid, timestamp, answerkey, guessed, count) "
                "values(?, ?, ?, ?, ?)" % table,
                (self.m_t.m_P.header.lesson_id, self.m_timestamp,
                 unicode(question), unicode(answer), 1))
        else:
            assert cursor.fetchone() is None
            cursor.execute(
                "update %s set count=? where "
                "uuid=? and timestamp=? and answerkey=? and guessed=?" % table,
                (row[0] + 1, self.m_t.m_P.header.lesson_id,
                 self.m_timestamp, question, answer))
        db.conn.commit()
    def add_wrong(self, question, answer):
        self._add(question, answer)
    def add_correct(self, answer):
        self._add(answer, answer)
    def key_to_pretty_name(self, key):
        for question in self.m_t.m_P.m_questions:
            if question.name.cval == key:
                return question.name
        return key
    def get_last_test_result(self):
        """
        Return the test result of the last test ran for this lesson file.
        """
        return db.last_test_result(self.m_t.m_P.header.lesson_id)
    def get_percentage_correct(self):
        """Will return a 0 <= value <= 1.0 that say how many percent is
        correct in this session
        """
        num_correct = db.conn.execute("select sum(count) from sessions where answerkey=guessed and timestamp=? and uuid=?", (self.m_timestamp, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        num_asked = db.conn.execute("select sum(count) from sessions where timestamp=? and  uuid=?", (self.m_timestamp, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        if not num_correct:
            num_correct = 0
        if not num_asked:
            num_asked = 0
            return 0
        return 100.0 * num_correct / num_asked
    def get_percentage_correct_for_key(self, seconds, key):
        """
        Return the percentage correct answer the last 'seconds' seconds.
        """
        # All statistics
        num_guess = self.get_num_guess_for_key(seconds, key)
        if num_guess:
            return 100.0 * self.get_num_correct_for_key(seconds, key) / num_guess
        return 0
    def get_num_correct_for_key(self, seconds, key):
        """
        Return the number of correct answers for the given key 'key' the
        last 'seconds' seconds.
        Special meanings of 'seconds':
            -1  all statistics
             0  statistics from this session
        """
        if seconds == -1:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and guessed=? and uuid=?", (key, key, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        elif seconds == 0:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and guessed=? and timestamp=? and uuid=?", (key, key, self.m_timestamp, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        else:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and guessed=? and timestamp>? and uuid=?", (key, key, self.m_timestamp - seconds, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        if ret:
            return ret
        return 0
    def get_num_guess_for_key(self, seconds, key):
        """
        See get_num_correct_for_key docstring.
        """
        if seconds == -1:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and uuid=?", (key, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        elif seconds == 0:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and timestamp=? and uuid=?", (key, self.m_timestamp, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        else:
            ret = db.conn.execute("select sum(count) from sessions where answerkey=? and timestamp>? and uuid=?", (key, self.m_timestamp - seconds, self.m_t.m_P.header.lesson_id)).fetchone()[0]
        if ret:
            return ret
        return 0


class LessonStatistics(AbstractStatistics):
    def key_to_pretty_name(self, key):
        for question in self.m_t.m_P.m_questions:
            if question.name.cval == key:
                return question.name
        return key

class IntervalStatistics(AbstractStatistics):
    def key_to_pretty_name(self, key):
        return utils.int_to_intervalname(int(key), 1, 1)

class HarmonicIntervalStatistics(AbstractStatistics):
    def key_to_pretty_name(self, key):
        return utils.int_to_intervalname(int(key), 1, 0)

class IdToneStatistics(LessonStatistics):
    def key_to_pretty_name(self, key):
        return mpd.MusicalPitch.new_from_notename(key).get_user_notename()


