前回(COMで弱参照を実現する)の続きで弱参照の実装、マルチスレッド対応編です。今回のコードは前回からがらりと変わっています。

今回はクラスObjectWithWeakReferenceにすべてまとめました。

  • 対象オブジェクトのメモリ領域
  • IWeakReferenceSourceの実装
  • IWeakReferenceの実装

対象オブジェクト

対象オブジェクトはATL::CComContainedObjectを使い、aligned_storageで確保した領域に構築させています。C++11のunionが使えればよいのですが、Visual C++にはまだありません。

typedef ATL::CComContainedObject<T> Obj;
typename std::aligned_storage<
  sizeof (Obj), std::alignment_of<Obj>::value>::type m_contained;

m_contained上のオブジェクトはObjectWithWeakReferenceのコンストラクタで構築します。そしてm_strongRefが0になったら破棄します。

std::atomic<ULONG> m_strongRef = 1;
 
ObjectWithWeakReference()
{
  ::new(&m_contained) Obj(
    static_cast<IUnknown*>(GetWeakReferenceSource()));
}
 
STDMETHOD_(ULONG, ObjectRelease)() override
{
  auto p = GetContainedObject();
  auto c = --m_strongRef;
  if (c == 0)
  {
    p->FinalRelease();
    p->~Obj();
 
    // FinalConstructでInternalAddRefしたものに対応する。
    InternalRelease();
  }
  return c;
}
 
Obj* GetContainedObject()
{
  assert(m_strongRef > 0);
  return reinterpret_cast<Obj*>(&m_contained);
}

GetWeakReferenceSource()は小細工入りのアップキャストで、IWeakReferenceSource*を返します。

なお、このObjectWithWeakReferenceでは、強参照がなくなっても弱参照がなくなるまでオブジェクトの大きさ分メモリを占有し続けます。こうした理由は、弱参照を長期間放置するような使い方は希だろうと判断したためです。ちなみに、boost::make_sharedも同じ挙動です。

IWeakReferenceSource

VTableの順番さえあっていればOK、というのがCOMです。IUnknownのメソッドを別名にしたIWeakReferenceSourceImplから派生させることで、弱参照用のIUnknownと強参照用のIUnknownを1クラスに共存させています。C++的に綺麗な実装方法もあるのですが、このほうが簡単なのです。

この手法はATLでも_IDispEvent::_LocDEQueryInterfaceで使用されています。

class DECLSPEC_NOVTABLE IWeakReferenceSourceImpl
{
public:
  STDMETHOD(ObjectQueryInterface)(
    _In_ REFIID riid, _COM_Outptr_ void** ppv) = 0;
  STDMETHOD_(ULONG, ObjectAddRef)() = 0;
  STDMETHOD_(ULONG, ObjectRelease)() = 0;
  STDMETHOD(GetWeakReference)(
    _COM_Outptr_ egtra::IWeakReference** weakReference) = 0;
};
 
