<?php declare(strict_types=1);

namespace Shopware\Core\System\SalesChannel\Context;

use Shopware\Core\Checkout\Cart\Delivery\Struct\ShippingLocation;
use Shopware\Core\Checkout\Cart\Price\Struct\CartPrice;
use Shopware\Core\Checkout\Cart\Tax\TaxDetector;
use Shopware\Core\Checkout\Customer\Aggregate\CustomerAddress\CustomerAddressEntity;
use Shopware\Core\Checkout\Customer\CustomerEntity;
use Shopware\Core\Checkout\Payment\Exception\UnknownPaymentMethodException;
use Shopware\Core\Checkout\Payment\PaymentMethodEntity;
use Shopware\Core\Framework\Api\Context\SalesChannelApiSource;
use Shopware\Core\Framework\Context;
use Shopware\Core\Framework\DataAbstractionLayer\EntityRepositoryInterface;
use Shopware\Core\Framework\DataAbstractionLayer\Pricing\CashRoundingConfig;
use Shopware\Core\Framework\DataAbstractionLayer\Search\Criteria;
use Shopware\Core\Framework\DataAbstractionLayer\Search\Filter\EqualsFilter;
use Shopware\Core\Framework\DataAbstractionLayer\Search\Filter\MultiFilter;
use Shopware\Core\Framework\Feature;
use Shopware\Core\Framework\Plugin\Exception\DecorationPatternException;
use Shopware\Core\System\Currency\Aggregate\CurrencyCountryRounding\CurrencyCountryRoundingEntity;
use Shopware\Core\System\SalesChannel\BaseContext;
use Shopware\Core\System\SalesChannel\Event\SalesChannelContextPermissionsChangedEvent;
use Shopware\Core\System\SalesChannel\SalesChannelContext;
use Shopware\Core\System\Tax\Aggregate\TaxRule\TaxRuleCollection;
use Shopware\Core\System\Tax\Aggregate\TaxRule\TaxRuleEntity;
use Shopware\Core\System\Tax\TaxCollection;
use Shopware\Core\System\Tax\TaxRuleType\TaxRuleTypeFilterInterface;
use Symfony\Component\EventDispatcher\EventDispatcherInterface;
use function array_unique;

class SalesChannelContextFactory extends AbstractSalesChannelContextFactory
{
    private EntityRepositoryInterface $customerRepository;

    private EntityRepositoryInterface $customerGroupRepository;

    private EntityRepositoryInterface $addressRepository;

    private EntityRepositoryInterface $paymentMethodRepository;

    private TaxDetector $taxDetector;

    /**
     * @var iterable|TaxRuleTypeFilterInterface[]
     */
    private $taxRuleTypeFilter;

    private EventDispatcherInterface $eventDispatcher;

    private EntityRepositoryInterface $currencyCountryRepository;

    private AbstractBaseContextFactory $baseContextFactory;

    /**
     * @internal
     *
     * @param iterable<TaxRuleTypeFilterInterface> $taxRuleTypeFilter
     */
    public function __construct(
        EntityRepositoryInterface $customerRepository,
        EntityRepositoryInterface $customerGroupRepository,
        EntityRepositoryInterface $addressRepository,
        EntityRepositoryInterface $paymentMethodRepository,
        TaxDetector $taxDetector,
        iterable $taxRuleTypeFilter,
        EventDispatcherInterface $eventDispatcher,
        EntityRepositoryInterface $currencyCountryRepository,
        AbstractBaseContextFactory $baseContextFactory
    ) {
        $this->customerRepository = $customerRepository;
        $this->customerGroupRepository = $customerGroupRepository;
        $this->addressRepository = $addressRepository;
        $this->paymentMethodRepository = $paymentMethodRepository;
        $this->taxDetector = $taxDetector;
        $this->taxRuleTypeFilter = $taxRuleTypeFilter;
        $this->eventDispatcher = $eventDispatcher;
        $this->currencyCountryRepository = $currencyCountryRepository;
        $this->baseContextFactory = $baseContextFactory;
    }

    public function getDecorated(): AbstractSalesChannelContextFactory
    {
        throw new DecorationPatternException(self::class);
    }

