diff --git a/encompass-cs/ComponentManager.cs b/encompass-cs/ComponentManager.cs index 43b74f9..8d10ef6 100644 --- a/encompass-cs/ComponentManager.cs +++ b/encompass-cs/ComponentManager.cs @@ -107,9 +107,9 @@ namespace Encompass internal IEnumerable> GetComponentsByEntityAndType(Guid entityID) where TComponent : struct, IComponent { - var entity_components = GetComponentsByEntity(entityID).Select((kv) => new KeyValuePair(kv.Key, (TComponent)kv.Value)); - var active_components_by_type = GetActiveComponentsByType(); - return entity_components.Intersect(active_components_by_type); + var entityComponentsByType = GetComponentsByEntity(entityID).Where((pair) => componentIDToType[pair.Key] == typeof(TComponent)).Select((pair) => new KeyValuePair(pair.Key, (TComponent)pair.Value)); + var activeComponentsByType = GetActiveComponentsByType(); + return activeComponentsByType.Intersect(entityComponentsByType); } internal IEnumerable> GetComponentsByEntityAndType(Guid entityID, Type type) diff --git a/test/EntityRendererTest.cs b/test/EntityRendererTest.cs index 6577eff..ceaf243 100644 --- a/test/EntityRendererTest.cs +++ b/test/EntityRendererTest.cs @@ -1,7 +1,9 @@ using System; using NUnit.Framework; +using FluentAssertions; using Encompass; +using System.Collections.Generic; namespace Tests { @@ -93,11 +95,13 @@ namespace Tests } static bool calledOnDraw = false; + static IEnumerable> resultComponents; [Renders(typeof(TestDrawComponent), typeof(AComponent), typeof(CComponent))] class CalledRenderer : EntityRenderer { public override void Render(Entity entity) { + resultComponents = entity.GetComponents(); calledOnDraw = true; } } @@ -110,12 +114,12 @@ namespace Tests AComponent aComponent; CComponent cComponent; - TestDrawComponent testDrawComponent = default(TestDrawComponent); + TestDrawComponent testDrawComponent; var entity = worldBuilder.CreateEntity(); entity.AddComponent(aComponent); entity.AddComponent(cComponent); - entity.AddDrawComponent(testDrawComponent, 2); + var testDrawComponentID = entity.AddDrawComponent(testDrawComponent, 2); var world = worldBuilder.Build(); @@ -124,6 +128,7 @@ namespace Tests Assert.IsTrue(renderer.IsTracking(entity.id)); Assert.IsTrue(calledOnDraw); + resultComponents.Should().Contain(new KeyValuePair(testDrawComponentID, testDrawComponent)); } } }