template<class T>
class ObjectWithWeakReference
  : public ATL::CComObjectRootEx<ATL::CComMultiThreadModelNoCS>
  , public egtra::IWeakReference
  , public IWeakReferenceSourceImpl
{
public:
  virtual HRESULT STDMETHODCALLTYPE GetWeakReference(
    /* [retval][out] */ _COM_Outptr_ egtra::IWeakReference** weakReference
    ) throw() override
  {
    ATLENSURE_RETURN_HR(weakReference != nullptr, E_POINTER);
    *weakReference = this;
    AddRef();
    return S_OK;
  }
 
  egtra::IWeakReferenceSource* GetWeakReferenceSource()
  {
    return reinterpret_cast<egtra::IWeakReferenceSource*>(
      static_cast<IWeakReferenceSourceImpl*>(this));
  }

IWeakReference

Resolveメソッドでは、強参照のカウンタを見てオブジェクトを返して良いか判定しています。1クラスにまとめて実装した理由もここにあり、すなわち強参照のカウンタを読み書きできる必要があるからです。

virtual HRESULT STDMETHODCALLTYPE Resolve(
  /* [in] */ _In_ REFIID riid,
  /* [iid_is][out] */ _COM_Outptr_result_maybenull_ void** ppv
  ) throw() override
{
  ATLENSURE_RETURN_HR(ppv != nullptr, E_POINTER);
  *ppv = nullptr;
  if (!TryAddRef())
  {
    return S_FALSE;
  }
  auto hr = ObjectQueryInterface(riid, ppv);
  ObjectRelease(); // TryAddRefの分
  return hr;
}
 
bool TryAddRef()
{
  for (;;)
  {
    auto c = m_strongRef.load(std::memory_order_relaxed);
    if (c == 0)
    {
      return false;
    }
    if (m_strongRef.compare_exchange_weak(
      c, c + 1, std::memory_order_relaxed))
    {
      return true;
    }
    _mm_pause();
  }
}

ソースコード

以下、ソースコード全体です。参照カウントに対して、積極的にmemory_order_relaxedを使ってみました。

#include <atomic>
#include <iostream>
#include <memory>
#include <cassert>
#include <intrin.h>
#include <atlbase.h>
#include <atlcom.h>
 
class Module : public ATL::CAtlExeModuleT<Module> {};
Module module;
 
namespace egtra
{
  MIDL_INTERFACE("bdcb7ca6-376d-481a-8652-dfd69f723ecc")
  IWeakReference : IUnknown
  {
  public:
    virtual HRESULT STDMETHODCALLTYPE Resolve(
      /* [in] */ __RPC__in REFIID riid,
      /* [iid_is][out] */ __RPC__deref_out void** objectReference) = 0;
  };
 
  MIDL_INTERFACE("de2988fe-a6b7-4e3d-923e-7463ce0e1040")
  IWeakReferenceSource : IUnknown
  {
  public:
    virtual HRESULT STDMETHODCALLTYPE GetWeakReference(
      /* [retval][out] */ __RPC__deref_out IWeakReference** weakReference
      ) = 0;
  };
}
 
template<typename T>
ATL::CComPtr<T> CreateComObject()
{
  auto p = std::make_unique<ATL::CComObject<T>>();
  p->SetVoid(nullptr);
  p->InternalFinalConstructAddRef();
  HRESULT hRes = p->_AtlInitialConstruct();
  if (SUCCEEDED(hRes))
    hRes = p->FinalConstruct();
  if (SUCCEEDED(hRes))
    hRes = p->_AtlFinalConstruct();
  p->InternalFinalConstructRelease();
  return hRes == S_OK
    ? p.release()
    : nullptr;
}
 
class DECLSPEC_NOVTABLE IWeakReferenceSourceImpl
{
public:
  STDMETHOD(ObjectQueryInterface)(
    _In_ REFIID riid, _COM_Outptr_ void** ppv) = 0;
  STDMETHOD_(ULONG, ObjectAddRef)() = 0;
  STDMETHOD_(ULONG, ObjectRelease)() = 0;
  STDMETHOD(GetWeakReference)(
    _COM_Outptr_ egtra::IWeakReference** weakReference) = 0;
};
 
template<class T>
class ObjectWithWeakReference
  : public ATL::CComObjectRootEx<ATL::CComMultiThreadModelNoCS>
  , public egtra::IWeakReference
  , public IWeakReferenceSourceImpl
{
  typedef ATL::CComContainedObject<T> Obj;
 
  BEGIN_COM_MAP(ObjectWithWeakReference)
    COM_INTERFACE_ENTRY(egtra::IWeakReference)
  END_COM_MAP()
 
public:
  DECLARE_PROTECT_FINAL_CONSTRUCT()
 
  ObjectWithWeakReference()
  {
    ::new(&m_contained) Obj(
      static_cast<IUnknown*>(GetWeakReferenceSource()));
  }
 
  HRESULT _AtlInitialConstruct()
  {
    HRESULT hr = GetContainedObject()->_AtlInitialConstruct();
    ATLENSURE_RETURN_HR(SUCCEEDED(hr), hr);
    return __super::_AtlInitialConstruct();
  }
 
  HRESULT FinalConstruct()
  {
    InternalAddRef();
    auto hr = __super::FinalConstruct();
    ATLENSURE_RETURN_HR(SUCCEEDED(hr), hr);
    return GetContainedObject()->FinalConstruct();
  }
 
  void FinalRelease()
  {
    auto c = m_strongRef.load(std::memory_order_relaxed);
    assert(c == 0 || c == 1);
    if (c == 1)
    {
      ObjectRelease();
    }
    __super::FinalRelease();
  }
 
  virtual HRESULT STDMETHODCALLTYPE ObjectQueryInterface(
    _In_ REFIID riid, _COM_Outptr_ void** ppv) throw() override
  {
    ATLENSURE_RETURN_HR(ppv != nullptr, E_POINTER);
    if (riid == __uuidof(egtra::IWeakReferenceSource))
    {
      *ppv = GetWeakReferenceSource();
      ObjectAddRef();
      return S_OK;
    }
    return GetContainedObject()->_InternalQueryInterface(riid, ppv);
  }
 
  virtual ULONG STDMETHODCALLTYPE ObjectAddRef() throw() override
  {
    // 簡単にするためインクリメント前の値をそのまま返す。
    return m_strongRef.fetch_add(1, std::memory_order_relaxed);
  }
 
  virtual ULONG STDMETHODCALLTYPE ObjectRelease() throw() override
  {
    auto p = GetContainedObject();
    // FinalRelease前にはmemory_order_seq_cstを入れようと思い、
    // デクリメント演算子を使用した。
    auto c = --m_strongRef;
    if (c == 0)
    {
      // 上記をmemory_order_relaxedにして
      // ここでstd::atomic_thread_fence(std::memory_order_seq_cst)
      // という手もあるが、コードが複雑になるので採用しなかった。
      p->FinalRelease();
      p->~Obj();
      InternalRelease();
    }
    return c;
  }
 
  virtual HRESULT STDMETHODCALLTYPE GetWeakReference(
    /* [retval][out] */ _COM_Outptr_ egtra::IWeakReference** weakReference
    ) throw() override
  {
    ATLENSURE_RETURN_HR(weakReference != nullptr, E_POINTER);
    *weakReference = this;
    AddRef();
    return S_OK;
  }
 
  virtual HRESULT STDMETHODCALLTYPE Resolve(
    /* [in] */ _In_ REFIID riid,
    /* [iid_is][out] */ _COM_Outptr_result_maybenull_ void** ppv
    ) throw() override
  {
    ATLENSURE_RETURN_HR(ppv != nullptr, E_POINTER);
    *ppv = nullptr;
    if (!TryAddRef())
    {
      return S_FALSE;
    }
    auto hr = ObjectQueryInterface(riid, ppv);
    ObjectRelease(); // TryAddRefの分
    return hr;
  }
 
  Obj* GetContainedObject()
  {
    assert(m_strongRef > 0);
    return reinterpret_cast<Obj*>(&m_contained);
  }
 
private:
  bool TryAddRef()
  {
    for (;;)
    {
      auto c = m_strongRef.load(std::memory_order_relaxed);
      if (c == 0)
      {
        return false;
      }
      if (m_strongRef.compare_exchange_weak(
        c, c + 1, std::memory_order_relaxed))
      {
        return true;
      }
      _mm_pause();
    }
  }
 
  egtra::IWeakReferenceSource* GetWeakReferenceSource()
  {
    return reinterpret_cast<egtra::IWeakReferenceSource*>(
      static_cast<IWeakReferenceSourceImpl*>(this));
  }
 
  typename std::aligned_storage<
    sizeof (Obj), std::alignment_of<Obj>::value>::type m_contained;
  std::atomic<ULONG> m_strongRef = 1;
};
 
template<typename T>
ATL::CComPtr<T> CreateComObjectWithWeakRef()
{
  if (auto weak = CreateComObject<ObjectWithWeakReference<T>>())
  {
    ATL::CComPtr<T> obj;
    obj.Attach(weak->GetContainedObject());
    return obj;
  }
  else
  {
    return nullptr;
  }
}
 
MIDL_INTERFACE("0f9a78ce-4f58-4160-9889-e3bd4485c92d") ITest : IUnknown
{
  virtual HRESULT STDMETHODCALLTYPE Test() = 0;
};
 
class ATL_NO_VTABLE TestImpl
  : public ATL::CComObjectRootEx<ATL::CComMultiThreadModel>
  , public ITest
{
  DECLARE_NOT_AGGREGATABLE(TestImpl)
 
  BEGIN_COM_MAP(TestImpl)
    COM_INTERFACE_ENTRY(ITest)
  END_COM_MAP()
 
public:
  HRESULT FinalConstruct() throw()
  {
    std::cout << "TestImpl::FinalConstruct" << std::endl;
    return S_OK;
  }
 
  HRESULT FinalRelease() throw()
  {
    std::cout << "TestImpl::FinalRelease" << std::endl;
    return S_OK;
  }
 
  virtual HRESULT STDMETHODCALLTYPE Test() throw() override
  {
    std::cout << "TestImpl::Test" << std::endl;
    return S_OK;
  }
 
  void Test2()
  {
    std::cout << "TestImpl::Test2 (non-virtual)" << std::endl;
  }
};
 
int main()
{
  ATL::CComPtr<egtra::IWeakReference> weak;
  {
    auto x = CreateComObjectWithWeakRef<TestImpl>();
    ATL::CComQIPtr<egtra::IWeakReferenceSource> s(x);
    ATLENSURE_SUCCEEDED(s->GetWeakReference(&weak));
 
    ATL::CComPtr<IUnknown> u;
    weak->Resolve(IID_PPV_ARGS(&u));
    std::cout << static_cast<IUnknown*>(u) << std::endl;
  }
 
  ATL::CComPtr<IUnknown> u;
  weak->Resolve(IID_PPV_ARGS(&u));
  std::cout << static_cast<IUnknown*>(u) << std::endl;
}

利用例(main関数)がマルチスレッドな感じになっていないのは申し訳ありません。

これを実現するに当たって、boost::shared_ptrとlibc++のstd::shared_ptrのソースを参考にしました。強参照が生きている間、弱参照のカウントを+1しておくことや、弱参照から強参照を得る際のCASなどはそこから着想を得ました。

スポンサード リンク

この記事のカテゴリ

  • ⇒ COMで弱参照を実現する(マルチスレッド対応)
  • ⇒ COMで弱参照を実現する(マルチスレッド対応)
  • ⇒ COMで弱参照を実現する(マルチスレッド対応)