    public function create(string $token, string $salesChannelId, array $options = []): SalesChannelContext
    {
        // we split the context generation to allow caching of the base context
        $base = $this->baseContextFactory->create($salesChannelId, $options);

        // customer
        $customer = null;
        if (\array_key_exists(SalesChannelContextService::CUSTOMER_ID, $options) && $options[SalesChannelContextService::CUSTOMER_ID] !== null) {
            //load logged in customer and set active addresses
            $customer = $this->loadCustomer($options, $base->getContext());
        }

        $shippingLocation = $base->getShippingLocation();
        if ($customer) {
            /** @var CustomerAddressEntity $activeShippingAddress */
            $activeShippingAddress = $customer->getActiveShippingAddress();
            $shippingLocation = ShippingLocation::createFromAddress($activeShippingAddress);
        }

        $customerGroup = $base->getCurrentCustomerGroup();

        if ($customer) {
            $criteria = new Criteria([$customer->getGroupId()]);
            $criteria->setTitle('context-factory::customer-group');
            $customerGroup = $this->customerGroupRepository->search($criteria, $base->getContext())->first() ?? $customerGroup;
        }

        //loads tax rules based on active customer and delivery address
        $taxRules = $this->getTaxRules($base, $customer, $shippingLocation);

        //detect active payment method, first check if checkout defined other payment method, otherwise validate if customer logged in, at least use shop default
        $payment = $this->getPaymentMethod($options, $base, $customer);

        [$itemRounding, $totalRounding] = $this->getCashRounding($base, $shippingLocation);

        $context = new Context(
            $base->getContext()->getSource(),
            [],
            $base->getCurrencyId(),
            $base->getContext()->getLanguageIdChain(),
            $base->getContext()->getVersionId(),
            $base->getCurrency()->getFactor(),
            true,
            CartPrice::TAX_STATE_GROSS,
            $itemRounding
        );

        $fallbackGroup = $customerGroup;
        Feature::callSilentIfInactive('v6.5.0.0', function () use ($base, &$fallbackGroup): void {
            $fallbackGroup = $base->getFallbackCustomerGroup();
        });

        $salesChannelContext = new SalesChannelContext(
            $context,
            $token,
            $options[SalesChannelContextService::DOMAIN_ID] ?? null,
            $base->getSalesChannel(),
            $base->getCurrency(),
            $customerGroup,
            $fallbackGroup,
            $taxRules,
            $payment,
            $base->getShippingMethod(),
            $shippingLocation,
            $customer,
            $itemRounding,
            $totalRounding,
            []
        );

        if (\array_key_exists(SalesChannelContextService::PERMISSIONS, $options)) {
            $salesChannelContext->setPermissions($options[SalesChannelContextService::PERMISSIONS]);

            $event = new SalesChannelContextPermissionsChangedEvent($salesChannelContext, $options[SalesChannelContextService::PERMISSIONS]);
            $this->eventDispatcher->dispatch($event);

            $salesChannelContext->lockPermissions();
        }

        $salesChannelContext->setTaxState($this->taxDetector->getTaxState($salesChannelContext));

        return $salesChannelContext;
    }

    private function getTaxRules(BaseContext $context, ?CustomerEntity $customer, ShippingLocation $shippingLocation): TaxCollection
    {
        $taxes = $context->getTaxRules()->getElements();

        foreach ($taxes as $tax) {
            $taxRules = $tax->getRules();
            if ($taxRules === null) {
                continue;
            }

            $taxRules = $taxRules->filter(function (TaxRuleEntity $taxRule) use ($customer, $shippingLocation) {
                foreach ($this->taxRuleTypeFilter as $ruleTypeFilter) {
                    if ($ruleTypeFilter->match($taxRule, $customer, $shippingLocation)) {
                        return true;
                    }
                }

                return false;
            });
            $taxRules->sortByTypePosition();
            $taxRule = $taxRules->first();

            $matchingRules = new TaxRuleCollection();
            if ($taxRule) {
                $matchingRules->add($taxRule);
            }
            $tax->setRules($matchingRules);
        }

        return new TaxCollection($taxes);
    }

    /**
     * @group not-deterministic
     * NEXT-21735 - This is covered randomly
     * @codeCoverageIgnore
     *
     * @param array<string, mixed> $options
     */
    private function getPaymentMethod(array $options, BaseContext $context, ?CustomerEntity $customer): PaymentMethodEntity
    {
        if ($customer === null || isset($options[SalesChannelContextService::PAYMENT_METHOD_ID])) {
            return $context->getPaymentMethod();
        }

        $id = $customer->getLastPaymentMethodId() ?? $customer->getDefaultPaymentMethodId();
        if ($id === $context->getPaymentMethod()->getId()) {
            // NEXT-21735 - does not execute on every test run
            return $context->getPaymentMethod();
        }

        $criteria = new Criteria([$id]);
        $criteria->addAssociation('media');
        $criteria->setTitle('context-factory::payment-method');

        /** @var PaymentMethodEntity|null $paymentMethod */
        $paymentMethod = $this->paymentMethodRepository->search($criteria, $context->getContext())->get($id);

        if (!$paymentMethod) {
            throw new UnknownPaymentMethodException($id);
        }

        return $paymentMethod;
    }

