#include <iostream>
#include <thread>
#include <mutex>
#include <vector>

class QueueException {
    std::string str;
public:
    QueueException(const std::string& s) : str(s) {}
    ~QueueException() {}
};

template <typename E> 
class SafeQueue {
    // The queue has bounded capacity.
    int capacity;
    // This array holds the queue elements.
    std::vector<E> elements;
    // The number of the next element to be dequeued.
    int head;
    // The number of the next element to be inserted.
    int tail;
    std::mutex lock;
public:
    SafeQueue<E>(int capacity) {
        this->capacity = capacity;
        std::vector<E> elements(capacity);
        this->elements = elements;
        this->head = 0;
        this->tail = 0;
  }
  void put(const E& element);
  E take();
};

template <typename E>
void SafeQueue<E>::put(const E& element) {
    std::lock_guard<std::mutex> guard(lock);
    if (this->head + this->capacity == this->tail)
        throw QueueException("Full queue");
    // We have tail < head + capacity, so the slot
    // at index (tail % capacity) is available.
    // Write the new element into this slot.
    this->elements.at(this->tail % this->capacity) = element;
    // Mark this slot as occupied.
    this->tail++;
    // We again have tail <= head + capacity.
}

template <typename E> 
E SafeQueue<E>::take() {
    std::lock_guard<std::mutex> guard(lock);
    if (this->head == this->tail)
        throw QueueException("Empty queue");
    // We have head < tail, so the slot at index
    // (head % capacity) is occupied.
    // Read the element in this slot.
    E element = this->elements.at(this->head % this->capacity);
    // Erase this slot.
    //  elements.at(head % capacity) = null;
    // Mark it free.
    this->head++;
    // We again have head <= tail.
    return element;
}

void taker(SafeQueue<int>& q, size_t& counter) {
    for (size_t i = 0; i < 500000; ++i) {
        int v = q.take();
        if (v == 1) {
            ++counter;
        }
    }
}

int main() {
    SafeQueue<int> q(1000000);
    for  (size_t i = 0; i < 1000000; ++i) {
        q.put(i % 2);
    }

    size_t counter_A = 0;
    size_t counter_B = 0;
    std::thread A(taker, std::ref(q), std::ref(counter_A));
    std::thread B(taker, std::ref(q), std::ref(counter_B));
    A.join();
    B.join();
    std::cout << "Count in A " << counter_A << "; Counter in B " << counter_B << std::endl;
    std::cout << "In total " << counter_A + counter_B << std::endl;
}