    /**
     * @param array<string, mixed> $options
     */
    private function loadCustomer(array $options, Context $context): ?CustomerEntity
    {
        $customerId = $options[SalesChannelContextService::CUSTOMER_ID];

        $criteria = new Criteria([$customerId]);
        $criteria->setTitle('context-factory::customer');
        $criteria->addAssociation('salutation');
        $criteria->addAssociation('defaultPaymentMethod');

        /** @var SalesChannelApiSource $source */
        $source = $context->getSource();

        $criteria->addFilter(new MultiFilter(MultiFilter::CONNECTION_OR, [
            new EqualsFilter('customer.boundSalesChannelId', null),
            new EqualsFilter('customer.boundSalesChannelId', $source->getSalesChannelId()),
        ]));

        /** @var CustomerEntity|null $customer */
        $customer = $this->customerRepository->search($criteria, $context)->get($customerId);

        if (!$customer) {
            return null;
        }

        $activeBillingAddressId = $options[SalesChannelContextService::BILLING_ADDRESS_ID] ?? $customer->getDefaultBillingAddressId();
        $activeShippingAddressId = $options[SalesChannelContextService::SHIPPING_ADDRESS_ID] ?? $customer->getDefaultShippingAddressId();

        $addressIds[] = $activeBillingAddressId;
        $addressIds[] = $activeShippingAddressId;
        $addressIds[] = $customer->getDefaultBillingAddressId();
        $addressIds[] = $customer->getDefaultShippingAddressId();

        $criteria = new Criteria(array_unique($addressIds));
        $criteria->setTitle('context-factory::addresses');
        $criteria->addAssociation('salutation');
        $criteria->addAssociation('country');
        $criteria->addAssociation('countryState');

        $addresses = $this->addressRepository->search($criteria, $context);

        /** @var CustomerAddressEntity $activeBillingAddress */
        $activeBillingAddress = $addresses->get($activeBillingAddressId);
        $customer->setActiveBillingAddress($activeBillingAddress);
        /** @var CustomerAddressEntity $activeShippingAddress */
        $activeShippingAddress = $addresses->get($activeShippingAddressId);
        $customer->setActiveShippingAddress($activeShippingAddress);
        /** @var CustomerAddressEntity $defaultBillingAddress */
        $defaultBillingAddress = $addresses->get($customer->getDefaultBillingAddressId());
        $customer->setDefaultBillingAddress($defaultBillingAddress);
        /** @var CustomerAddressEntity $defaultShippingAddress */
        $defaultShippingAddress = $addresses->get($customer->getDefaultShippingAddressId());
        $customer->setDefaultShippingAddress($defaultShippingAddress);

        return $customer;
    }

    /**
     * @return CashRoundingConfig[]
     *
     * @group not-deterministic
     * NEXT-21735 - This is covered randomly
     * @codeCoverageIgnore
     */
    private function getCashRounding(BaseContext $context, ShippingLocation $shippingLocation): array
    {
        if ($context->getShippingLocation()->getCountry()->getId() === $shippingLocation->getCountry()->getId()) {
            return [$context->getItemRounding(), $context->getTotalRounding()];
        }

        $criteria = new Criteria();
        $criteria->setTitle('context-factory::cash-rounding');
        $criteria->setLimit(1);
        $criteria->addFilter(new EqualsFilter('currencyId', $context->getCurrencyId()));
        $criteria->addFilter(new EqualsFilter('countryId', $shippingLocation->getCountry()->getId()));

        /** @var CurrencyCountryRoundingEntity|null $countryConfig */
        $countryConfig = $this->currencyCountryRepository
            ->search($criteria, $context->getContext())
            ->first();

        if ($countryConfig) {
            return [$countryConfig->getItemRounding(), $countryConfig->getTotalRounding()];
        }

        return [$context->getCurrency()->getItemRounding(), $context->getCurrency()->getTotalRounding()];
    }
}
